abc
commited on
Commit
·
94f2ce5
1
Parent(s):
07048a3
Upload 54 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .github/workflows/typos.yml +21 -0
- .gitignore +7 -0
- LICENSE.md +201 -0
- _typos.toml +15 -0
- adastand.py +291 -0
- append_module.py +504 -0
- build/lib/library/__init__.py +0 -0
- build/lib/library/model_util.py +1180 -0
- build/lib/library/train_util.py +1796 -0
- fine_tune.py +360 -0
- gen_img_diffusers.py +0 -0
- library.egg-info/PKG-INFO +4 -0
- library.egg-info/SOURCES.txt +10 -0
- library.egg-info/dependency_links.txt +1 -0
- library.egg-info/top_level.txt +1 -0
- library/__init__.py +0 -0
- library/__pycache__/__init__.cpython-310.pyc +0 -0
- library/__pycache__/model_util.cpython-310.pyc +0 -0
- library/__pycache__/train_util.cpython-310.pyc +0 -0
- library/model_util.py +1180 -0
- library/train_util.py +1796 -0
- locon/__init__.py +0 -0
- locon/kohya_model_utils.py +1184 -0
- locon/kohya_utils.py +48 -0
- locon/locon.py +53 -0
- locon/locon_kohya.py +243 -0
- locon/utils.py +148 -0
- lora_train_popup.py +862 -0
- lycoris/__init__.py +8 -0
- lycoris/kohya.py +276 -0
- lycoris/kohya_model_utils.py +1184 -0
- lycoris/kohya_utils.py +48 -0
- lycoris/locon.py +76 -0
- lycoris/loha.py +116 -0
- lycoris/utils.py +271 -0
- networks/__pycache__/lora.cpython-310.pyc +0 -0
- networks/check_lora_weights.py +32 -0
- networks/extract_lora_from_models.py +164 -0
- networks/lora.py +237 -0
- networks/lora_interrogator.py +122 -0
- networks/merge_lora.py +212 -0
- networks/merge_lora_old.py +179 -0
- networks/resize_lora.py +198 -0
- networks/svd_merge_lora.py +164 -0
- requirements.txt +25 -0
- requirements_startup.txt +23 -0
- setup.py +3 -0
- tools/convert_diffusers20_original_sd.py +89 -0
- tools/detect_face_rotate.py +239 -0
- tools/resize_images_to_resolution.py +122 -0
.github/workflows/typos.yml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
# yamllint disable rule:line-length
|
3 |
+
name: Typos
|
4 |
+
|
5 |
+
on: # yamllint disable-line rule:truthy
|
6 |
+
push:
|
7 |
+
pull_request:
|
8 |
+
types:
|
9 |
+
- opened
|
10 |
+
- synchronize
|
11 |
+
- reopened
|
12 |
+
|
13 |
+
jobs:
|
14 |
+
build:
|
15 |
+
runs-on: ubuntu-latest
|
16 |
+
|
17 |
+
steps:
|
18 |
+
- uses: actions/checkout@v3
|
19 |
+
|
20 |
+
- name: typos-action
|
21 |
+
uses: crate-ci/[email protected]
|
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
logs
|
2 |
+
__pycache__
|
3 |
+
wd14_tagger_model
|
4 |
+
venv
|
5 |
+
*.egg-info
|
6 |
+
build
|
7 |
+
.vscode
|
LICENSE.md
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [2022] [kohya-ss]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
_typos.toml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Files for typos
|
2 |
+
# Instruction: https://github.com/marketplace/actions/typos-action#getting-started
|
3 |
+
|
4 |
+
[default.extend-identifiers]
|
5 |
+
|
6 |
+
[default.extend-words]
|
7 |
+
NIN="NIN"
|
8 |
+
parms="parms"
|
9 |
+
nin="nin"
|
10 |
+
extention="extention" # Intentionally left
|
11 |
+
nd="nd"
|
12 |
+
|
13 |
+
|
14 |
+
[files]
|
15 |
+
extend-exclude = ["_typos.toml"]
|
adastand.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
def __version__():
|
4 |
+
return 0.5
|
5 |
+
#######################################################################################
|
6 |
+
#NTT提案のAdastand Optimizer NTTを信用できるならAdamより少し性能高い(2019)
|
7 |
+
#参考コード:https://github.com/bunag-public/adastand_pack/
|
8 |
+
#似た計算式のAdaBeliefがAdamWと同じweight_decayの計算を導入していたのでAdamWのweight_decay式を使えるように
|
9 |
+
class Adastand(torch.optim.Optimizer):
|
10 |
+
"""Implements Adastand algorithm.
|
11 |
+
Arguments:
|
12 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
13 |
+
parameter groups
|
14 |
+
lr (float, optional): learning rate (default: 1e-3)
|
15 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
16 |
+
running averages of gradient and its square (default: (0.9, 0.999))
|
17 |
+
eps (float, optional): term added to the denominator to improve
|
18 |
+
numerical stability (default: 1e-8)
|
19 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
20 |
+
weight_decouple (bool, optional): if True is weight decay as in AdamW
|
21 |
+
"""
|
22 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
23 |
+
weight_decay=0, weight_decouple=False, fixed_decay=False, amsgrad=False):
|
24 |
+
if not 0.0 <= lr:
|
25 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
26 |
+
if not 0.0 <= eps:
|
27 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
28 |
+
if not 0.0 <= betas[0] < 1.0:
|
29 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
30 |
+
if not 0.0 <= betas[1] < 1.0:
|
31 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
32 |
+
defaults = dict(lr=lr, betas=betas, eps=eps,
|
33 |
+
weight_decay=weight_decay, weight_decouple=weight_decouple, fixed_decay=fixed_decay, amsgrad=amsgrad)
|
34 |
+
super(Adastand, self).__init__(params, defaults)
|
35 |
+
|
36 |
+
def __setstate__(self, state):
|
37 |
+
super(Adastand, self).__setstate__(state)
|
38 |
+
|
39 |
+
def step(self, closure=None):
|
40 |
+
"""Performs a single optimization step.
|
41 |
+
Arguments:
|
42 |
+
closure (callable, optional): A closure that reevaluates the model
|
43 |
+
and returns the loss.
|
44 |
+
"""
|
45 |
+
loss = None
|
46 |
+
if closure is not None:
|
47 |
+
loss = closure()
|
48 |
+
|
49 |
+
for group in self.param_groups:
|
50 |
+
for p in group['params']:
|
51 |
+
if p.grad is None:
|
52 |
+
continue
|
53 |
+
grad = p.grad.data
|
54 |
+
if grad.is_sparse:
|
55 |
+
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
56 |
+
weight_decouple = group['weight_decouple']
|
57 |
+
fixed_decay = group['fixed_decay']
|
58 |
+
amsgrad = group['amsgrad']
|
59 |
+
|
60 |
+
state = self.state[p]
|
61 |
+
|
62 |
+
# State initialization
|
63 |
+
if len(state) == 0:
|
64 |
+
state['step'] = 0
|
65 |
+
# Exponential moving average of gradient values
|
66 |
+
state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
|
67 |
+
# Exponential moving average of squared gradient values
|
68 |
+
state['exp_avg_sq'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
|
69 |
+
if amsgrad:
|
70 |
+
# Maintains max of all exp. moving avg. of
|
71 |
+
# sq. grad. values
|
72 |
+
state['exp_avg_sqs'] = torch.zeros_like(
|
73 |
+
p.data, memory_format=torch.preserve_format
|
74 |
+
)
|
75 |
+
|
76 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
77 |
+
beta1, beta2 = group['betas']
|
78 |
+
|
79 |
+
state['step'] += 1
|
80 |
+
|
81 |
+
if weight_decouple:
|
82 |
+
if not fixed_decay:
|
83 |
+
p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
|
84 |
+
else:
|
85 |
+
p.data.mul_(1.0 - group['weight_decay'])
|
86 |
+
else:
|
87 |
+
if group['weight_decay'] != 0:
|
88 |
+
grad.add_(p.data, alpha=group['weight_decay'])
|
89 |
+
|
90 |
+
# Decay the first and second moment running average coefficient
|
91 |
+
grad_residual = grad - exp_avg
|
92 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad_residual, grad_residual, value=beta2 * (1 - beta2))
|
93 |
+
exp_avg.mul_(2 * beta1 - 1).add_(grad, alpha=1 - beta1)
|
94 |
+
|
95 |
+
bias_correction1 = 1 - beta1 ** state['step']
|
96 |
+
bias_correction2 = 1 - beta2 ** state['step']
|
97 |
+
if amsgrad:
|
98 |
+
exp_avg_sqs = state['exp_avg_sqs']
|
99 |
+
torch.max(exp_avg_sqs, exp_avg_sq, out=exp_avg_sqs)
|
100 |
+
denom = exp_avg_sqs.sqrt().add_(group['eps']/ math.sqrt(bias_correction2))
|
101 |
+
else:
|
102 |
+
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
103 |
+
|
104 |
+
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
|
105 |
+
|
106 |
+
p.data.addcdiv_(exp_avg, denom, value=-step_size)
|
107 |
+
#p.data.addcdiv_(-step_size, exp_avg, denom)
|
108 |
+
|
109 |
+
return loss
|
110 |
+
#######################################################################################
|
111 |
+
class Adastand_b(torch.optim.Optimizer):
|
112 |
+
"""Implements Adastand algorithm.
|
113 |
+
Arguments:
|
114 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
115 |
+
parameter groups
|
116 |
+
lr (float, optional): learning rate (default: 1e-3)
|
117 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
118 |
+
running averages of gradient and its square (default: (0.9, 0.999))
|
119 |
+
eps (float, optional): term added to the denominator to improve
|
120 |
+
numerical stability (default: 1e-8)
|
121 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
122 |
+
weight_decouple (bool, optional): if True is weight decay as in AdamW
|
123 |
+
"""
|
124 |
+
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
params,
|
128 |
+
lr: float = 1e-3,
|
129 |
+
betas = (0.9, 0.999),
|
130 |
+
eps: float = 1e-8,
|
131 |
+
weight_decay: float = 0,
|
132 |
+
amsgrad: bool = False,
|
133 |
+
weight_decouple: bool = False,
|
134 |
+
fixed_decay: bool = False,
|
135 |
+
rectify: bool = False,
|
136 |
+
) -> None:
|
137 |
+
if lr <= 0.0:
|
138 |
+
raise ValueError('Invalid learning rate: {}'.format(lr))
|
139 |
+
if eps < 0.0:
|
140 |
+
raise ValueError('Invalid epsilon value: {}'.format(eps))
|
141 |
+
if not 0.0 <= betas[0] < 1.0:
|
142 |
+
raise ValueError(
|
143 |
+
'Invalid beta parameter at index 0: {}'.format(betas[0])
|
144 |
+
)
|
145 |
+
if not 0.0 <= betas[1] < 1.0:
|
146 |
+
raise ValueError(
|
147 |
+
'Invalid beta parameter at index 1: {}'.format(betas[1])
|
148 |
+
)
|
149 |
+
if weight_decay < 0:
|
150 |
+
raise ValueError(
|
151 |
+
'Invalid weight_decay value: {}'.format(weight_decay)
|
152 |
+
)
|
153 |
+
defaults = dict(
|
154 |
+
lr=lr,
|
155 |
+
betas=betas,
|
156 |
+
eps=eps,
|
157 |
+
weight_decay=weight_decay,
|
158 |
+
amsgrad=amsgrad,
|
159 |
+
)
|
160 |
+
super(Adastand_b, self).__init__(params, defaults)
|
161 |
+
|
162 |
+
self._weight_decouple = weight_decouple
|
163 |
+
self._rectify = rectify
|
164 |
+
self._fixed_decay = fixed_decay
|
165 |
+
|
166 |
+
def __setstate__(self, state):
|
167 |
+
super(Adastand_b, self).__setstate__(state)
|
168 |
+
for group in self.param_groups:
|
169 |
+
group.setdefault('amsgrad', False)
|
170 |
+
|
171 |
+
def step(self, closure=None):
|
172 |
+
r"""Performs a single optimization step.
|
173 |
+
|
174 |
+
Arguments:
|
175 |
+
closure: A closure that reevaluates the model and returns the loss.
|
176 |
+
"""
|
177 |
+
loss = None
|
178 |
+
if closure is not None:
|
179 |
+
loss = closure()
|
180 |
+
|
181 |
+
for group in self.param_groups:
|
182 |
+
for p in group['params']:
|
183 |
+
if p.grad is None:
|
184 |
+
continue
|
185 |
+
grad = p.grad.data
|
186 |
+
if grad.is_sparse:
|
187 |
+
raise RuntimeError(
|
188 |
+
'AdaBelief does not support sparse gradients, '
|
189 |
+
'please consider SparseAdam instead'
|
190 |
+
)
|
191 |
+
amsgrad = group['amsgrad']
|
192 |
+
|
193 |
+
state = self.state[p]
|
194 |
+
|
195 |
+
beta1, beta2 = group['betas']
|
196 |
+
|
197 |
+
# State initialization
|
198 |
+
if len(state) == 0:
|
199 |
+
state['rho_inf'] = 2.0 / (1.0 - beta2) - 1.0
|
200 |
+
state['step'] = 0
|
201 |
+
# Exponential moving average of gradient values
|
202 |
+
state['exp_avg'] = torch.zeros_like(
|
203 |
+
p.data, memory_format=torch.preserve_format
|
204 |
+
)
|
205 |
+
# Exponential moving average of squared gradient values
|
206 |
+
state['exp_avg_var'] = torch.zeros_like(
|
207 |
+
p.data, memory_format=torch.preserve_format
|
208 |
+
)
|
209 |
+
if amsgrad:
|
210 |
+
# Maintains max of all exp. moving avg. of
|
211 |
+
# sq. grad. values
|
212 |
+
state['max_exp_avg_var'] = torch.zeros_like(
|
213 |
+
p.data, memory_format=torch.preserve_format
|
214 |
+
)
|
215 |
+
|
216 |
+
# get current state variable
|
217 |
+
exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
|
218 |
+
|
219 |
+
state['step'] += 1
|
220 |
+
bias_correction1 = 1 - beta1 ** state['step']
|
221 |
+
bias_correction2 = 1 - beta2 ** state['step']
|
222 |
+
|
223 |
+
# perform weight decay, check if decoupled weight decay
|
224 |
+
if self._weight_decouple:
|
225 |
+
if not self._fixed_decay:
|
226 |
+
p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
|
227 |
+
else:
|
228 |
+
p.data.mul_(1.0 - group['weight_decay'])
|
229 |
+
else:
|
230 |
+
if group['weight_decay'] != 0:
|
231 |
+
grad.add_(p.data, alpha=group['weight_decay'])
|
232 |
+
|
233 |
+
# Update first and second moment running average
|
234 |
+
exp_avg.mul_(2*beta1-1).add_(grad, alpha=1 - beta1)
|
235 |
+
grad_residual = grad - exp_avg
|
236 |
+
exp_avg_var.mul_(beta2).addcmul_(
|
237 |
+
grad_residual, grad_residual, value=beta2*(1 - beta2)
|
238 |
+
)
|
239 |
+
|
240 |
+
if amsgrad:
|
241 |
+
max_exp_avg_var = state['max_exp_avg_var']
|
242 |
+
# Maintains the maximum of all 2nd moment running
|
243 |
+
# avg. till now
|
244 |
+
torch.max(
|
245 |
+
max_exp_avg_var, exp_avg_var, out=max_exp_avg_var
|
246 |
+
)
|
247 |
+
|
248 |
+
# Use the max. for normalizing running avg. of gradient
|
249 |
+
denom = (
|
250 |
+
max_exp_avg_var.add_(group['eps']).sqrt()
|
251 |
+
/ math.sqrt(bias_correction2)
|
252 |
+
).add_(group['eps'])
|
253 |
+
else:
|
254 |
+
denom = (
|
255 |
+
exp_avg_var.add_(group['eps']).sqrt()
|
256 |
+
/ math.sqrt(bias_correction2)
|
257 |
+
).add_(group['eps'])
|
258 |
+
|
259 |
+
if not self._rectify:
|
260 |
+
# Default update
|
261 |
+
step_size = group['lr']* math.sqrt(bias_correction2) / bias_correction1
|
262 |
+
p.data.addcdiv_(exp_avg, denom, value=-step_size)
|
263 |
+
|
264 |
+
else: # Rectified update
|
265 |
+
# calculate rho_t
|
266 |
+
state['rho_t'] = state['rho_inf'] - 2 * state[
|
267 |
+
'step'
|
268 |
+
] * beta2 ** state['step'] / (1.0 - beta2 ** state['step'])
|
269 |
+
|
270 |
+
if (
|
271 |
+
state['rho_t'] > 4
|
272 |
+
): # perform Adam style update if variance is small
|
273 |
+
rho_inf, rho_t = state['rho_inf'], state['rho_t']
|
274 |
+
rt = (
|
275 |
+
(rho_t - 4.0)
|
276 |
+
* (rho_t - 2.0)
|
277 |
+
* rho_inf
|
278 |
+
/ (rho_inf - 4.0)
|
279 |
+
/ (rho_inf - 2.0)
|
280 |
+
/ rho_t
|
281 |
+
)
|
282 |
+
rt = math.sqrt(rt)
|
283 |
+
|
284 |
+
step_size = rt * group['lr'] / bias_correction1
|
285 |
+
|
286 |
+
p.data.addcdiv_(-step_size, exp_avg, denom)
|
287 |
+
|
288 |
+
else: # perform SGD style update
|
289 |
+
p.data.add_(-group['lr'], exp_avg)
|
290 |
+
|
291 |
+
return loss
|
append_module.py
ADDED
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import shutil
|
4 |
+
import time
|
5 |
+
from typing import Dict, List, NamedTuple, Tuple
|
6 |
+
from accelerate import Accelerator
|
7 |
+
from torch.autograd.function import Function
|
8 |
+
import glob
|
9 |
+
import math
|
10 |
+
import os
|
11 |
+
import random
|
12 |
+
import hashlib
|
13 |
+
from io import BytesIO
|
14 |
+
|
15 |
+
from tqdm import tqdm
|
16 |
+
import torch
|
17 |
+
from torchvision import transforms
|
18 |
+
from transformers import CLIPTokenizer
|
19 |
+
import diffusers
|
20 |
+
from diffusers import DDPMScheduler, StableDiffusionPipeline
|
21 |
+
import albumentations as albu
|
22 |
+
import numpy as np
|
23 |
+
from PIL import Image
|
24 |
+
import cv2
|
25 |
+
from einops import rearrange
|
26 |
+
from torch import einsum
|
27 |
+
import safetensors.torch
|
28 |
+
|
29 |
+
import library.model_util as model_util
|
30 |
+
import library.train_util as train_util
|
31 |
+
|
32 |
+
#============================================================================================================
|
33 |
+
#AdafactorScheduleに暫定的にinitial_lrを層別に適用できるようにしたもの
|
34 |
+
#============================================================================================================
|
35 |
+
from torch.optim.lr_scheduler import LambdaLR
|
36 |
+
class AdafactorSchedule_append(LambdaLR):
|
37 |
+
"""
|
38 |
+
Since [`~optimization.Adafactor`] performs its own scheduling, if the training loop relies on a scheduler (e.g.,
|
39 |
+
for logging), this class creates a proxy object that retrieves the current lr values from the optimizer.
|
40 |
+
|
41 |
+
It returns `initial_lr` during startup and the actual `lr` during stepping.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(self, optimizer, initial_lr=0.0):
|
45 |
+
def lr_lambda(_):
|
46 |
+
return initial_lr
|
47 |
+
|
48 |
+
for group in optimizer.param_groups:
|
49 |
+
if not type(initial_lr)==list:
|
50 |
+
group["initial_lr"] = initial_lr
|
51 |
+
else:
|
52 |
+
group["initial_lr"] = initial_lr.pop(0)
|
53 |
+
super().__init__(optimizer, lr_lambda)
|
54 |
+
for group in optimizer.param_groups:
|
55 |
+
del group["initial_lr"]
|
56 |
+
|
57 |
+
def get_lr(self):
|
58 |
+
opt = self.optimizer
|
59 |
+
lrs = [
|
60 |
+
opt._get_lr(group, opt.state[group["params"][0]])
|
61 |
+
for group in opt.param_groups
|
62 |
+
if group["params"][0].grad is not None
|
63 |
+
]
|
64 |
+
if len(lrs) == 0:
|
65 |
+
lrs = self.base_lrs # if called before stepping
|
66 |
+
return lrs
|
67 |
+
|
68 |
+
#============================================================================================================
|
69 |
+
#model_util 内より
|
70 |
+
#============================================================================================================
|
71 |
+
def make_bucket_resolutions_fix(max_reso, min_reso, min_size=256, max_size=1024, divisible=64, step=1):
|
72 |
+
max_width, max_height = max_reso
|
73 |
+
max_area = (max_width // divisible) * (max_height // divisible)
|
74 |
+
|
75 |
+
min_widht, min_height = min_reso
|
76 |
+
min_area = (min_widht // divisible) * (min_height // divisible)
|
77 |
+
|
78 |
+
area_size_list = []
|
79 |
+
area_size_resos_list = []
|
80 |
+
_max_area = max_area
|
81 |
+
|
82 |
+
while True:
|
83 |
+
resos = set()
|
84 |
+
size = int(math.sqrt(_max_area)) * divisible
|
85 |
+
resos.add((size, size))
|
86 |
+
|
87 |
+
size = min_size
|
88 |
+
while size <= max_size:
|
89 |
+
width = size
|
90 |
+
height = min(max_size, (_max_area // (width // divisible)) * divisible)
|
91 |
+
resos.add((width, height))
|
92 |
+
resos.add((height, width))
|
93 |
+
|
94 |
+
# # make additional resos
|
95 |
+
# if width >= height and width - divisible >= min_size:
|
96 |
+
# resos.add((width - divisible, height))
|
97 |
+
# resos.add((height, width - divisible))
|
98 |
+
# if height >= width and height - divisible >= min_size:
|
99 |
+
# resos.add((width, height - divisible))
|
100 |
+
# resos.add((height - divisible, width))
|
101 |
+
|
102 |
+
size += divisible
|
103 |
+
|
104 |
+
resos = list(resos)
|
105 |
+
resos.sort()
|
106 |
+
|
107 |
+
#aspect_ratios = [w / h for w, h in resos]
|
108 |
+
area_size_list.append(_max_area)
|
109 |
+
area_size_resos_list.append(resos)
|
110 |
+
#area_size_ratio_list.append(aspect_ratios)
|
111 |
+
|
112 |
+
_max_area -= step
|
113 |
+
if _max_area < min_area:
|
114 |
+
break
|
115 |
+
return area_size_resos_list, area_size_list
|
116 |
+
|
117 |
+
#============================================================================================================
|
118 |
+
#train_util 内より
|
119 |
+
#============================================================================================================
|
120 |
+
class BucketManager_append(train_util.BucketManager):
|
121 |
+
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps, min_reso=None, area_step=None) -> None:
|
122 |
+
super().__init__(no_upscale, max_reso, min_size, max_size, reso_steps)
|
123 |
+
print("BucketManager_appendを作成しました")
|
124 |
+
if min_reso is None:
|
125 |
+
self.min_reso = None
|
126 |
+
self.min_area = None
|
127 |
+
else:
|
128 |
+
self.min_reso = min_reso
|
129 |
+
self.min_area = min_reso[0] * min_reso[1]
|
130 |
+
self.area_step = area_step
|
131 |
+
self.area_sizes_flag = False
|
132 |
+
|
133 |
+
def make_buckets(self):
|
134 |
+
if self.min_reso:
|
135 |
+
print(f"make_resolution append")
|
136 |
+
resos, area_sizes = make_bucket_resolutions_fix(self.max_reso, self.min_reso, self.min_size, self.max_size, self.reso_steps, self.area_step)
|
137 |
+
self.set_predefined_resos(resos, area_sizes)
|
138 |
+
else:
|
139 |
+
resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps)
|
140 |
+
self.set_predefined_resos(resos)
|
141 |
+
|
142 |
+
def set_predefined_resos(self, resos, area_sizes=None):
|
143 |
+
# 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
|
144 |
+
if area_sizes:
|
145 |
+
self.area_sizes_flag = True
|
146 |
+
self.predefined_area_sizes = np.array(area_sizes.copy())
|
147 |
+
self.predefined_resos_list = resos.copy()
|
148 |
+
self.predefined_resos_set_list = [set(reso) for reso in resos]
|
149 |
+
self.predefined_aspect_ratios_list = [np.array([w/h for w,h in reso]) for reso in resos]
|
150 |
+
self.predefined_resos = None
|
151 |
+
self.predefined_resos_set = None
|
152 |
+
self.predefined_aspect_ratios = None
|
153 |
+
else:
|
154 |
+
self.area_sizes_flag = False
|
155 |
+
self.predefined_area_sizes = None
|
156 |
+
self.predefined_resos = resos.copy()
|
157 |
+
self.predefined_resos_set = set(resos)
|
158 |
+
self.predefined_aspect_ratios = np.array([w / h for w, h in resos])
|
159 |
+
|
160 |
+
def select_bucket(self, image_width, image_height):
|
161 |
+
# 画像サイズを算出する
|
162 |
+
area_size = (image_width//64) * (image_height//64)
|
163 |
+
aspect_ratio = image_width / image_height
|
164 |
+
bucket_size_id = None
|
165 |
+
# 拡張したバケットサイズを使うために画像サイズのエリアを決定する
|
166 |
+
if self.area_sizes_flag:
|
167 |
+
size_errors = self.predefined_area_sizes - area_size
|
168 |
+
bucket_size_id = np.abs(size_errors).argmin()
|
169 |
+
#一定の範囲を探索して使用する画像サイズを確定する
|
170 |
+
serch_size_range = 1
|
171 |
+
bucket_size_id_list = [bucket_size_id]
|
172 |
+
for i in range(serch_size_range):
|
173 |
+
if bucket_size_id - i <0:
|
174 |
+
bucket_size_id_list.append(bucket_size_id + i + 1)
|
175 |
+
elif bucket_size_id + 1 + i >= len(self.predefined_resos_list):
|
176 |
+
bucket_size_id_list.append(bucket_size_id - i - 1)
|
177 |
+
else:
|
178 |
+
bucket_size_id_list.append(bucket_size_id - i - 1)
|
179 |
+
bucket_size_id_list.append(bucket_size_id + i + 1)
|
180 |
+
_min_error = 1000.
|
181 |
+
_min_id = bucket_size_id
|
182 |
+
for now_size_id in bucket_size_id:
|
183 |
+
self.predefined_aspect_ratios = self.predefined_aspect_ratios_list[now_size_id]
|
184 |
+
ar_errors = self.predefined_aspect_ratios - aspect_ratio
|
185 |
+
ar_error = np.abs(ar_errors).min()
|
186 |
+
if _min_error > ar_error:
|
187 |
+
_min_error = ar_error
|
188 |
+
_min_id = now_size_id
|
189 |
+
if _min_error == 0.:
|
190 |
+
break
|
191 |
+
bucket_size_id = _min_id
|
192 |
+
del _min_error, _min_id, ar_error #余計なものは掃除
|
193 |
+
self.predefined_resos = self.predefined_resos_list[bucket_size_id]
|
194 |
+
self.predefined_resos_set = self.predefined_resos_set_list[bucket_size_id]
|
195 |
+
self.predefined_aspect_ratios = self.predefined_aspect_ratios_list[bucket_size_id]
|
196 |
+
# --ここから処理はそのまま
|
197 |
+
if not self.no_upscale:
|
198 |
+
# 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する
|
199 |
+
reso = (image_width, image_height)
|
200 |
+
if reso in self.predefined_resos_set:
|
201 |
+
pass
|
202 |
+
else:
|
203 |
+
ar_errors = self.predefined_aspect_ratios - aspect_ratio
|
204 |
+
predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
|
205 |
+
reso = self.predefined_resos[predefined_bucket_id]
|
206 |
+
|
207 |
+
ar_reso = reso[0] / reso[1]
|
208 |
+
if aspect_ratio > ar_reso: # 横が長い→縦を合わせる
|
209 |
+
scale = reso[1] / image_height
|
210 |
+
else:
|
211 |
+
scale = reso[0] / image_width
|
212 |
+
|
213 |
+
resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
|
214 |
+
# print("use predef", image_width, image_height, reso, resized_size)
|
215 |
+
else:
|
216 |
+
if image_width * image_height > self.max_area:
|
217 |
+
# 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める
|
218 |
+
resized_width = math.sqrt(self.max_area * aspect_ratio)
|
219 |
+
resized_height = self.max_area / resized_width
|
220 |
+
assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal"
|
221 |
+
|
222 |
+
# リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ
|
223 |
+
# 元のbucketingと同じロジック
|
224 |
+
b_width_rounded = self.round_to_steps(resized_width)
|
225 |
+
b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio)
|
226 |
+
ar_width_rounded = b_width_rounded / b_height_in_wr
|
227 |
+
|
228 |
+
b_height_rounded = self.round_to_steps(resized_height)
|
229 |
+
b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio)
|
230 |
+
ar_height_rounded = b_width_in_hr / b_height_rounded
|
231 |
+
|
232 |
+
# print(b_width_rounded, b_height_in_wr, ar_width_rounded)
|
233 |
+
# print(b_width_in_hr, b_height_rounded, ar_height_rounded)
|
234 |
+
|
235 |
+
if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio):
|
236 |
+
resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + .5))
|
237 |
+
else:
|
238 |
+
resized_size = (int(b_height_rounded * aspect_ratio + .5), b_height_rounded)
|
239 |
+
# print(resized_size)
|
240 |
+
else:
|
241 |
+
resized_size = (image_width, image_height) # リサイズは不要
|
242 |
+
|
243 |
+
# 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする)
|
244 |
+
bucket_width = resized_size[0] - resized_size[0] % self.reso_steps
|
245 |
+
bucket_height = resized_size[1] - resized_size[1] % self.reso_steps
|
246 |
+
# print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height)
|
247 |
+
|
248 |
+
reso = (bucket_width, bucket_height)
|
249 |
+
|
250 |
+
self.add_if_new_reso(reso)
|
251 |
+
|
252 |
+
ar_error = (reso[0] / reso[1]) - aspect_ratio
|
253 |
+
return reso, resized_size, ar_error
|
254 |
+
|
255 |
+
class DreamBoothDataset(train_util.DreamBoothDataset):
|
256 |
+
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset, min_resolution=None, area_step=None) -> None:
|
257 |
+
print("use append DreamBoothDataset")
|
258 |
+
self.min_resolution = min_resolution
|
259 |
+
self.area_step = area_step
|
260 |
+
super().__init__(batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens,
|
261 |
+
resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight,
|
262 |
+
flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
263 |
+
def make_buckets(self):
|
264 |
+
'''
|
265 |
+
bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
|
266 |
+
min_size and max_size are ignored when enable_bucket is False
|
267 |
+
'''
|
268 |
+
print("loading image sizes.")
|
269 |
+
for info in tqdm(self.image_data.values()):
|
270 |
+
if info.image_size is None:
|
271 |
+
info.image_size = self.get_image_size(info.absolute_path)
|
272 |
+
|
273 |
+
if self.enable_bucket:
|
274 |
+
print("make buckets")
|
275 |
+
else:
|
276 |
+
print("prepare dataset")
|
277 |
+
|
278 |
+
# bucketを作成し、画像をbucketに振り分ける
|
279 |
+
if self.enable_bucket:
|
280 |
+
if self.bucket_manager is None: # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み
|
281 |
+
#======================================================================change
|
282 |
+
if self.min_resolution:
|
283 |
+
self.bucket_manager = BucketManager_append(self.bucket_no_upscale, (self.width, self.height),
|
284 |
+
self.min_bucket_reso, self.max_bucket_reso, self.bucket_reso_steps, self.min_resolution, self.area_step)
|
285 |
+
else:
|
286 |
+
self.bucket_manager = train_util.BucketManager(self.bucket_no_upscale, (self.width, self.height),
|
287 |
+
self.min_bucket_reso, self.max_bucket_reso, self.bucket_reso_steps)
|
288 |
+
#======================================================================change
|
289 |
+
if not self.bucket_no_upscale:
|
290 |
+
self.bucket_manager.make_buckets()
|
291 |
+
else:
|
292 |
+
print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
|
293 |
+
|
294 |
+
img_ar_errors = []
|
295 |
+
for image_info in self.image_data.values():
|
296 |
+
image_width, image_height = image_info.image_size
|
297 |
+
image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(image_width, image_height)
|
298 |
+
|
299 |
+
# print(image_info.image_key, image_info.bucket_reso)
|
300 |
+
img_ar_errors.append(abs(ar_error))
|
301 |
+
|
302 |
+
self.bucket_manager.sort()
|
303 |
+
else:
|
304 |
+
self.bucket_manager = train_util.BucketManager(False, (self.width, self.height), None, None, None)
|
305 |
+
self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ
|
306 |
+
for image_info in self.image_data.values():
|
307 |
+
image_width, image_height = image_info.image_size
|
308 |
+
image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height)
|
309 |
+
|
310 |
+
for image_info in self.image_data.values():
|
311 |
+
for _ in range(image_info.num_repeats):
|
312 |
+
self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key)
|
313 |
+
|
314 |
+
# bucket情報を表示、格納する
|
315 |
+
if self.enable_bucket:
|
316 |
+
self.bucket_info = {"buckets": {}}
|
317 |
+
print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
|
318 |
+
for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)):
|
319 |
+
count = len(bucket)
|
320 |
+
if count > 0:
|
321 |
+
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
|
322 |
+
print(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
|
323 |
+
|
324 |
+
img_ar_errors = np.array(img_ar_errors)
|
325 |
+
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
|
326 |
+
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
|
327 |
+
print(f"mean ar error (without repeats): {mean_img_ar_error}")
|
328 |
+
|
329 |
+
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
|
330 |
+
self.buckets_indices: List(train_util.BucketBatchIndex) = []
|
331 |
+
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
|
332 |
+
batch_count = int(math.ceil(len(bucket) / self.batch_size))
|
333 |
+
for batch_index in range(batch_count):
|
334 |
+
self.buckets_indices.append(train_util.BucketBatchIndex(bucket_index, self.batch_size, batch_index))
|
335 |
+
|
336 |
+
# ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す
|
337 |
+
# 学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる
|
338 |
+
#
|
339 |
+
# # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
|
340 |
+
# # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
|
341 |
+
# # そのためバッチサイズを画像種類までに制限する
|
342 |
+
# # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
|
343 |
+
# # TO DO 正則化画像をepochまたがりで利用する仕組み
|
344 |
+
# num_of_image_types = len(set(bucket))
|
345 |
+
# bucket_batch_size = min(self.batch_size, num_of_image_types)
|
346 |
+
# batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
|
347 |
+
# # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
|
348 |
+
# for batch_index in range(batch_count):
|
349 |
+
# self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
|
350 |
+
# ↑ここまで
|
351 |
+
|
352 |
+
self.shuffle_buckets()
|
353 |
+
self._length = len(self.buckets_indices)
|
354 |
+
|
355 |
+
class FineTuningDataset(train_util.FineTuningDataset):
|
356 |
+
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
|
357 |
+
train_util.glob_images = glob_images
|
358 |
+
super().__init__( json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
359 |
+
resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range,
|
360 |
+
random_crop, dataset_repeats, debug_dataset)
|
361 |
+
|
362 |
+
def glob_images(directory, base="*", npz_flag=True):
|
363 |
+
img_paths = []
|
364 |
+
dots = []
|
365 |
+
for ext in train_util.IMAGE_EXTENSIONS:
|
366 |
+
dots.append(ext)
|
367 |
+
if npz_flag:
|
368 |
+
dots.append(".npz")
|
369 |
+
for ext in dots:
|
370 |
+
if base == '*':
|
371 |
+
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
|
372 |
+
else:
|
373 |
+
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
374 |
+
return img_paths
|
375 |
+
|
376 |
+
#============================================================================================================
|
377 |
+
#networks.lora
|
378 |
+
#============================================================================================================
|
379 |
+
from networks.lora import LoRANetwork
|
380 |
+
def replace_prepare_optimizer_params(networks):
|
381 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, scheduler_lr=None, loranames=None):
|
382 |
+
|
383 |
+
def enumerate_params(loras, lora_name=None):
|
384 |
+
params = []
|
385 |
+
for lora in loras:
|
386 |
+
if lora_name is not None:
|
387 |
+
if lora_name in lora.lora_name:
|
388 |
+
params.extend(lora.parameters())
|
389 |
+
else:
|
390 |
+
params.extend(lora.parameters())
|
391 |
+
return params
|
392 |
+
|
393 |
+
self.requires_grad_(True)
|
394 |
+
all_params = []
|
395 |
+
ret_scheduler_lr = []
|
396 |
+
|
397 |
+
if loranames is not None:
|
398 |
+
textencoder_names = [None]
|
399 |
+
unet_names = [None]
|
400 |
+
if "text_encoder" in loranames:
|
401 |
+
textencoder_names = loranames["text_encoder"]
|
402 |
+
if "unet" in loranames:
|
403 |
+
unet_names = loranames["unet"]
|
404 |
+
|
405 |
+
if self.text_encoder_loras:
|
406 |
+
for textencoder_name in textencoder_names:
|
407 |
+
param_data = {'params': enumerate_params(self.text_encoder_loras, lora_name=textencoder_name)}
|
408 |
+
if text_encoder_lr is not None:
|
409 |
+
param_data['lr'] = text_encoder_lr
|
410 |
+
if scheduler_lr is not None:
|
411 |
+
ret_scheduler_lr.append(scheduler_lr[0])
|
412 |
+
all_params.append(param_data)
|
413 |
+
|
414 |
+
if self.unet_loras:
|
415 |
+
for unet_name in unet_names:
|
416 |
+
param_data = {'params': enumerate_params(self.unet_loras, lora_name=unet_name)}
|
417 |
+
if unet_lr is not None:
|
418 |
+
param_data['lr'] = unet_lr
|
419 |
+
if scheduler_lr is not None:
|
420 |
+
ret_scheduler_lr.append(scheduler_lr[1])
|
421 |
+
all_params.append(param_data)
|
422 |
+
|
423 |
+
return all_params, ret_scheduler_lr
|
424 |
+
|
425 |
+
LoRANetwork.prepare_optimizer_params = prepare_optimizer_params
|
426 |
+
|
427 |
+
#============================================================================================================
|
428 |
+
#新規追加
|
429 |
+
#============================================================================================================
|
430 |
+
def add_append_arguments(parser: argparse.ArgumentParser):
|
431 |
+
# for train_network_opt.py
|
432 |
+
parser.add_argument("--optimizer", type=str, default="AdamW", choices=["AdamW", "RAdam", "AdaBound", "AdaBelief", "AggMo", "AdamP", "Adastand", "Adastand_belief", "Apollo", "Lamb", "Ranger", "RangerVA", "Lookahead_Adam", "Lookahead_DiffGrad", "Yogi", "NovoGrad", "QHAdam", "DiffGrad", "MADGRAD", "Adafactor"], help="使用するoptimizerを指定する")
|
433 |
+
parser.add_argument("--optimizer_arg", type=str, default=None, nargs='*')
|
434 |
+
parser.add_argument("--split_lora_networks", action="store_true")
|
435 |
+
parser.add_argument("--split_lora_level", type=int, default=0, help="どれくらい細分化するかの設定 0がunetのみを層別に 1がunetを大枠で分割 2がtextencoder含めて層別")
|
436 |
+
parser.add_argument("--min_resolution", type=str, default=None)
|
437 |
+
parser.add_argument("--area_step", type=int, default=1)
|
438 |
+
parser.add_argument("--config", type=str, default=None)
|
439 |
+
|
440 |
+
def create_split_names(split_flag, split_level):
|
441 |
+
split_names = None
|
442 |
+
if split_flag:
|
443 |
+
split_names = {}
|
444 |
+
text_encoder_names = [None]
|
445 |
+
unet_names = ["lora_unet_mid_block"]
|
446 |
+
if split_level==1:
|
447 |
+
unet_names.append(f"lora_unet_down_blocks_")
|
448 |
+
unet_names.append(f"lora_unet_up_blocks_")
|
449 |
+
elif split_level==2 or split_level==0:
|
450 |
+
if split_level==2:
|
451 |
+
text_encoder_names = []
|
452 |
+
for i in range(12):
|
453 |
+
text_encoder_names.append(f"lora_te_text_model_encoder_layers_{i}_")
|
454 |
+
for i in range(3):
|
455 |
+
unet_names.append(f"lora_unet_down_blocks_{i}")
|
456 |
+
unet_names.append(f"lora_unet_up_blocks_{i+1}")
|
457 |
+
split_names["text_encoder"] = text_encoder_names
|
458 |
+
split_names["unet"] = unet_names
|
459 |
+
return split_names
|
460 |
+
|
461 |
+
def get_config(parser):
|
462 |
+
args = parser.parse_args()
|
463 |
+
if args.config is not None and (not args.config==""):
|
464 |
+
import yaml
|
465 |
+
import datetime
|
466 |
+
if os.path.splitext(args.config)[-1] == ".yaml":
|
467 |
+
args.config = os.path.splitext(args.config)[0]
|
468 |
+
config_path = f"./{args.config}.yaml"
|
469 |
+
if os.path.exists(config_path):
|
470 |
+
print(f"{config_path} から設定を読み込み中...")
|
471 |
+
margs, rest = parser.parse_known_args()
|
472 |
+
with open(config_path, mode="r") as f:
|
473 |
+
configs = yaml.unsafe_load(f)
|
474 |
+
#変数でのやり取りをするためargparserからDict型を取り出す
|
475 |
+
args_dic = vars(args)
|
476 |
+
#デフォから引数指定で変更があるものを確認
|
477 |
+
change_def_dic = {}
|
478 |
+
args_type_dic = {}
|
479 |
+
for key, v in args_dic.items():
|
480 |
+
if not parser.get_default(key) == v:
|
481 |
+
change_def_dic[key] = v
|
482 |
+
#デフォ指定されてるデータ型を取得する
|
483 |
+
for key, act in parser._option_string_actions.items():
|
484 |
+
if key=="-h": continue
|
485 |
+
key = key[2:]
|
486 |
+
args_type_dic[key] = act.type
|
487 |
+
#データタイプの確認とargsにkeyの内容を代入していく
|
488 |
+
for key, v in configs.items():
|
489 |
+
if key in args_dic:
|
490 |
+
if args_dic[key] is not None:
|
491 |
+
new_type = type(args_dic[key])
|
492 |
+
if (not type(v) == new_type) and (not new_type==list):
|
493 |
+
v = new_type(v)
|
494 |
+
else:
|
495 |
+
if v is not None:
|
496 |
+
if not type(v) == args_type_dic[key]:
|
497 |
+
v = args_type_dic[key](v)
|
498 |
+
args_dic[key] = v
|
499 |
+
#最後にデフォから指定が変わってるものを変更する
|
500 |
+
for key, v in change_def_dic.items():
|
501 |
+
args_dic[key] = v
|
502 |
+
else:
|
503 |
+
print(f"{config_path} が見つかりませんでした")
|
504 |
+
return args
|
build/lib/library/__init__.py
ADDED
File without changes
|
build/lib/library/model_util.py
ADDED
@@ -0,0 +1,1180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# v1: split from train_db_fixed.py.
|
2 |
+
# v2: support safetensors
|
3 |
+
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
|
8 |
+
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
9 |
+
from safetensors.torch import load_file, save_file
|
10 |
+
|
11 |
+
# DiffUsers版StableDiffusionのモデルパラメータ
|
12 |
+
NUM_TRAIN_TIMESTEPS = 1000
|
13 |
+
BETA_START = 0.00085
|
14 |
+
BETA_END = 0.0120
|
15 |
+
|
16 |
+
UNET_PARAMS_MODEL_CHANNELS = 320
|
17 |
+
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
|
18 |
+
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
|
19 |
+
UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
|
20 |
+
UNET_PARAMS_IN_CHANNELS = 4
|
21 |
+
UNET_PARAMS_OUT_CHANNELS = 4
|
22 |
+
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
23 |
+
UNET_PARAMS_CONTEXT_DIM = 768
|
24 |
+
UNET_PARAMS_NUM_HEADS = 8
|
25 |
+
|
26 |
+
VAE_PARAMS_Z_CHANNELS = 4
|
27 |
+
VAE_PARAMS_RESOLUTION = 256
|
28 |
+
VAE_PARAMS_IN_CHANNELS = 3
|
29 |
+
VAE_PARAMS_OUT_CH = 3
|
30 |
+
VAE_PARAMS_CH = 128
|
31 |
+
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
|
32 |
+
VAE_PARAMS_NUM_RES_BLOCKS = 2
|
33 |
+
|
34 |
+
# V2
|
35 |
+
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
36 |
+
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
37 |
+
|
38 |
+
# Diffusersの設定を読み込むための参照モデル
|
39 |
+
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
|
40 |
+
DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
|
41 |
+
|
42 |
+
|
43 |
+
# region StableDiffusion->Diffusersの変換コード
|
44 |
+
# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
|
45 |
+
|
46 |
+
|
47 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
48 |
+
"""
|
49 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
50 |
+
"""
|
51 |
+
if n_shave_prefix_segments >= 0:
|
52 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
53 |
+
else:
|
54 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
55 |
+
|
56 |
+
|
57 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
58 |
+
"""
|
59 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
60 |
+
"""
|
61 |
+
mapping = []
|
62 |
+
for old_item in old_list:
|
63 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
64 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
65 |
+
|
66 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
67 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
68 |
+
|
69 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
70 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
71 |
+
|
72 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
73 |
+
|
74 |
+
mapping.append({"old": old_item, "new": new_item})
|
75 |
+
|
76 |
+
return mapping
|
77 |
+
|
78 |
+
|
79 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
80 |
+
"""
|
81 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
82 |
+
"""
|
83 |
+
mapping = []
|
84 |
+
for old_item in old_list:
|
85 |
+
new_item = old_item
|
86 |
+
|
87 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
88 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
89 |
+
|
90 |
+
mapping.append({"old": old_item, "new": new_item})
|
91 |
+
|
92 |
+
return mapping
|
93 |
+
|
94 |
+
|
95 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
96 |
+
"""
|
97 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
98 |
+
"""
|
99 |
+
mapping = []
|
100 |
+
for old_item in old_list:
|
101 |
+
new_item = old_item
|
102 |
+
|
103 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
104 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
105 |
+
|
106 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
107 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
108 |
+
|
109 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
110 |
+
|
111 |
+
mapping.append({"old": old_item, "new": new_item})
|
112 |
+
|
113 |
+
return mapping
|
114 |
+
|
115 |
+
|
116 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
117 |
+
"""
|
118 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
119 |
+
"""
|
120 |
+
mapping = []
|
121 |
+
for old_item in old_list:
|
122 |
+
new_item = old_item
|
123 |
+
|
124 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
125 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
126 |
+
|
127 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
128 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
129 |
+
|
130 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
131 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
132 |
+
|
133 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
134 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
135 |
+
|
136 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
137 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
138 |
+
|
139 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
140 |
+
|
141 |
+
mapping.append({"old": old_item, "new": new_item})
|
142 |
+
|
143 |
+
return mapping
|
144 |
+
|
145 |
+
|
146 |
+
def assign_to_checkpoint(
|
147 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
148 |
+
):
|
149 |
+
"""
|
150 |
+
This does the final conversion step: take locally converted weights and apply a global renaming
|
151 |
+
to them. It splits attention layers, and takes into account additional replacements
|
152 |
+
that may arise.
|
153 |
+
|
154 |
+
Assigns the weights to the new checkpoint.
|
155 |
+
"""
|
156 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
157 |
+
|
158 |
+
# Splits the attention layers into three variables.
|
159 |
+
if attention_paths_to_split is not None:
|
160 |
+
for path, path_map in attention_paths_to_split.items():
|
161 |
+
old_tensor = old_checkpoint[path]
|
162 |
+
channels = old_tensor.shape[0] // 3
|
163 |
+
|
164 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
165 |
+
|
166 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
167 |
+
|
168 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
169 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
170 |
+
|
171 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
172 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
173 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
174 |
+
|
175 |
+
for path in paths:
|
176 |
+
new_path = path["new"]
|
177 |
+
|
178 |
+
# These have already been assigned
|
179 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
180 |
+
continue
|
181 |
+
|
182 |
+
# Global renaming happens here
|
183 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
184 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
185 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
186 |
+
|
187 |
+
if additional_replacements is not None:
|
188 |
+
for replacement in additional_replacements:
|
189 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
190 |
+
|
191 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
192 |
+
if "proj_attn.weight" in new_path:
|
193 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
194 |
+
else:
|
195 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
196 |
+
|
197 |
+
|
198 |
+
def conv_attn_to_linear(checkpoint):
|
199 |
+
keys = list(checkpoint.keys())
|
200 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
201 |
+
for key in keys:
|
202 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
203 |
+
if checkpoint[key].ndim > 2:
|
204 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
205 |
+
elif "proj_attn.weight" in key:
|
206 |
+
if checkpoint[key].ndim > 2:
|
207 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
208 |
+
|
209 |
+
|
210 |
+
def linear_transformer_to_conv(checkpoint):
|
211 |
+
keys = list(checkpoint.keys())
|
212 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
213 |
+
for key in keys:
|
214 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
215 |
+
if checkpoint[key].ndim == 2:
|
216 |
+
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
217 |
+
|
218 |
+
|
219 |
+
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
220 |
+
"""
|
221 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
222 |
+
"""
|
223 |
+
|
224 |
+
# extract state_dict for UNet
|
225 |
+
unet_state_dict = {}
|
226 |
+
unet_key = "model.diffusion_model."
|
227 |
+
keys = list(checkpoint.keys())
|
228 |
+
for key in keys:
|
229 |
+
if key.startswith(unet_key):
|
230 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
231 |
+
|
232 |
+
new_checkpoint = {}
|
233 |
+
|
234 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
235 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
236 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
237 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
238 |
+
|
239 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
240 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
241 |
+
|
242 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
243 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
244 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
245 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
246 |
+
|
247 |
+
# Retrieves the keys for the input blocks only
|
248 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
249 |
+
input_blocks = {
|
250 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
|
251 |
+
for layer_id in range(num_input_blocks)
|
252 |
+
}
|
253 |
+
|
254 |
+
# Retrieves the keys for the middle blocks only
|
255 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
256 |
+
middle_blocks = {
|
257 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
|
258 |
+
for layer_id in range(num_middle_blocks)
|
259 |
+
}
|
260 |
+
|
261 |
+
# Retrieves the keys for the output blocks only
|
262 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
263 |
+
output_blocks = {
|
264 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
|
265 |
+
for layer_id in range(num_output_blocks)
|
266 |
+
}
|
267 |
+
|
268 |
+
for i in range(1, num_input_blocks):
|
269 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
270 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
271 |
+
|
272 |
+
resnets = [
|
273 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
274 |
+
]
|
275 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
276 |
+
|
277 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
278 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
279 |
+
f"input_blocks.{i}.0.op.weight"
|
280 |
+
)
|
281 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
282 |
+
f"input_blocks.{i}.0.op.bias"
|
283 |
+
)
|
284 |
+
|
285 |
+
paths = renew_resnet_paths(resnets)
|
286 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
287 |
+
assign_to_checkpoint(
|
288 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
289 |
+
)
|
290 |
+
|
291 |
+
if len(attentions):
|
292 |
+
paths = renew_attention_paths(attentions)
|
293 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
294 |
+
assign_to_checkpoint(
|
295 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
296 |
+
)
|
297 |
+
|
298 |
+
resnet_0 = middle_blocks[0]
|
299 |
+
attentions = middle_blocks[1]
|
300 |
+
resnet_1 = middle_blocks[2]
|
301 |
+
|
302 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
303 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
304 |
+
|
305 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
306 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
307 |
+
|
308 |
+
attentions_paths = renew_attention_paths(attentions)
|
309 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
310 |
+
assign_to_checkpoint(
|
311 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
312 |
+
)
|
313 |
+
|
314 |
+
for i in range(num_output_blocks):
|
315 |
+
block_id = i // (config["layers_per_block"] + 1)
|
316 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
317 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
318 |
+
output_block_list = {}
|
319 |
+
|
320 |
+
for layer in output_block_layers:
|
321 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
322 |
+
if layer_id in output_block_list:
|
323 |
+
output_block_list[layer_id].append(layer_name)
|
324 |
+
else:
|
325 |
+
output_block_list[layer_id] = [layer_name]
|
326 |
+
|
327 |
+
if len(output_block_list) > 1:
|
328 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
329 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
330 |
+
|
331 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
332 |
+
paths = renew_resnet_paths(resnets)
|
333 |
+
|
334 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
335 |
+
assign_to_checkpoint(
|
336 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
337 |
+
)
|
338 |
+
|
339 |
+
# オリジナル:
|
340 |
+
# if ["conv.weight", "conv.bias"] in output_block_list.values():
|
341 |
+
# index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
342 |
+
|
343 |
+
# biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
|
344 |
+
for l in output_block_list.values():
|
345 |
+
l.sort()
|
346 |
+
|
347 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
348 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
349 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
350 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
351 |
+
]
|
352 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
353 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
354 |
+
]
|
355 |
+
|
356 |
+
# Clear attentions as they have been attributed above.
|
357 |
+
if len(attentions) == 2:
|
358 |
+
attentions = []
|
359 |
+
|
360 |
+
if len(attentions):
|
361 |
+
paths = renew_attention_paths(attentions)
|
362 |
+
meta_path = {
|
363 |
+
"old": f"output_blocks.{i}.1",
|
364 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
365 |
+
}
|
366 |
+
assign_to_checkpoint(
|
367 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
368 |
+
)
|
369 |
+
else:
|
370 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
371 |
+
for path in resnet_0_paths:
|
372 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
373 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
374 |
+
|
375 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
376 |
+
|
377 |
+
# SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
|
378 |
+
if v2:
|
379 |
+
linear_transformer_to_conv(new_checkpoint)
|
380 |
+
|
381 |
+
return new_checkpoint
|
382 |
+
|
383 |
+
|
384 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
385 |
+
# extract state dict for VAE
|
386 |
+
vae_state_dict = {}
|
387 |
+
vae_key = "first_stage_model."
|
388 |
+
keys = list(checkpoint.keys())
|
389 |
+
for key in keys:
|
390 |
+
if key.startswith(vae_key):
|
391 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
392 |
+
# if len(vae_state_dict) == 0:
|
393 |
+
# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
|
394 |
+
# vae_state_dict = checkpoint
|
395 |
+
|
396 |
+
new_checkpoint = {}
|
397 |
+
|
398 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
399 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
400 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
401 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
402 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
403 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
404 |
+
|
405 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
406 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
407 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
408 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
409 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
410 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
411 |
+
|
412 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
413 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
414 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
415 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
416 |
+
|
417 |
+
# Retrieves the keys for the encoder down blocks only
|
418 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
419 |
+
down_blocks = {
|
420 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
421 |
+
}
|
422 |
+
|
423 |
+
# Retrieves the keys for the decoder up blocks only
|
424 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
425 |
+
up_blocks = {
|
426 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
427 |
+
}
|
428 |
+
|
429 |
+
for i in range(num_down_blocks):
|
430 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
431 |
+
|
432 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
433 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
434 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
435 |
+
)
|
436 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
437 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
438 |
+
)
|
439 |
+
|
440 |
+
paths = renew_vae_resnet_paths(resnets)
|
441 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
442 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
443 |
+
|
444 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
445 |
+
num_mid_res_blocks = 2
|
446 |
+
for i in range(1, num_mid_res_blocks + 1):
|
447 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
448 |
+
|
449 |
+
paths = renew_vae_resnet_paths(resnets)
|
450 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
451 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
452 |
+
|
453 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
454 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
455 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
456 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
457 |
+
conv_attn_to_linear(new_checkpoint)
|
458 |
+
|
459 |
+
for i in range(num_up_blocks):
|
460 |
+
block_id = num_up_blocks - 1 - i
|
461 |
+
resnets = [
|
462 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
463 |
+
]
|
464 |
+
|
465 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
466 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
467 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
468 |
+
]
|
469 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
470 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
471 |
+
]
|
472 |
+
|
473 |
+
paths = renew_vae_resnet_paths(resnets)
|
474 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
475 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
476 |
+
|
477 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
478 |
+
num_mid_res_blocks = 2
|
479 |
+
for i in range(1, num_mid_res_blocks + 1):
|
480 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
481 |
+
|
482 |
+
paths = renew_vae_resnet_paths(resnets)
|
483 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
484 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
485 |
+
|
486 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
487 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
488 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
489 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
490 |
+
conv_attn_to_linear(new_checkpoint)
|
491 |
+
return new_checkpoint
|
492 |
+
|
493 |
+
|
494 |
+
def create_unet_diffusers_config(v2):
|
495 |
+
"""
|
496 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
497 |
+
"""
|
498 |
+
# unet_params = original_config.model.params.unet_config.params
|
499 |
+
|
500 |
+
block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
|
501 |
+
|
502 |
+
down_block_types = []
|
503 |
+
resolution = 1
|
504 |
+
for i in range(len(block_out_channels)):
|
505 |
+
block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
|
506 |
+
down_block_types.append(block_type)
|
507 |
+
if i != len(block_out_channels) - 1:
|
508 |
+
resolution *= 2
|
509 |
+
|
510 |
+
up_block_types = []
|
511 |
+
for i in range(len(block_out_channels)):
|
512 |
+
block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
|
513 |
+
up_block_types.append(block_type)
|
514 |
+
resolution //= 2
|
515 |
+
|
516 |
+
config = dict(
|
517 |
+
sample_size=UNET_PARAMS_IMAGE_SIZE,
|
518 |
+
in_channels=UNET_PARAMS_IN_CHANNELS,
|
519 |
+
out_channels=UNET_PARAMS_OUT_CHANNELS,
|
520 |
+
down_block_types=tuple(down_block_types),
|
521 |
+
up_block_types=tuple(up_block_types),
|
522 |
+
block_out_channels=tuple(block_out_channels),
|
523 |
+
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
524 |
+
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
|
525 |
+
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
526 |
+
)
|
527 |
+
|
528 |
+
return config
|
529 |
+
|
530 |
+
|
531 |
+
def create_vae_diffusers_config():
|
532 |
+
"""
|
533 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
534 |
+
"""
|
535 |
+
# vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
536 |
+
# _ = original_config.model.params.first_stage_config.params.embed_dim
|
537 |
+
block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
|
538 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
539 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
540 |
+
|
541 |
+
config = dict(
|
542 |
+
sample_size=VAE_PARAMS_RESOLUTION,
|
543 |
+
in_channels=VAE_PARAMS_IN_CHANNELS,
|
544 |
+
out_channels=VAE_PARAMS_OUT_CH,
|
545 |
+
down_block_types=tuple(down_block_types),
|
546 |
+
up_block_types=tuple(up_block_types),
|
547 |
+
block_out_channels=tuple(block_out_channels),
|
548 |
+
latent_channels=VAE_PARAMS_Z_CHANNELS,
|
549 |
+
layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
|
550 |
+
)
|
551 |
+
return config
|
552 |
+
|
553 |
+
|
554 |
+
def convert_ldm_clip_checkpoint_v1(checkpoint):
|
555 |
+
keys = list(checkpoint.keys())
|
556 |
+
text_model_dict = {}
|
557 |
+
for key in keys:
|
558 |
+
if key.startswith("cond_stage_model.transformer"):
|
559 |
+
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
|
560 |
+
return text_model_dict
|
561 |
+
|
562 |
+
|
563 |
+
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
564 |
+
# 嫌になるくらい違うぞ!
|
565 |
+
def convert_key(key):
|
566 |
+
if not key.startswith("cond_stage_model"):
|
567 |
+
return None
|
568 |
+
|
569 |
+
# common conversion
|
570 |
+
key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
|
571 |
+
key = key.replace("cond_stage_model.model.", "text_model.")
|
572 |
+
|
573 |
+
if "resblocks" in key:
|
574 |
+
# resblocks conversion
|
575 |
+
key = key.replace(".resblocks.", ".layers.")
|
576 |
+
if ".ln_" in key:
|
577 |
+
key = key.replace(".ln_", ".layer_norm")
|
578 |
+
elif ".mlp." in key:
|
579 |
+
key = key.replace(".c_fc.", ".fc1.")
|
580 |
+
key = key.replace(".c_proj.", ".fc2.")
|
581 |
+
elif '.attn.out_proj' in key:
|
582 |
+
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
583 |
+
elif '.attn.in_proj' in key:
|
584 |
+
key = None # 特殊なので後で処理する
|
585 |
+
else:
|
586 |
+
raise ValueError(f"unexpected key in SD: {key}")
|
587 |
+
elif '.positional_embedding' in key:
|
588 |
+
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
589 |
+
elif '.text_projection' in key:
|
590 |
+
key = None # 使われない???
|
591 |
+
elif '.logit_scale' in key:
|
592 |
+
key = None # 使われない???
|
593 |
+
elif '.token_embedding' in key:
|
594 |
+
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
595 |
+
elif '.ln_final' in key:
|
596 |
+
key = key.replace(".ln_final", ".final_layer_norm")
|
597 |
+
return key
|
598 |
+
|
599 |
+
keys = list(checkpoint.keys())
|
600 |
+
new_sd = {}
|
601 |
+
for key in keys:
|
602 |
+
# remove resblocks 23
|
603 |
+
if '.resblocks.23.' in key:
|
604 |
+
continue
|
605 |
+
new_key = convert_key(key)
|
606 |
+
if new_key is None:
|
607 |
+
continue
|
608 |
+
new_sd[new_key] = checkpoint[key]
|
609 |
+
|
610 |
+
# attnの変換
|
611 |
+
for key in keys:
|
612 |
+
if '.resblocks.23.' in key:
|
613 |
+
continue
|
614 |
+
if '.resblocks' in key and '.attn.in_proj_' in key:
|
615 |
+
# 三つに分割
|
616 |
+
values = torch.chunk(checkpoint[key], 3)
|
617 |
+
|
618 |
+
key_suffix = ".weight" if "weight" in key else ".bias"
|
619 |
+
key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
|
620 |
+
key_pfx = key_pfx.replace("_weight", "")
|
621 |
+
key_pfx = key_pfx.replace("_bias", "")
|
622 |
+
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
623 |
+
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
624 |
+
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
625 |
+
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
626 |
+
|
627 |
+
# rename or add position_ids
|
628 |
+
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
|
629 |
+
if ANOTHER_POSITION_IDS_KEY in new_sd:
|
630 |
+
# waifu diffusion v1.4
|
631 |
+
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
|
632 |
+
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
633 |
+
else:
|
634 |
+
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
635 |
+
|
636 |
+
new_sd["text_model.embeddings.position_ids"] = position_ids
|
637 |
+
return new_sd
|
638 |
+
|
639 |
+
# endregion
|
640 |
+
|
641 |
+
|
642 |
+
# region Diffusers->StableDiffusion の変換コード
|
643 |
+
# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
|
644 |
+
|
645 |
+
def conv_transformer_to_linear(checkpoint):
|
646 |
+
keys = list(checkpoint.keys())
|
647 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
648 |
+
for key in keys:
|
649 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
650 |
+
if checkpoint[key].ndim > 2:
|
651 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
652 |
+
|
653 |
+
|
654 |
+
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
655 |
+
unet_conversion_map = [
|
656 |
+
# (stable-diffusion, HF Diffusers)
|
657 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
658 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
659 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
660 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
661 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
662 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
663 |
+
("out.0.weight", "conv_norm_out.weight"),
|
664 |
+
("out.0.bias", "conv_norm_out.bias"),
|
665 |
+
("out.2.weight", "conv_out.weight"),
|
666 |
+
("out.2.bias", "conv_out.bias"),
|
667 |
+
]
|
668 |
+
|
669 |
+
unet_conversion_map_resnet = [
|
670 |
+
# (stable-diffusion, HF Diffusers)
|
671 |
+
("in_layers.0", "norm1"),
|
672 |
+
("in_layers.2", "conv1"),
|
673 |
+
("out_layers.0", "norm2"),
|
674 |
+
("out_layers.3", "conv2"),
|
675 |
+
("emb_layers.1", "time_emb_proj"),
|
676 |
+
("skip_connection", "conv_shortcut"),
|
677 |
+
]
|
678 |
+
|
679 |
+
unet_conversion_map_layer = []
|
680 |
+
for i in range(4):
|
681 |
+
# loop over downblocks/upblocks
|
682 |
+
|
683 |
+
for j in range(2):
|
684 |
+
# loop over resnets/attentions for downblocks
|
685 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
686 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
687 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
688 |
+
|
689 |
+
if i < 3:
|
690 |
+
# no attention layers in down_blocks.3
|
691 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
692 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
693 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
694 |
+
|
695 |
+
for j in range(3):
|
696 |
+
# loop over resnets/attentions for upblocks
|
697 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
698 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
699 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
700 |
+
|
701 |
+
if i > 0:
|
702 |
+
# no attention layers in up_blocks.0
|
703 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
704 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
705 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
706 |
+
|
707 |
+
if i < 3:
|
708 |
+
# no downsample in down_blocks.3
|
709 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
710 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
711 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
712 |
+
|
713 |
+
# no upsample in up_blocks.3
|
714 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
715 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
716 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
717 |
+
|
718 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
719 |
+
sd_mid_atn_prefix = "middle_block.1."
|
720 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
721 |
+
|
722 |
+
for j in range(2):
|
723 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
724 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
725 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
726 |
+
|
727 |
+
# buyer beware: this is a *brittle* function,
|
728 |
+
# and correct output requires that all of these pieces interact in
|
729 |
+
# the exact order in which I have arranged them.
|
730 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
731 |
+
for sd_name, hf_name in unet_conversion_map:
|
732 |
+
mapping[hf_name] = sd_name
|
733 |
+
for k, v in mapping.items():
|
734 |
+
if "resnets" in k:
|
735 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
736 |
+
v = v.replace(hf_part, sd_part)
|
737 |
+
mapping[k] = v
|
738 |
+
for k, v in mapping.items():
|
739 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
740 |
+
v = v.replace(hf_part, sd_part)
|
741 |
+
mapping[k] = v
|
742 |
+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
743 |
+
|
744 |
+
if v2:
|
745 |
+
conv_transformer_to_linear(new_state_dict)
|
746 |
+
|
747 |
+
return new_state_dict
|
748 |
+
|
749 |
+
|
750 |
+
# ================#
|
751 |
+
# VAE Conversion #
|
752 |
+
# ================#
|
753 |
+
|
754 |
+
def reshape_weight_for_sd(w):
|
755 |
+
# convert HF linear weights to SD conv2d weights
|
756 |
+
return w.reshape(*w.shape, 1, 1)
|
757 |
+
|
758 |
+
|
759 |
+
def convert_vae_state_dict(vae_state_dict):
|
760 |
+
vae_conversion_map = [
|
761 |
+
# (stable-diffusion, HF Diffusers)
|
762 |
+
("nin_shortcut", "conv_shortcut"),
|
763 |
+
("norm_out", "conv_norm_out"),
|
764 |
+
("mid.attn_1.", "mid_block.attentions.0."),
|
765 |
+
]
|
766 |
+
|
767 |
+
for i in range(4):
|
768 |
+
# down_blocks have two resnets
|
769 |
+
for j in range(2):
|
770 |
+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
771 |
+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
772 |
+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
773 |
+
|
774 |
+
if i < 3:
|
775 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
776 |
+
sd_downsample_prefix = f"down.{i}.downsample."
|
777 |
+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
778 |
+
|
779 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
780 |
+
sd_upsample_prefix = f"up.{3-i}.upsample."
|
781 |
+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
782 |
+
|
783 |
+
# up_blocks have three resnets
|
784 |
+
# also, up blocks in hf are numbered in reverse from sd
|
785 |
+
for j in range(3):
|
786 |
+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
787 |
+
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
788 |
+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
789 |
+
|
790 |
+
# this part accounts for mid blocks in both the encoder and the decoder
|
791 |
+
for i in range(2):
|
792 |
+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
793 |
+
sd_mid_res_prefix = f"mid.block_{i+1}."
|
794 |
+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
795 |
+
|
796 |
+
vae_conversion_map_attn = [
|
797 |
+
# (stable-diffusion, HF Diffusers)
|
798 |
+
("norm.", "group_norm."),
|
799 |
+
("q.", "query."),
|
800 |
+
("k.", "key."),
|
801 |
+
("v.", "value."),
|
802 |
+
("proj_out.", "proj_attn."),
|
803 |
+
]
|
804 |
+
|
805 |
+
mapping = {k: k for k in vae_state_dict.keys()}
|
806 |
+
for k, v in mapping.items():
|
807 |
+
for sd_part, hf_part in vae_conversion_map:
|
808 |
+
v = v.replace(hf_part, sd_part)
|
809 |
+
mapping[k] = v
|
810 |
+
for k, v in mapping.items():
|
811 |
+
if "attentions" in k:
|
812 |
+
for sd_part, hf_part in vae_conversion_map_attn:
|
813 |
+
v = v.replace(hf_part, sd_part)
|
814 |
+
mapping[k] = v
|
815 |
+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
816 |
+
weights_to_convert = ["q", "k", "v", "proj_out"]
|
817 |
+
for k, v in new_state_dict.items():
|
818 |
+
for weight_name in weights_to_convert:
|
819 |
+
if f"mid.attn_1.{weight_name}.weight" in k:
|
820 |
+
# print(f"Reshaping {k} for SD format")
|
821 |
+
new_state_dict[k] = reshape_weight_for_sd(v)
|
822 |
+
|
823 |
+
return new_state_dict
|
824 |
+
|
825 |
+
|
826 |
+
# endregion
|
827 |
+
|
828 |
+
# region 自作のモデル読み書きなど
|
829 |
+
|
830 |
+
def is_safetensors(path):
|
831 |
+
return os.path.splitext(path)[1].lower() == '.safetensors'
|
832 |
+
|
833 |
+
|
834 |
+
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
835 |
+
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
836 |
+
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
837 |
+
('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
|
838 |
+
('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
|
839 |
+
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
|
840 |
+
]
|
841 |
+
|
842 |
+
if is_safetensors(ckpt_path):
|
843 |
+
checkpoint = None
|
844 |
+
state_dict = load_file(ckpt_path, "cpu")
|
845 |
+
else:
|
846 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
847 |
+
if "state_dict" in checkpoint:
|
848 |
+
state_dict = checkpoint["state_dict"]
|
849 |
+
else:
|
850 |
+
state_dict = checkpoint
|
851 |
+
checkpoint = None
|
852 |
+
|
853 |
+
key_reps = []
|
854 |
+
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
855 |
+
for key in state_dict.keys():
|
856 |
+
if key.startswith(rep_from):
|
857 |
+
new_key = rep_to + key[len(rep_from):]
|
858 |
+
key_reps.append((key, new_key))
|
859 |
+
|
860 |
+
for key, new_key in key_reps:
|
861 |
+
state_dict[new_key] = state_dict[key]
|
862 |
+
del state_dict[key]
|
863 |
+
|
864 |
+
return checkpoint, state_dict
|
865 |
+
|
866 |
+
|
867 |
+
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
868 |
+
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
869 |
+
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
870 |
+
if dtype is not None:
|
871 |
+
for k, v in state_dict.items():
|
872 |
+
if type(v) is torch.Tensor:
|
873 |
+
state_dict[k] = v.to(dtype)
|
874 |
+
|
875 |
+
# Convert the UNet2DConditionModel model.
|
876 |
+
unet_config = create_unet_diffusers_config(v2)
|
877 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
|
878 |
+
|
879 |
+
unet = UNet2DConditionModel(**unet_config)
|
880 |
+
info = unet.load_state_dict(converted_unet_checkpoint)
|
881 |
+
print("loading u-net:", info)
|
882 |
+
|
883 |
+
# Convert the VAE model.
|
884 |
+
vae_config = create_vae_diffusers_config()
|
885 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
|
886 |
+
|
887 |
+
vae = AutoencoderKL(**vae_config)
|
888 |
+
info = vae.load_state_dict(converted_vae_checkpoint)
|
889 |
+
print("loading vae:", info)
|
890 |
+
|
891 |
+
# convert text_model
|
892 |
+
if v2:
|
893 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
|
894 |
+
cfg = CLIPTextConfig(
|
895 |
+
vocab_size=49408,
|
896 |
+
hidden_size=1024,
|
897 |
+
intermediate_size=4096,
|
898 |
+
num_hidden_layers=23,
|
899 |
+
num_attention_heads=16,
|
900 |
+
max_position_embeddings=77,
|
901 |
+
hidden_act="gelu",
|
902 |
+
layer_norm_eps=1e-05,
|
903 |
+
dropout=0.0,
|
904 |
+
attention_dropout=0.0,
|
905 |
+
initializer_range=0.02,
|
906 |
+
initializer_factor=1.0,
|
907 |
+
pad_token_id=1,
|
908 |
+
bos_token_id=0,
|
909 |
+
eos_token_id=2,
|
910 |
+
model_type="clip_text_model",
|
911 |
+
projection_dim=512,
|
912 |
+
torch_dtype="float32",
|
913 |
+
transformers_version="4.25.0.dev0",
|
914 |
+
)
|
915 |
+
text_model = CLIPTextModel._from_config(cfg)
|
916 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
917 |
+
else:
|
918 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
919 |
+
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
920 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
921 |
+
print("loading text encoder:", info)
|
922 |
+
|
923 |
+
return text_model, vae, unet
|
924 |
+
|
925 |
+
|
926 |
+
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
|
927 |
+
def convert_key(key):
|
928 |
+
# position_idsの除去
|
929 |
+
if ".position_ids" in key:
|
930 |
+
return None
|
931 |
+
|
932 |
+
# common
|
933 |
+
key = key.replace("text_model.encoder.", "transformer.")
|
934 |
+
key = key.replace("text_model.", "")
|
935 |
+
if "layers" in key:
|
936 |
+
# resblocks conversion
|
937 |
+
key = key.replace(".layers.", ".resblocks.")
|
938 |
+
if ".layer_norm" in key:
|
939 |
+
key = key.replace(".layer_norm", ".ln_")
|
940 |
+
elif ".mlp." in key:
|
941 |
+
key = key.replace(".fc1.", ".c_fc.")
|
942 |
+
key = key.replace(".fc2.", ".c_proj.")
|
943 |
+
elif '.self_attn.out_proj' in key:
|
944 |
+
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
945 |
+
elif '.self_attn.' in key:
|
946 |
+
key = None # 特殊なので後で処理する
|
947 |
+
else:
|
948 |
+
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
949 |
+
elif '.position_embedding' in key:
|
950 |
+
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
951 |
+
elif '.token_embedding' in key:
|
952 |
+
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
953 |
+
elif 'final_layer_norm' in key:
|
954 |
+
key = key.replace("final_layer_norm", "ln_final")
|
955 |
+
return key
|
956 |
+
|
957 |
+
keys = list(checkpoint.keys())
|
958 |
+
new_sd = {}
|
959 |
+
for key in keys:
|
960 |
+
new_key = convert_key(key)
|
961 |
+
if new_key is None:
|
962 |
+
continue
|
963 |
+
new_sd[new_key] = checkpoint[key]
|
964 |
+
|
965 |
+
# attnの変換
|
966 |
+
for key in keys:
|
967 |
+
if 'layers' in key and 'q_proj' in key:
|
968 |
+
# 三つを結合
|
969 |
+
key_q = key
|
970 |
+
key_k = key.replace("q_proj", "k_proj")
|
971 |
+
key_v = key.replace("q_proj", "v_proj")
|
972 |
+
|
973 |
+
value_q = checkpoint[key_q]
|
974 |
+
value_k = checkpoint[key_k]
|
975 |
+
value_v = checkpoint[key_v]
|
976 |
+
value = torch.cat([value_q, value_k, value_v])
|
977 |
+
|
978 |
+
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
979 |
+
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
980 |
+
new_sd[new_key] = value
|
981 |
+
|
982 |
+
# 最後の層などを捏造するか
|
983 |
+
if make_dummy_weights:
|
984 |
+
print("make dummy weights for resblock.23, text_projection and logit scale.")
|
985 |
+
keys = list(new_sd.keys())
|
986 |
+
for key in keys:
|
987 |
+
if key.startswith("transformer.resblocks.22."):
|
988 |
+
new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
|
989 |
+
|
990 |
+
# Diffusersに含まれない重みを作っておく
|
991 |
+
new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
|
992 |
+
new_sd['logit_scale'] = torch.tensor(1)
|
993 |
+
|
994 |
+
return new_sd
|
995 |
+
|
996 |
+
|
997 |
+
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
|
998 |
+
if ckpt_path is not None:
|
999 |
+
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
1000 |
+
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
1001 |
+
if checkpoint is None: # safetensors または state_dictのckpt
|
1002 |
+
checkpoint = {}
|
1003 |
+
strict = False
|
1004 |
+
else:
|
1005 |
+
strict = True
|
1006 |
+
if "state_dict" in state_dict:
|
1007 |
+
del state_dict["state_dict"]
|
1008 |
+
else:
|
1009 |
+
# 新しく作る
|
1010 |
+
assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
|
1011 |
+
checkpoint = {}
|
1012 |
+
state_dict = {}
|
1013 |
+
strict = False
|
1014 |
+
|
1015 |
+
def update_sd(prefix, sd):
|
1016 |
+
for k, v in sd.items():
|
1017 |
+
key = prefix + k
|
1018 |
+
assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
|
1019 |
+
if save_dtype is not None:
|
1020 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
1021 |
+
state_dict[key] = v
|
1022 |
+
|
1023 |
+
# Convert the UNet model
|
1024 |
+
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
1025 |
+
update_sd("model.diffusion_model.", unet_state_dict)
|
1026 |
+
|
1027 |
+
# Convert the text encoder model
|
1028 |
+
if v2:
|
1029 |
+
make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製��て作るなどダミーの重みを入れる
|
1030 |
+
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
|
1031 |
+
update_sd("cond_stage_model.model.", text_enc_dict)
|
1032 |
+
else:
|
1033 |
+
text_enc_dict = text_encoder.state_dict()
|
1034 |
+
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
1035 |
+
|
1036 |
+
# Convert the VAE
|
1037 |
+
if vae is not None:
|
1038 |
+
vae_dict = convert_vae_state_dict(vae.state_dict())
|
1039 |
+
update_sd("first_stage_model.", vae_dict)
|
1040 |
+
|
1041 |
+
# Put together new checkpoint
|
1042 |
+
key_count = len(state_dict.keys())
|
1043 |
+
new_ckpt = {'state_dict': state_dict}
|
1044 |
+
|
1045 |
+
if 'epoch' in checkpoint:
|
1046 |
+
epochs += checkpoint['epoch']
|
1047 |
+
if 'global_step' in checkpoint:
|
1048 |
+
steps += checkpoint['global_step']
|
1049 |
+
|
1050 |
+
new_ckpt['epoch'] = epochs
|
1051 |
+
new_ckpt['global_step'] = steps
|
1052 |
+
|
1053 |
+
if is_safetensors(output_file):
|
1054 |
+
# TODO Tensor以外のdictの値を削除したほうがいいか
|
1055 |
+
save_file(state_dict, output_file)
|
1056 |
+
else:
|
1057 |
+
torch.save(new_ckpt, output_file)
|
1058 |
+
|
1059 |
+
return key_count
|
1060 |
+
|
1061 |
+
|
1062 |
+
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
|
1063 |
+
if pretrained_model_name_or_path is None:
|
1064 |
+
# load default settings for v1/v2
|
1065 |
+
if v2:
|
1066 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
|
1067 |
+
else:
|
1068 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
|
1069 |
+
|
1070 |
+
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
1071 |
+
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
1072 |
+
if vae is None:
|
1073 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
1074 |
+
|
1075 |
+
pipeline = StableDiffusionPipeline(
|
1076 |
+
unet=unet,
|
1077 |
+
text_encoder=text_encoder,
|
1078 |
+
vae=vae,
|
1079 |
+
scheduler=scheduler,
|
1080 |
+
tokenizer=tokenizer,
|
1081 |
+
safety_checker=None,
|
1082 |
+
feature_extractor=None,
|
1083 |
+
requires_safety_checker=None,
|
1084 |
+
)
|
1085 |
+
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
1086 |
+
|
1087 |
+
|
1088 |
+
VAE_PREFIX = "first_stage_model."
|
1089 |
+
|
1090 |
+
|
1091 |
+
def load_vae(vae_id, dtype):
|
1092 |
+
print(f"load VAE: {vae_id}")
|
1093 |
+
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
1094 |
+
# Diffusers local/remote
|
1095 |
+
try:
|
1096 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
|
1097 |
+
except EnvironmentError as e:
|
1098 |
+
print(f"exception occurs in loading vae: {e}")
|
1099 |
+
print("retry with subfolder='vae'")
|
1100 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
|
1101 |
+
return vae
|
1102 |
+
|
1103 |
+
# local
|
1104 |
+
vae_config = create_vae_diffusers_config()
|
1105 |
+
|
1106 |
+
if vae_id.endswith(".bin"):
|
1107 |
+
# SD 1.5 VAE on Huggingface
|
1108 |
+
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
|
1109 |
+
else:
|
1110 |
+
# StableDiffusion
|
1111 |
+
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
|
1112 |
+
else torch.load(vae_id, map_location="cpu"))
|
1113 |
+
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
|
1114 |
+
|
1115 |
+
# vae only or full model
|
1116 |
+
full_model = False
|
1117 |
+
for vae_key in vae_sd:
|
1118 |
+
if vae_key.startswith(VAE_PREFIX):
|
1119 |
+
full_model = True
|
1120 |
+
break
|
1121 |
+
if not full_model:
|
1122 |
+
sd = {}
|
1123 |
+
for key, value in vae_sd.items():
|
1124 |
+
sd[VAE_PREFIX + key] = value
|
1125 |
+
vae_sd = sd
|
1126 |
+
del sd
|
1127 |
+
|
1128 |
+
# Convert the VAE model.
|
1129 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
1130 |
+
|
1131 |
+
vae = AutoencoderKL(**vae_config)
|
1132 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
1133 |
+
return vae
|
1134 |
+
|
1135 |
+
# endregion
|
1136 |
+
|
1137 |
+
|
1138 |
+
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
|
1139 |
+
max_width, max_height = max_reso
|
1140 |
+
max_area = (max_width // divisible) * (max_height // divisible)
|
1141 |
+
|
1142 |
+
resos = set()
|
1143 |
+
|
1144 |
+
size = int(math.sqrt(max_area)) * divisible
|
1145 |
+
resos.add((size, size))
|
1146 |
+
|
1147 |
+
size = min_size
|
1148 |
+
while size <= max_size:
|
1149 |
+
width = size
|
1150 |
+
height = min(max_size, (max_area // (width // divisible)) * divisible)
|
1151 |
+
resos.add((width, height))
|
1152 |
+
resos.add((height, width))
|
1153 |
+
|
1154 |
+
# # make additional resos
|
1155 |
+
# if width >= height and width - divisible >= min_size:
|
1156 |
+
# resos.add((width - divisible, height))
|
1157 |
+
# resos.add((height, width - divisible))
|
1158 |
+
# if height >= width and height - divisible >= min_size:
|
1159 |
+
# resos.add((width, height - divisible))
|
1160 |
+
# resos.add((height - divisible, width))
|
1161 |
+
|
1162 |
+
size += divisible
|
1163 |
+
|
1164 |
+
resos = list(resos)
|
1165 |
+
resos.sort()
|
1166 |
+
return resos
|
1167 |
+
|
1168 |
+
|
1169 |
+
if __name__ == '__main__':
|
1170 |
+
resos = make_bucket_resolutions((512, 768))
|
1171 |
+
print(len(resos))
|
1172 |
+
print(resos)
|
1173 |
+
aspect_ratios = [w / h for w, h in resos]
|
1174 |
+
print(aspect_ratios)
|
1175 |
+
|
1176 |
+
ars = set()
|
1177 |
+
for ar in aspect_ratios:
|
1178 |
+
if ar in ars:
|
1179 |
+
print("error! duplicate ar:", ar)
|
1180 |
+
ars.add(ar)
|
build/lib/library/train_util.py
ADDED
@@ -0,0 +1,1796 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# common functions for training
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import shutil
|
6 |
+
import time
|
7 |
+
from typing import Dict, List, NamedTuple, Tuple
|
8 |
+
from accelerate import Accelerator
|
9 |
+
from torch.autograd.function import Function
|
10 |
+
import glob
|
11 |
+
import math
|
12 |
+
import os
|
13 |
+
import random
|
14 |
+
import hashlib
|
15 |
+
import subprocess
|
16 |
+
from io import BytesIO
|
17 |
+
|
18 |
+
from tqdm import tqdm
|
19 |
+
import torch
|
20 |
+
from torchvision import transforms
|
21 |
+
from transformers import CLIPTokenizer
|
22 |
+
import diffusers
|
23 |
+
from diffusers import DDPMScheduler, StableDiffusionPipeline
|
24 |
+
import albumentations as albu
|
25 |
+
import numpy as np
|
26 |
+
from PIL import Image
|
27 |
+
import cv2
|
28 |
+
from einops import rearrange
|
29 |
+
from torch import einsum
|
30 |
+
import safetensors.torch
|
31 |
+
|
32 |
+
import library.model_util as model_util
|
33 |
+
|
34 |
+
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
35 |
+
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
36 |
+
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
|
37 |
+
|
38 |
+
# checkpointファイル名
|
39 |
+
EPOCH_STATE_NAME = "{}-{:06d}-state"
|
40 |
+
EPOCH_FILE_NAME = "{}-{:06d}"
|
41 |
+
EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}"
|
42 |
+
LAST_STATE_NAME = "{}-state"
|
43 |
+
DEFAULT_EPOCH_NAME = "epoch"
|
44 |
+
DEFAULT_LAST_OUTPUT_NAME = "last"
|
45 |
+
|
46 |
+
# region dataset
|
47 |
+
|
48 |
+
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
|
49 |
+
# , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux?
|
50 |
+
|
51 |
+
|
52 |
+
class ImageInfo():
|
53 |
+
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
54 |
+
self.image_key: str = image_key
|
55 |
+
self.num_repeats: int = num_repeats
|
56 |
+
self.caption: str = caption
|
57 |
+
self.is_reg: bool = is_reg
|
58 |
+
self.absolute_path: str = absolute_path
|
59 |
+
self.image_size: Tuple[int, int] = None
|
60 |
+
self.resized_size: Tuple[int, int] = None
|
61 |
+
self.bucket_reso: Tuple[int, int] = None
|
62 |
+
self.latents: torch.Tensor = None
|
63 |
+
self.latents_flipped: torch.Tensor = None
|
64 |
+
self.latents_npz: str = None
|
65 |
+
self.latents_npz_flipped: str = None
|
66 |
+
|
67 |
+
|
68 |
+
class BucketManager():
|
69 |
+
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
|
70 |
+
self.no_upscale = no_upscale
|
71 |
+
if max_reso is None:
|
72 |
+
self.max_reso = None
|
73 |
+
self.max_area = None
|
74 |
+
else:
|
75 |
+
self.max_reso = max_reso
|
76 |
+
self.max_area = max_reso[0] * max_reso[1]
|
77 |
+
self.min_size = min_size
|
78 |
+
self.max_size = max_size
|
79 |
+
self.reso_steps = reso_steps
|
80 |
+
|
81 |
+
self.resos = []
|
82 |
+
self.reso_to_id = {}
|
83 |
+
self.buckets = [] # 前処理時は (image_key, image)、学習時は image_key
|
84 |
+
|
85 |
+
def add_image(self, reso, image):
|
86 |
+
bucket_id = self.reso_to_id[reso]
|
87 |
+
self.buckets[bucket_id].append(image)
|
88 |
+
|
89 |
+
def shuffle(self):
|
90 |
+
for bucket in self.buckets:
|
91 |
+
random.shuffle(bucket)
|
92 |
+
|
93 |
+
def sort(self):
|
94 |
+
# 解像度順にソートする(表示時、メタデータ格納時の見栄えをよくするためだけ)。bucketsも入れ替えてreso_to_idも振り直す
|
95 |
+
sorted_resos = self.resos.copy()
|
96 |
+
sorted_resos.sort()
|
97 |
+
|
98 |
+
sorted_buckets = []
|
99 |
+
sorted_reso_to_id = {}
|
100 |
+
for i, reso in enumerate(sorted_resos):
|
101 |
+
bucket_id = self.reso_to_id[reso]
|
102 |
+
sorted_buckets.append(self.buckets[bucket_id])
|
103 |
+
sorted_reso_to_id[reso] = i
|
104 |
+
|
105 |
+
self.resos = sorted_resos
|
106 |
+
self.buckets = sorted_buckets
|
107 |
+
self.reso_to_id = sorted_reso_to_id
|
108 |
+
|
109 |
+
def make_buckets(self):
|
110 |
+
resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps)
|
111 |
+
self.set_predefined_resos(resos)
|
112 |
+
|
113 |
+
def set_predefined_resos(self, resos):
|
114 |
+
# 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
|
115 |
+
self.predefined_resos = resos.copy()
|
116 |
+
self.predefined_resos_set = set(resos)
|
117 |
+
self.predefined_aspect_ratios = np.array([w / h for w, h in resos])
|
118 |
+
|
119 |
+
def add_if_new_reso(self, reso):
|
120 |
+
if reso not in self.reso_to_id:
|
121 |
+
bucket_id = len(self.resos)
|
122 |
+
self.reso_to_id[reso] = bucket_id
|
123 |
+
self.resos.append(reso)
|
124 |
+
self.buckets.append([])
|
125 |
+
# print(reso, bucket_id, len(self.buckets))
|
126 |
+
|
127 |
+
def round_to_steps(self, x):
|
128 |
+
x = int(x + .5)
|
129 |
+
return x - x % self.reso_steps
|
130 |
+
|
131 |
+
def select_bucket(self, image_width, image_height):
|
132 |
+
aspect_ratio = image_width / image_height
|
133 |
+
if not self.no_upscale:
|
134 |
+
# 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する
|
135 |
+
reso = (image_width, image_height)
|
136 |
+
if reso in self.predefined_resos_set:
|
137 |
+
pass
|
138 |
+
else:
|
139 |
+
ar_errors = self.predefined_aspect_ratios - aspect_ratio
|
140 |
+
predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
|
141 |
+
reso = self.predefined_resos[predefined_bucket_id]
|
142 |
+
|
143 |
+
ar_reso = reso[0] / reso[1]
|
144 |
+
if aspect_ratio > ar_reso: # 横が長い→縦を合わせる
|
145 |
+
scale = reso[1] / image_height
|
146 |
+
else:
|
147 |
+
scale = reso[0] / image_width
|
148 |
+
|
149 |
+
resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
|
150 |
+
# print("use predef", image_width, image_height, reso, resized_size)
|
151 |
+
else:
|
152 |
+
if image_width * image_height > self.max_area:
|
153 |
+
# 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める
|
154 |
+
resized_width = math.sqrt(self.max_area * aspect_ratio)
|
155 |
+
resized_height = self.max_area / resized_width
|
156 |
+
assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal"
|
157 |
+
|
158 |
+
# リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ
|
159 |
+
# 元のbucketingと同じロジック
|
160 |
+
b_width_rounded = self.round_to_steps(resized_width)
|
161 |
+
b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio)
|
162 |
+
ar_width_rounded = b_width_rounded / b_height_in_wr
|
163 |
+
|
164 |
+
b_height_rounded = self.round_to_steps(resized_height)
|
165 |
+
b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio)
|
166 |
+
ar_height_rounded = b_width_in_hr / b_height_rounded
|
167 |
+
|
168 |
+
# print(b_width_rounded, b_height_in_wr, ar_width_rounded)
|
169 |
+
# print(b_width_in_hr, b_height_rounded, ar_height_rounded)
|
170 |
+
|
171 |
+
if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio):
|
172 |
+
resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + .5))
|
173 |
+
else:
|
174 |
+
resized_size = (int(b_height_rounded * aspect_ratio + .5), b_height_rounded)
|
175 |
+
# print(resized_size)
|
176 |
+
else:
|
177 |
+
resized_size = (image_width, image_height) # リサイズは不要
|
178 |
+
|
179 |
+
# 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする)
|
180 |
+
bucket_width = resized_size[0] - resized_size[0] % self.reso_steps
|
181 |
+
bucket_height = resized_size[1] - resized_size[1] % self.reso_steps
|
182 |
+
# print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height)
|
183 |
+
|
184 |
+
reso = (bucket_width, bucket_height)
|
185 |
+
|
186 |
+
self.add_if_new_reso(reso)
|
187 |
+
|
188 |
+
ar_error = (reso[0] / reso[1]) - aspect_ratio
|
189 |
+
return reso, resized_size, ar_error
|
190 |
+
|
191 |
+
|
192 |
+
class BucketBatchIndex(NamedTuple):
|
193 |
+
bucket_index: int
|
194 |
+
bucket_batch_size: int
|
195 |
+
batch_index: int
|
196 |
+
|
197 |
+
|
198 |
+
class BaseDataset(torch.utils.data.Dataset):
|
199 |
+
def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
|
200 |
+
super().__init__()
|
201 |
+
self.tokenizer: CLIPTokenizer = tokenizer
|
202 |
+
self.max_token_length = max_token_length
|
203 |
+
self.shuffle_caption = shuffle_caption
|
204 |
+
self.shuffle_keep_tokens = shuffle_keep_tokens
|
205 |
+
# width/height is used when enable_bucket==False
|
206 |
+
self.width, self.height = (None, None) if resolution is None else resolution
|
207 |
+
self.face_crop_aug_range = face_crop_aug_range
|
208 |
+
self.flip_aug = flip_aug
|
209 |
+
self.color_aug = color_aug
|
210 |
+
self.debug_dataset = debug_dataset
|
211 |
+
self.random_crop = random_crop
|
212 |
+
self.token_padding_disabled = False
|
213 |
+
self.dataset_dirs_info = {}
|
214 |
+
self.reg_dataset_dirs_info = {}
|
215 |
+
self.tag_frequency = {}
|
216 |
+
|
217 |
+
self.enable_bucket = False
|
218 |
+
self.bucket_manager: BucketManager = None # not initialized
|
219 |
+
self.min_bucket_reso = None
|
220 |
+
self.max_bucket_reso = None
|
221 |
+
self.bucket_reso_steps = None
|
222 |
+
self.bucket_no_upscale = None
|
223 |
+
self.bucket_info = None # for metadata
|
224 |
+
|
225 |
+
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
226 |
+
|
227 |
+
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
228 |
+
self.dropout_rate: float = 0
|
229 |
+
self.dropout_every_n_epochs: int = None
|
230 |
+
self.tag_dropout_rate: float = 0
|
231 |
+
|
232 |
+
# augmentation
|
233 |
+
flip_p = 0.5 if flip_aug else 0.0
|
234 |
+
if color_aug:
|
235 |
+
# わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
|
236 |
+
self.aug = albu.Compose([
|
237 |
+
albu.OneOf([
|
238 |
+
albu.HueSaturationValue(8, 0, 0, p=.5),
|
239 |
+
albu.RandomGamma((95, 105), p=.5),
|
240 |
+
], p=.33),
|
241 |
+
albu.HorizontalFlip(p=flip_p)
|
242 |
+
], p=1.)
|
243 |
+
elif flip_aug:
|
244 |
+
self.aug = albu.Compose([
|
245 |
+
albu.HorizontalFlip(p=flip_p)
|
246 |
+
], p=1.)
|
247 |
+
else:
|
248 |
+
self.aug = None
|
249 |
+
|
250 |
+
self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
|
251 |
+
|
252 |
+
self.image_data: Dict[str, ImageInfo] = {}
|
253 |
+
|
254 |
+
self.replacements = {}
|
255 |
+
|
256 |
+
def set_current_epoch(self, epoch):
|
257 |
+
self.current_epoch = epoch
|
258 |
+
|
259 |
+
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
|
260 |
+
# コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
|
261 |
+
self.dropout_rate = dropout_rate
|
262 |
+
self.dropout_every_n_epochs = dropout_every_n_epochs
|
263 |
+
self.tag_dropout_rate = tag_dropout_rate
|
264 |
+
|
265 |
+
def set_tag_frequency(self, dir_name, captions):
|
266 |
+
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
267 |
+
self.tag_frequency[dir_name] = frequency_for_dir
|
268 |
+
for caption in captions:
|
269 |
+
for tag in caption.split(","):
|
270 |
+
if tag and not tag.isspace():
|
271 |
+
tag = tag.lower()
|
272 |
+
frequency = frequency_for_dir.get(tag, 0)
|
273 |
+
frequency_for_dir[tag] = frequency + 1
|
274 |
+
|
275 |
+
def disable_token_padding(self):
|
276 |
+
self.token_padding_disabled = True
|
277 |
+
|
278 |
+
def add_replacement(self, str_from, str_to):
|
279 |
+
self.replacements[str_from] = str_to
|
280 |
+
|
281 |
+
def process_caption(self, caption):
|
282 |
+
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
283 |
+
is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
|
284 |
+
is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
|
285 |
+
|
286 |
+
if is_drop_out:
|
287 |
+
caption = ""
|
288 |
+
else:
|
289 |
+
if self.shuffle_caption or self.tag_dropout_rate > 0:
|
290 |
+
def dropout_tags(tokens):
|
291 |
+
if self.tag_dropout_rate <= 0:
|
292 |
+
return tokens
|
293 |
+
l = []
|
294 |
+
for token in tokens:
|
295 |
+
if random.random() >= self.tag_dropout_rate:
|
296 |
+
l.append(token)
|
297 |
+
return l
|
298 |
+
|
299 |
+
tokens = [t.strip() for t in caption.strip().split(",")]
|
300 |
+
if self.shuffle_keep_tokens is None:
|
301 |
+
if self.shuffle_caption:
|
302 |
+
random.shuffle(tokens)
|
303 |
+
|
304 |
+
tokens = dropout_tags(tokens)
|
305 |
+
else:
|
306 |
+
if len(tokens) > self.shuffle_keep_tokens:
|
307 |
+
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
308 |
+
tokens = tokens[self.shuffle_keep_tokens:]
|
309 |
+
|
310 |
+
if self.shuffle_caption:
|
311 |
+
random.shuffle(tokens)
|
312 |
+
|
313 |
+
tokens = dropout_tags(tokens)
|
314 |
+
|
315 |
+
tokens = keep_tokens + tokens
|
316 |
+
caption = ", ".join(tokens)
|
317 |
+
|
318 |
+
# textual inversion対応
|
319 |
+
for str_from, str_to in self.replacements.items():
|
320 |
+
if str_from == "":
|
321 |
+
# replace all
|
322 |
+
if type(str_to) == list:
|
323 |
+
caption = random.choice(str_to)
|
324 |
+
else:
|
325 |
+
caption = str_to
|
326 |
+
else:
|
327 |
+
caption = caption.replace(str_from, str_to)
|
328 |
+
|
329 |
+
return caption
|
330 |
+
|
331 |
+
def get_input_ids(self, caption):
|
332 |
+
input_ids = self.tokenizer(caption, padding="max_length", truncation=True,
|
333 |
+
max_length=self.tokenizer_max_length, return_tensors="pt").input_ids
|
334 |
+
|
335 |
+
if self.tokenizer_max_length > self.tokenizer.model_max_length:
|
336 |
+
input_ids = input_ids.squeeze(0)
|
337 |
+
iids_list = []
|
338 |
+
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
|
339 |
+
# v1
|
340 |
+
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
|
341 |
+
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
|
342 |
+
for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75)
|
343 |
+
ids_chunk = (input_ids[0].unsqueeze(0),
|
344 |
+
input_ids[i:i + self.tokenizer.model_max_length - 2],
|
345 |
+
input_ids[-1].unsqueeze(0))
|
346 |
+
ids_chunk = torch.cat(ids_chunk)
|
347 |
+
iids_list.append(ids_chunk)
|
348 |
+
else:
|
349 |
+
# v2
|
350 |
+
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
|
351 |
+
for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2):
|
352 |
+
ids_chunk = (input_ids[0].unsqueeze(0), # BOS
|
353 |
+
input_ids[i:i + self.tokenizer.model_max_length - 2],
|
354 |
+
input_ids[-1].unsqueeze(0)) # PAD or EOS
|
355 |
+
ids_chunk = torch.cat(ids_chunk)
|
356 |
+
|
357 |
+
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
|
358 |
+
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
|
359 |
+
if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id:
|
360 |
+
ids_chunk[-1] = self.tokenizer.eos_token_id
|
361 |
+
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
|
362 |
+
if ids_chunk[1] == self.tokenizer.pad_token_id:
|
363 |
+
ids_chunk[1] = self.tokenizer.eos_token_id
|
364 |
+
|
365 |
+
iids_list.append(ids_chunk)
|
366 |
+
|
367 |
+
input_ids = torch.stack(iids_list) # 3,77
|
368 |
+
return input_ids
|
369 |
+
|
370 |
+
def register_image(self, info: ImageInfo):
|
371 |
+
self.image_data[info.image_key] = info
|
372 |
+
|
373 |
+
def make_buckets(self):
|
374 |
+
'''
|
375 |
+
bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
|
376 |
+
min_size and max_size are ignored when enable_bucket is False
|
377 |
+
'''
|
378 |
+
print("loading image sizes.")
|
379 |
+
for info in tqdm(self.image_data.values()):
|
380 |
+
if info.image_size is None:
|
381 |
+
info.image_size = self.get_image_size(info.absolute_path)
|
382 |
+
|
383 |
+
if self.enable_bucket:
|
384 |
+
print("make buckets")
|
385 |
+
else:
|
386 |
+
print("prepare dataset")
|
387 |
+
|
388 |
+
# bucketを作成し、画像をbucketに振り分ける
|
389 |
+
if self.enable_bucket:
|
390 |
+
if self.bucket_manager is None: # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み
|
391 |
+
self.bucket_manager = BucketManager(self.bucket_no_upscale, (self.width, self.height),
|
392 |
+
self.min_bucket_reso, self.max_bucket_reso, self.bucket_reso_steps)
|
393 |
+
if not self.bucket_no_upscale:
|
394 |
+
self.bucket_manager.make_buckets()
|
395 |
+
else:
|
396 |
+
print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
|
397 |
+
|
398 |
+
img_ar_errors = []
|
399 |
+
for image_info in self.image_data.values():
|
400 |
+
image_width, image_height = image_info.image_size
|
401 |
+
image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(image_width, image_height)
|
402 |
+
|
403 |
+
# print(image_info.image_key, image_info.bucket_reso)
|
404 |
+
img_ar_errors.append(abs(ar_error))
|
405 |
+
|
406 |
+
self.bucket_manager.sort()
|
407 |
+
else:
|
408 |
+
self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None)
|
409 |
+
self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ
|
410 |
+
for image_info in self.image_data.values():
|
411 |
+
image_width, image_height = image_info.image_size
|
412 |
+
image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height)
|
413 |
+
|
414 |
+
for image_info in self.image_data.values():
|
415 |
+
for _ in range(image_info.num_repeats):
|
416 |
+
self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key)
|
417 |
+
|
418 |
+
# bucket情報を表示、格納する
|
419 |
+
if self.enable_bucket:
|
420 |
+
self.bucket_info = {"buckets": {}}
|
421 |
+
print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
|
422 |
+
for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)):
|
423 |
+
count = len(bucket)
|
424 |
+
if count > 0:
|
425 |
+
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
|
426 |
+
print(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
|
427 |
+
|
428 |
+
img_ar_errors = np.array(img_ar_errors)
|
429 |
+
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
|
430 |
+
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
|
431 |
+
print(f"mean ar error (without repeats): {mean_img_ar_error}")
|
432 |
+
|
433 |
+
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
|
434 |
+
self.buckets_indices: List(BucketBatchIndex) = []
|
435 |
+
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
|
436 |
+
batch_count = int(math.ceil(len(bucket) / self.batch_size))
|
437 |
+
for batch_index in range(batch_count):
|
438 |
+
self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index))
|
439 |
+
|
440 |
+
# ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す
|
441 |
+
# 学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる
|
442 |
+
#
|
443 |
+
# # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
|
444 |
+
# # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
|
445 |
+
# # そのためバッチサイズを画像種類までに制限する
|
446 |
+
# # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
|
447 |
+
# # TO DO 正則化画像をepochまたがりで利用する仕組み
|
448 |
+
# num_of_image_types = len(set(bucket))
|
449 |
+
# bucket_batch_size = min(self.batch_size, num_of_image_types)
|
450 |
+
# batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
|
451 |
+
# # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
|
452 |
+
# for batch_index in range(batch_count):
|
453 |
+
# self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
|
454 |
+
# ↑ここまで
|
455 |
+
|
456 |
+
self.shuffle_buckets()
|
457 |
+
self._length = len(self.buckets_indices)
|
458 |
+
|
459 |
+
def shuffle_buckets(self):
|
460 |
+
random.shuffle(self.buckets_indices)
|
461 |
+
self.bucket_manager.shuffle()
|
462 |
+
|
463 |
+
def load_image(self, image_path):
|
464 |
+
image = Image.open(image_path)
|
465 |
+
if not image.mode == "RGB":
|
466 |
+
image = image.convert("RGB")
|
467 |
+
img = np.array(image, np.uint8)
|
468 |
+
return img
|
469 |
+
|
470 |
+
def trim_and_resize_if_required(self, image, reso, resized_size):
|
471 |
+
image_height, image_width = image.shape[0:2]
|
472 |
+
|
473 |
+
if image_width != resized_size[0] or image_height != resized_size[1]:
|
474 |
+
# リサイズする
|
475 |
+
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
476 |
+
|
477 |
+
image_height, image_width = image.shape[0:2]
|
478 |
+
if image_width > reso[0]:
|
479 |
+
trim_size = image_width - reso[0]
|
480 |
+
p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
|
481 |
+
# print("w", trim_size, p)
|
482 |
+
image = image[:, p:p + reso[0]]
|
483 |
+
if image_height > reso[1]:
|
484 |
+
trim_size = image_height - reso[1]
|
485 |
+
p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
|
486 |
+
# print("h", trim_size, p)
|
487 |
+
image = image[p:p + reso[1]]
|
488 |
+
|
489 |
+
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
490 |
+
return image
|
491 |
+
|
492 |
+
def cache_latents(self, vae):
|
493 |
+
# TODO ここを高速化したい
|
494 |
+
print("caching latents.")
|
495 |
+
for info in tqdm(self.image_data.values()):
|
496 |
+
if info.latents_npz is not None:
|
497 |
+
info.latents = self.load_latents_from_npz(info, False)
|
498 |
+
info.latents = torch.FloatTensor(info.latents)
|
499 |
+
info.latents_flipped = self.load_latents_from_npz(info, True) # might be None
|
500 |
+
if info.latents_flipped is not None:
|
501 |
+
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
|
502 |
+
continue
|
503 |
+
|
504 |
+
image = self.load_image(info.absolute_path)
|
505 |
+
image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
|
506 |
+
|
507 |
+
img_tensor = self.image_transforms(image)
|
508 |
+
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
509 |
+
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
510 |
+
|
511 |
+
if self.flip_aug:
|
512 |
+
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
|
513 |
+
img_tensor = self.image_transforms(image)
|
514 |
+
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
515 |
+
info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
516 |
+
|
517 |
+
def get_image_size(self, image_path):
|
518 |
+
image = Image.open(image_path)
|
519 |
+
return image.size
|
520 |
+
|
521 |
+
def load_image_with_face_info(self, image_path: str):
|
522 |
+
img = self.load_image(image_path)
|
523 |
+
|
524 |
+
face_cx = face_cy = face_w = face_h = 0
|
525 |
+
if self.face_crop_aug_range is not None:
|
526 |
+
tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
|
527 |
+
if len(tokens) >= 5:
|
528 |
+
face_cx = int(tokens[-4])
|
529 |
+
face_cy = int(tokens[-3])
|
530 |
+
face_w = int(tokens[-2])
|
531 |
+
face_h = int(tokens[-1])
|
532 |
+
|
533 |
+
return img, face_cx, face_cy, face_w, face_h
|
534 |
+
|
535 |
+
# いい感じに切り出す
|
536 |
+
def crop_target(self, image, face_cx, face_cy, face_w, face_h):
|
537 |
+
height, width = image.shape[0:2]
|
538 |
+
if height == self.height and width == self.width:
|
539 |
+
return image
|
540 |
+
|
541 |
+
# 画像サイズはsizeより大きいのでリサイズする
|
542 |
+
face_size = max(face_w, face_h)
|
543 |
+
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
|
544 |
+
min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
|
545 |
+
max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
|
546 |
+
if min_scale >= max_scale: # range指定がmin==max
|
547 |
+
scale = min_scale
|
548 |
+
else:
|
549 |
+
scale = random.uniform(min_scale, max_scale)
|
550 |
+
|
551 |
+
nh = int(height * scale + .5)
|
552 |
+
nw = int(width * scale + .5)
|
553 |
+
assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
|
554 |
+
image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
|
555 |
+
face_cx = int(face_cx * scale + .5)
|
556 |
+
face_cy = int(face_cy * scale + .5)
|
557 |
+
height, width = nh, nw
|
558 |
+
|
559 |
+
# 顔を中心として448*640とかへ切り出す
|
560 |
+
for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
|
561 |
+
p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
|
562 |
+
|
563 |
+
if self.random_crop:
|
564 |
+
# 背景も含めるために顔を中心に置く確率を高めつつずらす
|
565 |
+
range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
|
566 |
+
p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
|
567 |
+
else:
|
568 |
+
# range指定があるときのみ、すこしだけランダムに(わりと適当)
|
569 |
+
if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
|
570 |
+
if face_size > self.size // 10 and face_size >= 40:
|
571 |
+
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
|
572 |
+
|
573 |
+
p1 = max(0, min(p1, length - target_size))
|
574 |
+
|
575 |
+
if axis == 0:
|
576 |
+
image = image[p1:p1 + target_size, :]
|
577 |
+
else:
|
578 |
+
image = image[:, p1:p1 + target_size]
|
579 |
+
|
580 |
+
return image
|
581 |
+
|
582 |
+
def load_latents_from_npz(self, image_info: ImageInfo, flipped):
|
583 |
+
npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
|
584 |
+
if npz_file is None:
|
585 |
+
return None
|
586 |
+
return np.load(npz_file)['arr_0']
|
587 |
+
|
588 |
+
def __len__(self):
|
589 |
+
return self._length
|
590 |
+
|
591 |
+
def __getitem__(self, index):
|
592 |
+
if index == 0:
|
593 |
+
self.shuffle_buckets()
|
594 |
+
|
595 |
+
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
596 |
+
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
597 |
+
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
598 |
+
|
599 |
+
loss_weights = []
|
600 |
+
captions = []
|
601 |
+
input_ids_list = []
|
602 |
+
latents_list = []
|
603 |
+
images = []
|
604 |
+
|
605 |
+
for image_key in bucket[image_index:image_index + bucket_batch_size]:
|
606 |
+
image_info = self.image_data[image_key]
|
607 |
+
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
608 |
+
|
609 |
+
# image/latentsを処理する
|
610 |
+
if image_info.latents is not None:
|
611 |
+
latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped
|
612 |
+
image = None
|
613 |
+
elif image_info.latents_npz is not None:
|
614 |
+
latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5)
|
615 |
+
latents = torch.FloatTensor(latents)
|
616 |
+
image = None
|
617 |
+
else:
|
618 |
+
# 画像を読み込み、必要ならcropする
|
619 |
+
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
|
620 |
+
im_h, im_w = img.shape[0:2]
|
621 |
+
|
622 |
+
if self.enable_bucket:
|
623 |
+
img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
|
624 |
+
else:
|
625 |
+
if face_cx > 0: # 顔位置情報あり
|
626 |
+
img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
|
627 |
+
elif im_h > self.height or im_w > self.width:
|
628 |
+
assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
|
629 |
+
if im_h > self.height:
|
630 |
+
p = random.randint(0, im_h - self.height)
|
631 |
+
img = img[p:p + self.height]
|
632 |
+
if im_w > self.width:
|
633 |
+
p = random.randint(0, im_w - self.width)
|
634 |
+
img = img[:, p:p + self.width]
|
635 |
+
|
636 |
+
im_h, im_w = img.shape[0:2]
|
637 |
+
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
638 |
+
|
639 |
+
# augmentation
|
640 |
+
if self.aug is not None:
|
641 |
+
img = self.aug(image=img)['image']
|
642 |
+
|
643 |
+
latents = None
|
644 |
+
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
645 |
+
|
646 |
+
images.append(image)
|
647 |
+
latents_list.append(latents)
|
648 |
+
|
649 |
+
caption = self.process_caption(image_info.caption)
|
650 |
+
captions.append(caption)
|
651 |
+
if not self.token_padding_disabled: # this option might be omitted in future
|
652 |
+
input_ids_list.append(self.get_input_ids(caption))
|
653 |
+
|
654 |
+
example = {}
|
655 |
+
example['loss_weights'] = torch.FloatTensor(loss_weights)
|
656 |
+
|
657 |
+
if self.token_padding_disabled:
|
658 |
+
# padding=True means pad in the batch
|
659 |
+
example['input_ids'] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
660 |
+
else:
|
661 |
+
# batch processing seems to be good
|
662 |
+
example['input_ids'] = torch.stack(input_ids_list)
|
663 |
+
|
664 |
+
if images[0] is not None:
|
665 |
+
images = torch.stack(images)
|
666 |
+
images = images.to(memory_format=torch.contiguous_format).float()
|
667 |
+
else:
|
668 |
+
images = None
|
669 |
+
example['images'] = images
|
670 |
+
|
671 |
+
example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None
|
672 |
+
|
673 |
+
if self.debug_dataset:
|
674 |
+
example['image_keys'] = bucket[image_index:image_index + self.batch_size]
|
675 |
+
example['captions'] = captions
|
676 |
+
return example
|
677 |
+
|
678 |
+
|
679 |
+
class DreamBoothDataset(BaseDataset):
|
680 |
+
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
|
681 |
+
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
682 |
+
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
683 |
+
|
684 |
+
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
685 |
+
|
686 |
+
self.batch_size = batch_size
|
687 |
+
self.size = min(self.width, self.height) # 短いほう
|
688 |
+
self.prior_loss_weight = prior_loss_weight
|
689 |
+
self.latents_cache = None
|
690 |
+
|
691 |
+
self.enable_bucket = enable_bucket
|
692 |
+
if self.enable_bucket:
|
693 |
+
assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
|
694 |
+
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
695 |
+
self.min_bucket_reso = min_bucket_reso
|
696 |
+
self.max_bucket_reso = max_bucket_reso
|
697 |
+
self.bucket_reso_steps = bucket_reso_steps
|
698 |
+
self.bucket_no_upscale = bucket_no_upscale
|
699 |
+
else:
|
700 |
+
self.min_bucket_reso = None
|
701 |
+
self.max_bucket_reso = None
|
702 |
+
self.bucket_reso_steps = None # この情報は使われない
|
703 |
+
self.bucket_no_upscale = False
|
704 |
+
|
705 |
+
def read_caption(img_path):
|
706 |
+
# captionの候補ファイル名を作る
|
707 |
+
base_name = os.path.splitext(img_path)[0]
|
708 |
+
base_name_face_det = base_name
|
709 |
+
tokens = base_name.split("_")
|
710 |
+
if len(tokens) >= 5:
|
711 |
+
base_name_face_det = "_".join(tokens[:-4])
|
712 |
+
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
|
713 |
+
|
714 |
+
caption = None
|
715 |
+
for cap_path in cap_paths:
|
716 |
+
if os.path.isfile(cap_path):
|
717 |
+
with open(cap_path, "rt", encoding='utf-8') as f:
|
718 |
+
try:
|
719 |
+
lines = f.readlines()
|
720 |
+
except UnicodeDecodeError as e:
|
721 |
+
print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
|
722 |
+
raise e
|
723 |
+
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
724 |
+
caption = lines[0].strip()
|
725 |
+
break
|
726 |
+
return caption
|
727 |
+
|
728 |
+
def load_dreambooth_dir(dir):
|
729 |
+
if not os.path.isdir(dir):
|
730 |
+
# print(f"ignore file: {dir}")
|
731 |
+
return 0, [], []
|
732 |
+
|
733 |
+
tokens = os.path.basename(dir).split('_')
|
734 |
+
try:
|
735 |
+
n_repeats = int(tokens[0])
|
736 |
+
except ValueError as e:
|
737 |
+
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
|
738 |
+
return 0, [], []
|
739 |
+
|
740 |
+
caption_by_folder = '_'.join(tokens[1:])
|
741 |
+
img_paths = glob_images(dir, "*")
|
742 |
+
print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
|
743 |
+
|
744 |
+
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
745 |
+
captions = []
|
746 |
+
for img_path in img_paths:
|
747 |
+
cap_for_img = read_caption(img_path)
|
748 |
+
captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
|
749 |
+
|
750 |
+
self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
|
751 |
+
|
752 |
+
return n_repeats, img_paths, captions
|
753 |
+
|
754 |
+
print("prepare train images.")
|
755 |
+
train_dirs = os.listdir(train_data_dir)
|
756 |
+
num_train_images = 0
|
757 |
+
for dir in train_dirs:
|
758 |
+
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
|
759 |
+
num_train_images += n_repeats * len(img_paths)
|
760 |
+
|
761 |
+
for img_path, caption in zip(img_paths, captions):
|
762 |
+
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
|
763 |
+
self.register_image(info)
|
764 |
+
|
765 |
+
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
766 |
+
|
767 |
+
print(f"{num_train_images} train images with repeating.")
|
768 |
+
self.num_train_images = num_train_images
|
769 |
+
|
770 |
+
# reg imageは数を数えて学習画像と同じ枚数にする
|
771 |
+
num_reg_images = 0
|
772 |
+
if reg_data_dir:
|
773 |
+
print("prepare reg images.")
|
774 |
+
reg_infos: List[ImageInfo] = []
|
775 |
+
|
776 |
+
reg_dirs = os.listdir(reg_data_dir)
|
777 |
+
for dir in reg_dirs:
|
778 |
+
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
|
779 |
+
num_reg_images += n_repeats * len(img_paths)
|
780 |
+
|
781 |
+
for img_path, caption in zip(img_paths, captions):
|
782 |
+
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
783 |
+
reg_infos.append(info)
|
784 |
+
|
785 |
+
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
786 |
+
|
787 |
+
print(f"{num_reg_images} reg images.")
|
788 |
+
if num_train_images < num_reg_images:
|
789 |
+
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
|
790 |
+
|
791 |
+
if num_reg_images == 0:
|
792 |
+
print("no regularization images / 正則化画像が見つかりませんでした")
|
793 |
+
else:
|
794 |
+
# num_repeatsを計算する:どうせ大した数ではないのでループで処理する
|
795 |
+
n = 0
|
796 |
+
first_loop = True
|
797 |
+
while n < num_train_images:
|
798 |
+
for info in reg_infos:
|
799 |
+
if first_loop:
|
800 |
+
self.register_image(info)
|
801 |
+
n += info.num_repeats
|
802 |
+
else:
|
803 |
+
info.num_repeats += 1
|
804 |
+
n += 1
|
805 |
+
if n >= num_train_images:
|
806 |
+
break
|
807 |
+
first_loop = False
|
808 |
+
|
809 |
+
self.num_reg_images = num_reg_images
|
810 |
+
|
811 |
+
|
812 |
+
class FineTuningDataset(BaseDataset):
|
813 |
+
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
|
814 |
+
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
815 |
+
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
816 |
+
|
817 |
+
# メタデータを読み込む
|
818 |
+
if os.path.exists(json_file_name):
|
819 |
+
print(f"loading existing metadata: {json_file_name}")
|
820 |
+
with open(json_file_name, "rt", encoding='utf-8') as f:
|
821 |
+
metadata = json.load(f)
|
822 |
+
else:
|
823 |
+
raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}")
|
824 |
+
|
825 |
+
self.metadata = metadata
|
826 |
+
self.train_data_dir = train_data_dir
|
827 |
+
self.batch_size = batch_size
|
828 |
+
|
829 |
+
tags_list = []
|
830 |
+
for image_key, img_md in metadata.items():
|
831 |
+
# path情報を作る
|
832 |
+
if os.path.exists(image_key):
|
833 |
+
abs_path = image_key
|
834 |
+
else:
|
835 |
+
# わりといい加減だがいい方法が思いつかん
|
836 |
+
abs_path = glob_images(train_data_dir, image_key)
|
837 |
+
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
838 |
+
abs_path = abs_path[0]
|
839 |
+
|
840 |
+
caption = img_md.get('caption')
|
841 |
+
tags = img_md.get('tags')
|
842 |
+
if caption is None:
|
843 |
+
caption = tags
|
844 |
+
elif tags is not None and len(tags) > 0:
|
845 |
+
caption = caption + ', ' + tags
|
846 |
+
tags_list.append(tags)
|
847 |
+
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
|
848 |
+
|
849 |
+
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
|
850 |
+
image_info.image_size = img_md.get('train_resolution')
|
851 |
+
|
852 |
+
if not self.color_aug and not self.random_crop:
|
853 |
+
# if npz exists, use them
|
854 |
+
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
|
855 |
+
|
856 |
+
self.register_image(image_info)
|
857 |
+
self.num_train_images = len(metadata) * dataset_repeats
|
858 |
+
self.num_reg_images = 0
|
859 |
+
|
860 |
+
# TODO do not record tag freq when no tag
|
861 |
+
self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
|
862 |
+
self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
|
863 |
+
|
864 |
+
# check existence of all npz files
|
865 |
+
use_npz_latents = not (self.color_aug or self.random_crop)
|
866 |
+
if use_npz_latents:
|
867 |
+
npz_any = False
|
868 |
+
npz_all = True
|
869 |
+
for image_info in self.image_data.values():
|
870 |
+
has_npz = image_info.latents_npz is not None
|
871 |
+
npz_any = npz_any or has_npz
|
872 |
+
|
873 |
+
if self.flip_aug:
|
874 |
+
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
875 |
+
npz_all = npz_all and has_npz
|
876 |
+
|
877 |
+
if npz_any and not npz_all:
|
878 |
+
break
|
879 |
+
|
880 |
+
if not npz_any:
|
881 |
+
use_npz_latents = False
|
882 |
+
print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します")
|
883 |
+
elif not npz_all:
|
884 |
+
use_npz_latents = False
|
885 |
+
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
|
886 |
+
if self.flip_aug:
|
887 |
+
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
888 |
+
# else:
|
889 |
+
# print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
890 |
+
|
891 |
+
# check min/max bucket size
|
892 |
+
sizes = set()
|
893 |
+
resos = set()
|
894 |
+
for image_info in self.image_data.values():
|
895 |
+
if image_info.image_size is None:
|
896 |
+
sizes = None # not calculated
|
897 |
+
break
|
898 |
+
sizes.add(image_info.image_size[0])
|
899 |
+
sizes.add(image_info.image_size[1])
|
900 |
+
resos.add(tuple(image_info.image_size))
|
901 |
+
|
902 |
+
if sizes is None:
|
903 |
+
if use_npz_latents:
|
904 |
+
use_npz_latents = False
|
905 |
+
print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します")
|
906 |
+
|
907 |
+
assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
|
908 |
+
|
909 |
+
self.enable_bucket = enable_bucket
|
910 |
+
if self.enable_bucket:
|
911 |
+
self.min_bucket_reso = min_bucket_reso
|
912 |
+
self.max_bucket_reso = max_bucket_reso
|
913 |
+
self.bucket_reso_steps = bucket_reso_steps
|
914 |
+
self.bucket_no_upscale = bucket_no_upscale
|
915 |
+
else:
|
916 |
+
if not enable_bucket:
|
917 |
+
print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします")
|
918 |
+
print("using bucket info in metadata / メタデータ内のbucket情報を使います")
|
919 |
+
self.enable_bucket = True
|
920 |
+
|
921 |
+
assert not bucket_no_upscale, "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません"
|
922 |
+
|
923 |
+
# bucket情報を初期化しておく、make_bucketsで再作成しない
|
924 |
+
self.bucket_manager = BucketManager(False, None, None, None, None)
|
925 |
+
self.bucket_manager.set_predefined_resos(resos)
|
926 |
+
|
927 |
+
# npz情報をきれいにしておく
|
928 |
+
if not use_npz_latents:
|
929 |
+
for image_info in self.image_data.values():
|
930 |
+
image_info.latents_npz = image_info.latents_npz_flipped = None
|
931 |
+
|
932 |
+
def image_key_to_npz_file(self, image_key):
|
933 |
+
base_name = os.path.splitext(image_key)[0]
|
934 |
+
npz_file_norm = base_name + '.npz'
|
935 |
+
|
936 |
+
if os.path.exists(npz_file_norm):
|
937 |
+
# image_key is full path
|
938 |
+
npz_file_flip = base_name + '_flip.npz'
|
939 |
+
if not os.path.exists(npz_file_flip):
|
940 |
+
npz_file_flip = None
|
941 |
+
return npz_file_norm, npz_file_flip
|
942 |
+
|
943 |
+
# image_key is relative path
|
944 |
+
npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz')
|
945 |
+
npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz')
|
946 |
+
|
947 |
+
if not os.path.exists(npz_file_norm):
|
948 |
+
npz_file_norm = None
|
949 |
+
npz_file_flip = None
|
950 |
+
elif not os.path.exists(npz_file_flip):
|
951 |
+
npz_file_flip = None
|
952 |
+
|
953 |
+
return npz_file_norm, npz_file_flip
|
954 |
+
|
955 |
+
|
956 |
+
def debug_dataset(train_dataset, show_input_ids=False):
|
957 |
+
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
958 |
+
print("Escape for exit. / Escキーで中断、終了します")
|
959 |
+
|
960 |
+
train_dataset.set_current_epoch(1)
|
961 |
+
k = 0
|
962 |
+
for i, example in enumerate(train_dataset):
|
963 |
+
if example['latents'] is not None:
|
964 |
+
print(f"sample has latents from npz file: {example['latents'].size()}")
|
965 |
+
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
|
966 |
+
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
|
967 |
+
if show_input_ids:
|
968 |
+
print(f"input ids: {iid}")
|
969 |
+
if example['images'] is not None:
|
970 |
+
im = example['images'][j]
|
971 |
+
print(f"image size: {im.size()}")
|
972 |
+
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
|
973 |
+
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
|
974 |
+
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
|
975 |
+
if os.name == 'nt': # only windows
|
976 |
+
cv2.imshow("img", im)
|
977 |
+
k = cv2.waitKey()
|
978 |
+
cv2.destroyAllWindows()
|
979 |
+
if k == 27:
|
980 |
+
break
|
981 |
+
if k == 27 or (example['images'] is None and i >= 8):
|
982 |
+
break
|
983 |
+
|
984 |
+
|
985 |
+
def glob_images(directory, base="*"):
|
986 |
+
img_paths = []
|
987 |
+
for ext in IMAGE_EXTENSIONS:
|
988 |
+
if base == '*':
|
989 |
+
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
|
990 |
+
else:
|
991 |
+
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
992 |
+
# img_paths = list(set(img_paths)) # 重複を排除
|
993 |
+
# img_paths.sort()
|
994 |
+
return img_paths
|
995 |
+
|
996 |
+
|
997 |
+
def glob_images_pathlib(dir_path, recursive):
|
998 |
+
image_paths = []
|
999 |
+
if recursive:
|
1000 |
+
for ext in IMAGE_EXTENSIONS:
|
1001 |
+
image_paths += list(dir_path.rglob('*' + ext))
|
1002 |
+
else:
|
1003 |
+
for ext in IMAGE_EXTENSIONS:
|
1004 |
+
image_paths += list(dir_path.glob('*' + ext))
|
1005 |
+
# image_paths = list(set(image_paths)) # 重複を排除
|
1006 |
+
# image_paths.sort()
|
1007 |
+
return image_paths
|
1008 |
+
|
1009 |
+
# endregion
|
1010 |
+
|
1011 |
+
|
1012 |
+
# region モジュール入れ替え部
|
1013 |
+
"""
|
1014 |
+
高速化のためのモジュール入れ替え
|
1015 |
+
"""
|
1016 |
+
|
1017 |
+
# FlashAttentionを使うCrossAttention
|
1018 |
+
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
|
1019 |
+
# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
|
1020 |
+
|
1021 |
+
# constants
|
1022 |
+
|
1023 |
+
EPSILON = 1e-6
|
1024 |
+
|
1025 |
+
# helper functions
|
1026 |
+
|
1027 |
+
|
1028 |
+
def exists(val):
|
1029 |
+
return val is not None
|
1030 |
+
|
1031 |
+
|
1032 |
+
def default(val, d):
|
1033 |
+
return val if exists(val) else d
|
1034 |
+
|
1035 |
+
|
1036 |
+
def model_hash(filename):
|
1037 |
+
"""Old model hash used by stable-diffusion-webui"""
|
1038 |
+
try:
|
1039 |
+
with open(filename, "rb") as file:
|
1040 |
+
m = hashlib.sha256()
|
1041 |
+
|
1042 |
+
file.seek(0x100000)
|
1043 |
+
m.update(file.read(0x10000))
|
1044 |
+
return m.hexdigest()[0:8]
|
1045 |
+
except FileNotFoundError:
|
1046 |
+
return 'NOFILE'
|
1047 |
+
|
1048 |
+
|
1049 |
+
def calculate_sha256(filename):
|
1050 |
+
"""New model hash used by stable-diffusion-webui"""
|
1051 |
+
hash_sha256 = hashlib.sha256()
|
1052 |
+
blksize = 1024 * 1024
|
1053 |
+
|
1054 |
+
with open(filename, "rb") as f:
|
1055 |
+
for chunk in iter(lambda: f.read(blksize), b""):
|
1056 |
+
hash_sha256.update(chunk)
|
1057 |
+
|
1058 |
+
return hash_sha256.hexdigest()
|
1059 |
+
|
1060 |
+
|
1061 |
+
def precalculate_safetensors_hashes(tensors, metadata):
|
1062 |
+
"""Precalculate the model hashes needed by sd-webui-additional-networks to
|
1063 |
+
save time on indexing the model later."""
|
1064 |
+
|
1065 |
+
# Because writing user metadata to the file can change the result of
|
1066 |
+
# sd_models.model_hash(), only retain the training metadata for purposes of
|
1067 |
+
# calculating the hash, as they are meant to be immutable
|
1068 |
+
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
|
1069 |
+
|
1070 |
+
bytes = safetensors.torch.save(tensors, metadata)
|
1071 |
+
b = BytesIO(bytes)
|
1072 |
+
|
1073 |
+
model_hash = addnet_hash_safetensors(b)
|
1074 |
+
legacy_hash = addnet_hash_legacy(b)
|
1075 |
+
return model_hash, legacy_hash
|
1076 |
+
|
1077 |
+
|
1078 |
+
def addnet_hash_legacy(b):
|
1079 |
+
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
|
1080 |
+
m = hashlib.sha256()
|
1081 |
+
|
1082 |
+
b.seek(0x100000)
|
1083 |
+
m.update(b.read(0x10000))
|
1084 |
+
return m.hexdigest()[0:8]
|
1085 |
+
|
1086 |
+
|
1087 |
+
def addnet_hash_safetensors(b):
|
1088 |
+
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
|
1089 |
+
hash_sha256 = hashlib.sha256()
|
1090 |
+
blksize = 1024 * 1024
|
1091 |
+
|
1092 |
+
b.seek(0)
|
1093 |
+
header = b.read(8)
|
1094 |
+
n = int.from_bytes(header, "little")
|
1095 |
+
|
1096 |
+
offset = n + 8
|
1097 |
+
b.seek(offset)
|
1098 |
+
for chunk in iter(lambda: b.read(blksize), b""):
|
1099 |
+
hash_sha256.update(chunk)
|
1100 |
+
|
1101 |
+
return hash_sha256.hexdigest()
|
1102 |
+
|
1103 |
+
|
1104 |
+
def get_git_revision_hash() -> str:
|
1105 |
+
try:
|
1106 |
+
return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=os.path.dirname(__file__)).decode('ascii').strip()
|
1107 |
+
except:
|
1108 |
+
return "(unknown)"
|
1109 |
+
|
1110 |
+
|
1111 |
+
# flash attention forwards and backwards
|
1112 |
+
|
1113 |
+
# https://arxiv.org/abs/2205.14135
|
1114 |
+
|
1115 |
+
|
1116 |
+
class FlashAttentionFunction(torch.autograd.function.Function):
|
1117 |
+
@ staticmethod
|
1118 |
+
@ torch.no_grad()
|
1119 |
+
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
1120 |
+
""" Algorithm 2 in the paper """
|
1121 |
+
|
1122 |
+
device = q.device
|
1123 |
+
dtype = q.dtype
|
1124 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
1125 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
1126 |
+
|
1127 |
+
o = torch.zeros_like(q)
|
1128 |
+
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
1129 |
+
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
|
1130 |
+
|
1131 |
+
scale = (q.shape[-1] ** -0.5)
|
1132 |
+
|
1133 |
+
if not exists(mask):
|
1134 |
+
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
1135 |
+
else:
|
1136 |
+
mask = rearrange(mask, 'b n -> b 1 1 n')
|
1137 |
+
mask = mask.split(q_bucket_size, dim=-1)
|
1138 |
+
|
1139 |
+
row_splits = zip(
|
1140 |
+
q.split(q_bucket_size, dim=-2),
|
1141 |
+
o.split(q_bucket_size, dim=-2),
|
1142 |
+
mask,
|
1143 |
+
all_row_sums.split(q_bucket_size, dim=-2),
|
1144 |
+
all_row_maxes.split(q_bucket_size, dim=-2),
|
1145 |
+
)
|
1146 |
+
|
1147 |
+
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
1148 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
1149 |
+
|
1150 |
+
col_splits = zip(
|
1151 |
+
k.split(k_bucket_size, dim=-2),
|
1152 |
+
v.split(k_bucket_size, dim=-2),
|
1153 |
+
)
|
1154 |
+
|
1155 |
+
for k_ind, (kc, vc) in enumerate(col_splits):
|
1156 |
+
k_start_index = k_ind * k_bucket_size
|
1157 |
+
|
1158 |
+
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
|
1159 |
+
|
1160 |
+
if exists(row_mask):
|
1161 |
+
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
1162 |
+
|
1163 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
1164 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
|
1165 |
+
device=device).triu(q_start_index - k_start_index + 1)
|
1166 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
1167 |
+
|
1168 |
+
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
1169 |
+
attn_weights -= block_row_maxes
|
1170 |
+
exp_weights = torch.exp(attn_weights)
|
1171 |
+
|
1172 |
+
if exists(row_mask):
|
1173 |
+
exp_weights.masked_fill_(~row_mask, 0.)
|
1174 |
+
|
1175 |
+
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
|
1176 |
+
|
1177 |
+
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
1178 |
+
|
1179 |
+
exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
|
1180 |
+
|
1181 |
+
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
1182 |
+
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
1183 |
+
|
1184 |
+
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
|
1185 |
+
|
1186 |
+
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
|
1187 |
+
|
1188 |
+
row_maxes.copy_(new_row_maxes)
|
1189 |
+
row_sums.copy_(new_row_sums)
|
1190 |
+
|
1191 |
+
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
1192 |
+
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
1193 |
+
|
1194 |
+
return o
|
1195 |
+
|
1196 |
+
@ staticmethod
|
1197 |
+
@ torch.no_grad()
|
1198 |
+
def backward(ctx, do):
|
1199 |
+
""" Algorithm 4 in the paper """
|
1200 |
+
|
1201 |
+
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
1202 |
+
q, k, v, o, l, m = ctx.saved_tensors
|
1203 |
+
|
1204 |
+
device = q.device
|
1205 |
+
|
1206 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
1207 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
1208 |
+
|
1209 |
+
dq = torch.zeros_like(q)
|
1210 |
+
dk = torch.zeros_like(k)
|
1211 |
+
dv = torch.zeros_like(v)
|
1212 |
+
|
1213 |
+
row_splits = zip(
|
1214 |
+
q.split(q_bucket_size, dim=-2),
|
1215 |
+
o.split(q_bucket_size, dim=-2),
|
1216 |
+
do.split(q_bucket_size, dim=-2),
|
1217 |
+
mask,
|
1218 |
+
l.split(q_bucket_size, dim=-2),
|
1219 |
+
m.split(q_bucket_size, dim=-2),
|
1220 |
+
dq.split(q_bucket_size, dim=-2)
|
1221 |
+
)
|
1222 |
+
|
1223 |
+
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
1224 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
1225 |
+
|
1226 |
+
col_splits = zip(
|
1227 |
+
k.split(k_bucket_size, dim=-2),
|
1228 |
+
v.split(k_bucket_size, dim=-2),
|
1229 |
+
dk.split(k_bucket_size, dim=-2),
|
1230 |
+
dv.split(k_bucket_size, dim=-2),
|
1231 |
+
)
|
1232 |
+
|
1233 |
+
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
1234 |
+
k_start_index = k_ind * k_bucket_size
|
1235 |
+
|
1236 |
+
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
|
1237 |
+
|
1238 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
1239 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
|
1240 |
+
device=device).triu(q_start_index - k_start_index + 1)
|
1241 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
1242 |
+
|
1243 |
+
exp_attn_weights = torch.exp(attn_weights - mc)
|
1244 |
+
|
1245 |
+
if exists(row_mask):
|
1246 |
+
exp_attn_weights.masked_fill_(~row_mask, 0.)
|
1247 |
+
|
1248 |
+
p = exp_attn_weights / lc
|
1249 |
+
|
1250 |
+
dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
|
1251 |
+
dp = einsum('... i d, ... j d -> ... i j', doc, vc)
|
1252 |
+
|
1253 |
+
D = (doc * oc).sum(dim=-1, keepdims=True)
|
1254 |
+
ds = p * scale * (dp - D)
|
1255 |
+
|
1256 |
+
dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
|
1257 |
+
dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
|
1258 |
+
|
1259 |
+
dqc.add_(dq_chunk)
|
1260 |
+
dkc.add_(dk_chunk)
|
1261 |
+
dvc.add_(dv_chunk)
|
1262 |
+
|
1263 |
+
return dq, dk, dv, None, None, None, None
|
1264 |
+
|
1265 |
+
|
1266 |
+
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
1267 |
+
if mem_eff_attn:
|
1268 |
+
replace_unet_cross_attn_to_memory_efficient()
|
1269 |
+
elif xformers:
|
1270 |
+
replace_unet_cross_attn_to_xformers()
|
1271 |
+
|
1272 |
+
|
1273 |
+
def replace_unet_cross_attn_to_memory_efficient():
|
1274 |
+
print("Replace CrossAttention.forward to use FlashAttention (not xformers)")
|
1275 |
+
flash_func = FlashAttentionFunction
|
1276 |
+
|
1277 |
+
def forward_flash_attn(self, x, context=None, mask=None):
|
1278 |
+
q_bucket_size = 512
|
1279 |
+
k_bucket_size = 1024
|
1280 |
+
|
1281 |
+
h = self.heads
|
1282 |
+
q = self.to_q(x)
|
1283 |
+
|
1284 |
+
context = context if context is not None else x
|
1285 |
+
context = context.to(x.dtype)
|
1286 |
+
|
1287 |
+
if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
|
1288 |
+
context_k, context_v = self.hypernetwork.forward(x, context)
|
1289 |
+
context_k = context_k.to(x.dtype)
|
1290 |
+
context_v = context_v.to(x.dtype)
|
1291 |
+
else:
|
1292 |
+
context_k = context
|
1293 |
+
context_v = context
|
1294 |
+
|
1295 |
+
k = self.to_k(context_k)
|
1296 |
+
v = self.to_v(context_v)
|
1297 |
+
del context, x
|
1298 |
+
|
1299 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
|
1300 |
+
|
1301 |
+
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
|
1302 |
+
|
1303 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
1304 |
+
|
1305 |
+
# diffusers 0.7.0~ わざわざ変えるなよ (;´Д`)
|
1306 |
+
out = self.to_out[0](out)
|
1307 |
+
out = self.to_out[1](out)
|
1308 |
+
return out
|
1309 |
+
|
1310 |
+
diffusers.models.attention.CrossAttention.forward = forward_flash_attn
|
1311 |
+
|
1312 |
+
|
1313 |
+
def replace_unet_cross_attn_to_xformers():
|
1314 |
+
print("Replace CrossAttention.forward to use xformers")
|
1315 |
+
try:
|
1316 |
+
import xformers.ops
|
1317 |
+
except ImportError:
|
1318 |
+
raise ImportError("No xformers / xformersがインストールされていないようです")
|
1319 |
+
|
1320 |
+
def forward_xformers(self, x, context=None, mask=None):
|
1321 |
+
h = self.heads
|
1322 |
+
q_in = self.to_q(x)
|
1323 |
+
|
1324 |
+
context = default(context, x)
|
1325 |
+
context = context.to(x.dtype)
|
1326 |
+
|
1327 |
+
if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
|
1328 |
+
context_k, context_v = self.hypernetwork.forward(x, context)
|
1329 |
+
context_k = context_k.to(x.dtype)
|
1330 |
+
context_v = context_v.to(x.dtype)
|
1331 |
+
else:
|
1332 |
+
context_k = context
|
1333 |
+
context_v = context
|
1334 |
+
|
1335 |
+
k_in = self.to_k(context_k)
|
1336 |
+
v_in = self.to_v(context_v)
|
1337 |
+
|
1338 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
1339 |
+
del q_in, k_in, v_in
|
1340 |
+
|
1341 |
+
q = q.contiguous()
|
1342 |
+
k = k.contiguous()
|
1343 |
+
v = v.contiguous()
|
1344 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
1345 |
+
|
1346 |
+
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
1347 |
+
|
1348 |
+
# diffusers 0.7.0~
|
1349 |
+
out = self.to_out[0](out)
|
1350 |
+
out = self.to_out[1](out)
|
1351 |
+
return out
|
1352 |
+
|
1353 |
+
diffusers.models.attention.CrossAttention.forward = forward_xformers
|
1354 |
+
# endregion
|
1355 |
+
|
1356 |
+
|
1357 |
+
# region arguments
|
1358 |
+
|
1359 |
+
def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
1360 |
+
# for pretrained models
|
1361 |
+
parser.add_argument("--v2", action='store_true',
|
1362 |
+
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
|
1363 |
+
parser.add_argument("--v_parameterization", action='store_true',
|
1364 |
+
help='enable v-parameterization training / v-parameterization学習を有効にする')
|
1365 |
+
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
|
1366 |
+
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
1367 |
+
|
1368 |
+
|
1369 |
+
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
1370 |
+
parser.add_argument("--output_dir", type=str, default=None,
|
1371 |
+
help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
|
1372 |
+
parser.add_argument("--output_name", type=str, default=None,
|
1373 |
+
help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
|
1374 |
+
parser.add_argument("--save_precision", type=str, default=None,
|
1375 |
+
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
|
1376 |
+
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
1377 |
+
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
1378 |
+
parser.add_argument("--save_n_epoch_ratio", type=int, default=None,
|
1379 |
+
help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)")
|
1380 |
+
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
|
1381 |
+
parser.add_argument("--save_last_n_epochs_state", type=int, default=None,
|
1382 |
+
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
|
1383 |
+
parser.add_argument("--save_state", action="store_true",
|
1384 |
+
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
1385 |
+
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
1386 |
+
|
1387 |
+
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
1388 |
+
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
1389 |
+
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
1390 |
+
parser.add_argument("--use_8bit_adam", action="store_true",
|
1391 |
+
help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
1392 |
+
parser.add_argument("--use_lion_optimizer", action="store_true",
|
1393 |
+
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
1394 |
+
parser.add_argument("--mem_eff_attn", action="store_true",
|
1395 |
+
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
1396 |
+
parser.add_argument("--xformers", action="store_true",
|
1397 |
+
help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
1398 |
+
parser.add_argument("--vae", type=str, default=None,
|
1399 |
+
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
1400 |
+
|
1401 |
+
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
1402 |
+
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
1403 |
+
parser.add_argument("--max_train_epochs", type=int, default=None,
|
1404 |
+
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
1405 |
+
parser.add_argument("--max_data_loader_n_workers", type=int, default=8,
|
1406 |
+
help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
|
1407 |
+
parser.add_argument("--persistent_data_loader_workers", action="store_true",
|
1408 |
+
help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)")
|
1409 |
+
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
1410 |
+
parser.add_argument("--gradient_checkpointing", action="store_true",
|
1411 |
+
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
1412 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
|
1413 |
+
help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数")
|
1414 |
+
parser.add_argument("--mixed_precision", type=str, default="no",
|
1415 |
+
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
1416 |
+
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
|
1417 |
+
parser.add_argument("--clip_skip", type=int, default=None,
|
1418 |
+
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
1419 |
+
parser.add_argument("--logging_dir", type=str, default=None,
|
1420 |
+
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
1421 |
+
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
1422 |
+
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
1423 |
+
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
|
1424 |
+
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
1425 |
+
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
1426 |
+
parser.add_argument("--noise_offset", type=float, default=None,
|
1427 |
+
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
|
1428 |
+
parser.add_argument("--lowram", action="store_true",
|
1429 |
+
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
|
1430 |
+
|
1431 |
+
if support_dreambooth:
|
1432 |
+
# DreamBooth training
|
1433 |
+
parser.add_argument("--prior_loss_weight", type=float, default=1.0,
|
1434 |
+
help="loss weight for regularization images / 正則化画像のlossの重み")
|
1435 |
+
|
1436 |
+
|
1437 |
+
def verify_training_args(args: argparse.Namespace):
|
1438 |
+
if args.v_parameterization and not args.v2:
|
1439 |
+
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
|
1440 |
+
if args.v2 and args.clip_skip is not None:
|
1441 |
+
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
1442 |
+
|
1443 |
+
|
1444 |
+
def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool):
|
1445 |
+
# dataset common
|
1446 |
+
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
|
1447 |
+
parser.add_argument("--shuffle_caption", action="store_true",
|
1448 |
+
help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする")
|
1449 |
+
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
|
1450 |
+
parser.add_argument("--caption_extention", type=str, default=None,
|
1451 |
+
help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
|
1452 |
+
parser.add_argument("--keep_tokens", type=int, default=None,
|
1453 |
+
help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
|
1454 |
+
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
|
1455 |
+
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
|
1456 |
+
parser.add_argument("--face_crop_aug_range", type=str, default=None,
|
1457 |
+
help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)")
|
1458 |
+
parser.add_argument("--random_crop", action="store_true",
|
1459 |
+
help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)")
|
1460 |
+
parser.add_argument("--debug_dataset", action="store_true",
|
1461 |
+
help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
|
1462 |
+
parser.add_argument("--resolution", type=str, default=None,
|
1463 |
+
help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)")
|
1464 |
+
parser.add_argument("--cache_latents", action="store_true",
|
1465 |
+
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
|
1466 |
+
parser.add_argument("--enable_bucket", action="store_true",
|
1467 |
+
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
|
1468 |
+
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
1469 |
+
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
|
1470 |
+
parser.add_argument("--bucket_reso_steps", type=int, default=64,
|
1471 |
+
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
|
1472 |
+
parser.add_argument("--bucket_no_upscale", action="store_true",
|
1473 |
+
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
|
1474 |
+
|
1475 |
+
if support_caption_dropout:
|
1476 |
+
# Textual Inversion はcaptionのdropoutをsupportしない
|
1477 |
+
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
1478 |
+
parser.add_argument("--caption_dropout_rate", type=float, default=0,
|
1479 |
+
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
1480 |
+
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
|
1481 |
+
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
1482 |
+
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
|
1483 |
+
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
|
1484 |
+
|
1485 |
+
if support_dreambooth:
|
1486 |
+
# DreamBooth dataset
|
1487 |
+
parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
|
1488 |
+
|
1489 |
+
if support_caption:
|
1490 |
+
# caption dataset
|
1491 |
+
parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル")
|
1492 |
+
parser.add_argument("--dataset_repeats", type=int, default=1,
|
1493 |
+
help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数")
|
1494 |
+
|
1495 |
+
|
1496 |
+
def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
1497 |
+
parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"],
|
1498 |
+
help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)")
|
1499 |
+
parser.add_argument("--use_safetensors", action='store_true',
|
1500 |
+
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)")
|
1501 |
+
|
1502 |
+
# endregion
|
1503 |
+
|
1504 |
+
# region utils
|
1505 |
+
|
1506 |
+
|
1507 |
+
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
1508 |
+
# backward compatibility
|
1509 |
+
if args.caption_extention is not None:
|
1510 |
+
args.caption_extension = args.caption_extention
|
1511 |
+
args.caption_extention = None
|
1512 |
+
|
1513 |
+
if args.cache_latents:
|
1514 |
+
assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
|
1515 |
+
assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
|
1516 |
+
|
1517 |
+
# assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
|
1518 |
+
if args.resolution is not None:
|
1519 |
+
args.resolution = tuple([int(r) for r in args.resolution.split(',')])
|
1520 |
+
if len(args.resolution) == 1:
|
1521 |
+
args.resolution = (args.resolution[0], args.resolution[0])
|
1522 |
+
assert len(args.resolution) == 2, \
|
1523 |
+
f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
|
1524 |
+
|
1525 |
+
if args.face_crop_aug_range is not None:
|
1526 |
+
args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
|
1527 |
+
assert len(args.face_crop_aug_range) == 2 and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1], \
|
1528 |
+
f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}"
|
1529 |
+
else:
|
1530 |
+
args.face_crop_aug_range = None
|
1531 |
+
|
1532 |
+
if support_metadata:
|
1533 |
+
if args.in_json is not None and (args.color_aug or args.random_crop):
|
1534 |
+
print(f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます")
|
1535 |
+
|
1536 |
+
|
1537 |
+
def load_tokenizer(args: argparse.Namespace):
|
1538 |
+
print("prepare tokenizer")
|
1539 |
+
if args.v2:
|
1540 |
+
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
1541 |
+
else:
|
1542 |
+
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
|
1543 |
+
if args.max_token_length is not None:
|
1544 |
+
print(f"update token length: {args.max_token_length}")
|
1545 |
+
return tokenizer
|
1546 |
+
|
1547 |
+
|
1548 |
+
def prepare_accelerator(args: argparse.Namespace):
|
1549 |
+
if args.logging_dir is None:
|
1550 |
+
log_with = None
|
1551 |
+
logging_dir = None
|
1552 |
+
else:
|
1553 |
+
log_with = "tensorboard"
|
1554 |
+
log_prefix = "" if args.log_prefix is None else args.log_prefix
|
1555 |
+
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime())
|
1556 |
+
|
1557 |
+
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision,
|
1558 |
+
log_with=log_with, logging_dir=logging_dir)
|
1559 |
+
|
1560 |
+
# accelerateの互換性問題を解決する
|
1561 |
+
accelerator_0_15 = True
|
1562 |
+
try:
|
1563 |
+
accelerator.unwrap_model("dummy", True)
|
1564 |
+
print("Using accelerator 0.15.0 or above.")
|
1565 |
+
except TypeError:
|
1566 |
+
accelerator_0_15 = False
|
1567 |
+
|
1568 |
+
def unwrap_model(model):
|
1569 |
+
if accelerator_0_15:
|
1570 |
+
return accelerator.unwrap_model(model, True)
|
1571 |
+
return accelerator.unwrap_model(model)
|
1572 |
+
|
1573 |
+
return accelerator, unwrap_model
|
1574 |
+
|
1575 |
+
|
1576 |
+
def prepare_dtype(args: argparse.Namespace):
|
1577 |
+
weight_dtype = torch.float32
|
1578 |
+
if args.mixed_precision == "fp16":
|
1579 |
+
weight_dtype = torch.float16
|
1580 |
+
elif args.mixed_precision == "bf16":
|
1581 |
+
weight_dtype = torch.bfloat16
|
1582 |
+
|
1583 |
+
save_dtype = None
|
1584 |
+
if args.save_precision == "fp16":
|
1585 |
+
save_dtype = torch.float16
|
1586 |
+
elif args.save_precision == "bf16":
|
1587 |
+
save_dtype = torch.bfloat16
|
1588 |
+
elif args.save_precision == "float":
|
1589 |
+
save_dtype = torch.float32
|
1590 |
+
|
1591 |
+
return weight_dtype, save_dtype
|
1592 |
+
|
1593 |
+
|
1594 |
+
def load_target_model(args: argparse.Namespace, weight_dtype):
|
1595 |
+
load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers
|
1596 |
+
if load_stable_diffusion_format:
|
1597 |
+
print("load StableDiffusion checkpoint")
|
1598 |
+
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
|
1599 |
+
else:
|
1600 |
+
print("load Diffusers pretrained models")
|
1601 |
+
pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
|
1602 |
+
text_encoder = pipe.text_encoder
|
1603 |
+
vae = pipe.vae
|
1604 |
+
unet = pipe.unet
|
1605 |
+
del pipe
|
1606 |
+
|
1607 |
+
# VAEを読み込む
|
1608 |
+
if args.vae is not None:
|
1609 |
+
vae = model_util.load_vae(args.vae, weight_dtype)
|
1610 |
+
print("additional VAE loaded")
|
1611 |
+
|
1612 |
+
return text_encoder, vae, unet, load_stable_diffusion_format
|
1613 |
+
|
1614 |
+
|
1615 |
+
def patch_accelerator_for_fp16_training(accelerator):
|
1616 |
+
org_unscale_grads = accelerator.scaler._unscale_grads_
|
1617 |
+
|
1618 |
+
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
1619 |
+
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
|
1620 |
+
|
1621 |
+
accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
|
1622 |
+
|
1623 |
+
|
1624 |
+
def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None):
|
1625 |
+
# with no_token_padding, the length is not max length, return result immediately
|
1626 |
+
if input_ids.size()[-1] != tokenizer.model_max_length:
|
1627 |
+
return text_encoder(input_ids)[0]
|
1628 |
+
|
1629 |
+
b_size = input_ids.size()[0]
|
1630 |
+
input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
|
1631 |
+
|
1632 |
+
if args.clip_skip is None:
|
1633 |
+
encoder_hidden_states = text_encoder(input_ids)[0]
|
1634 |
+
else:
|
1635 |
+
enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
|
1636 |
+
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
1637 |
+
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
1638 |
+
|
1639 |
+
# bs*3, 77, 768 or 1024
|
1640 |
+
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
|
1641 |
+
|
1642 |
+
if args.max_token_length is not None:
|
1643 |
+
if args.v2:
|
1644 |
+
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
1645 |
+
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
1646 |
+
for i in range(1, args.max_token_length, tokenizer.model_max_length):
|
1647 |
+
chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # <BOS> の後から 最後の前まで
|
1648 |
+
if i > 0:
|
1649 |
+
for j in range(len(chunk)):
|
1650 |
+
if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
1651 |
+
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
1652 |
+
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
1653 |
+
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
1654 |
+
encoder_hidden_states = torch.cat(states_list, dim=1)
|
1655 |
+
else:
|
1656 |
+
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
1657 |
+
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
1658 |
+
for i in range(1, args.max_token_length, tokenizer.model_max_length):
|
1659 |
+
states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
1660 |
+
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
|
1661 |
+
encoder_hidden_states = torch.cat(states_list, dim=1)
|
1662 |
+
|
1663 |
+
if weight_dtype is not None:
|
1664 |
+
# this is required for additional network training
|
1665 |
+
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
|
1666 |
+
|
1667 |
+
return encoder_hidden_states
|
1668 |
+
|
1669 |
+
|
1670 |
+
def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch):
|
1671 |
+
model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
1672 |
+
ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt")
|
1673 |
+
return model_name, ckpt_name
|
1674 |
+
|
1675 |
+
|
1676 |
+
def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int):
|
1677 |
+
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
|
1678 |
+
if saving:
|
1679 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
1680 |
+
save_func()
|
1681 |
+
|
1682 |
+
if args.save_last_n_epochs is not None:
|
1683 |
+
remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
|
1684 |
+
remove_old_func(remove_epoch_no)
|
1685 |
+
return saving
|
1686 |
+
|
1687 |
+
|
1688 |
+
def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae):
|
1689 |
+
epoch_no = epoch + 1
|
1690 |
+
model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no)
|
1691 |
+
|
1692 |
+
if save_stable_diffusion_format:
|
1693 |
+
def save_sd():
|
1694 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
1695 |
+
print(f"saving checkpoint: {ckpt_file}")
|
1696 |
+
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
|
1697 |
+
src_path, epoch_no, global_step, save_dtype, vae)
|
1698 |
+
|
1699 |
+
def remove_sd(old_epoch_no):
|
1700 |
+
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
|
1701 |
+
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
1702 |
+
if os.path.exists(old_ckpt_file):
|
1703 |
+
print(f"removing old checkpoint: {old_ckpt_file}")
|
1704 |
+
os.remove(old_ckpt_file)
|
1705 |
+
|
1706 |
+
save_func = save_sd
|
1707 |
+
remove_old_func = remove_sd
|
1708 |
+
else:
|
1709 |
+
def save_du():
|
1710 |
+
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no))
|
1711 |
+
print(f"saving model: {out_dir}")
|
1712 |
+
os.makedirs(out_dir, exist_ok=True)
|
1713 |
+
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
|
1714 |
+
src_path, vae=vae, use_safetensors=use_safetensors)
|
1715 |
+
|
1716 |
+
def remove_du(old_epoch_no):
|
1717 |
+
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
|
1718 |
+
if os.path.exists(out_dir_old):
|
1719 |
+
print(f"removing old model: {out_dir_old}")
|
1720 |
+
shutil.rmtree(out_dir_old)
|
1721 |
+
|
1722 |
+
save_func = save_du
|
1723 |
+
remove_old_func = remove_du
|
1724 |
+
|
1725 |
+
saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
|
1726 |
+
if saving and args.save_state:
|
1727 |
+
save_state_on_epoch_end(args, accelerator, model_name, epoch_no)
|
1728 |
+
|
1729 |
+
|
1730 |
+
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no):
|
1731 |
+
print("saving state.")
|
1732 |
+
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
|
1733 |
+
|
1734 |
+
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
|
1735 |
+
if last_n_epochs is not None:
|
1736 |
+
remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs
|
1737 |
+
state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
|
1738 |
+
if os.path.exists(state_dir_old):
|
1739 |
+
print(f"removing old state: {state_dir_old}")
|
1740 |
+
shutil.rmtree(state_dir_old)
|
1741 |
+
|
1742 |
+
|
1743 |
+
def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, global_step: int, text_encoder, unet, vae):
|
1744 |
+
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
1745 |
+
|
1746 |
+
if save_stable_diffusion_format:
|
1747 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
1748 |
+
|
1749 |
+
ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt")
|
1750 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
1751 |
+
|
1752 |
+
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
|
1753 |
+
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
|
1754 |
+
src_path, epoch, global_step, save_dtype, vae)
|
1755 |
+
else:
|
1756 |
+
out_dir = os.path.join(args.output_dir, model_name)
|
1757 |
+
os.makedirs(out_dir, exist_ok=True)
|
1758 |
+
|
1759 |
+
print(f"save trained model as Diffusers to {out_dir}")
|
1760 |
+
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
|
1761 |
+
src_path, vae=vae, use_safetensors=use_safetensors)
|
1762 |
+
|
1763 |
+
|
1764 |
+
def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
1765 |
+
print("saving last state.")
|
1766 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
1767 |
+
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
1768 |
+
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
1769 |
+
|
1770 |
+
# endregion
|
1771 |
+
|
1772 |
+
# region 前処理用
|
1773 |
+
|
1774 |
+
|
1775 |
+
class ImageLoadingDataset(torch.utils.data.Dataset):
|
1776 |
+
def __init__(self, image_paths):
|
1777 |
+
self.images = image_paths
|
1778 |
+
|
1779 |
+
def __len__(self):
|
1780 |
+
return len(self.images)
|
1781 |
+
|
1782 |
+
def __getitem__(self, idx):
|
1783 |
+
img_path = self.images[idx]
|
1784 |
+
|
1785 |
+
try:
|
1786 |
+
image = Image.open(img_path).convert("RGB")
|
1787 |
+
# convert to tensor temporarily so dataloader will accept it
|
1788 |
+
tensor_pil = transforms.functional.pil_to_tensor(image)
|
1789 |
+
except Exception as e:
|
1790 |
+
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
1791 |
+
return None
|
1792 |
+
|
1793 |
+
return (tensor_pil, img_path)
|
1794 |
+
|
1795 |
+
|
1796 |
+
# endregion
|
fine_tune.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# training with captions
|
2 |
+
# XXX dropped option: hypernetwork training
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import gc
|
6 |
+
import math
|
7 |
+
import os
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
import torch
|
11 |
+
from accelerate.utils import set_seed
|
12 |
+
import diffusers
|
13 |
+
from diffusers import DDPMScheduler
|
14 |
+
|
15 |
+
import library.train_util as train_util
|
16 |
+
|
17 |
+
|
18 |
+
def collate_fn(examples):
|
19 |
+
return examples[0]
|
20 |
+
|
21 |
+
|
22 |
+
def train(args):
|
23 |
+
train_util.verify_training_args(args)
|
24 |
+
train_util.prepare_dataset_args(args, True)
|
25 |
+
|
26 |
+
cache_latents = args.cache_latents
|
27 |
+
|
28 |
+
if args.seed is not None:
|
29 |
+
set_seed(args.seed) # 乱数系列を初期化する
|
30 |
+
|
31 |
+
tokenizer = train_util.load_tokenizer(args)
|
32 |
+
|
33 |
+
train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
34 |
+
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
35 |
+
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
36 |
+
args.bucket_reso_steps, args.bucket_no_upscale,
|
37 |
+
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
38 |
+
args.dataset_repeats, args.debug_dataset)
|
39 |
+
|
40 |
+
# 学習データのdropout率を設定する
|
41 |
+
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
|
42 |
+
|
43 |
+
train_dataset.make_buckets()
|
44 |
+
|
45 |
+
if args.debug_dataset:
|
46 |
+
train_util.debug_dataset(train_dataset)
|
47 |
+
return
|
48 |
+
if len(train_dataset) == 0:
|
49 |
+
print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
|
50 |
+
return
|
51 |
+
|
52 |
+
# acceleratorを準備する
|
53 |
+
print("prepare accelerator")
|
54 |
+
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
55 |
+
|
56 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
57 |
+
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
58 |
+
|
59 |
+
# モデルを読み込む
|
60 |
+
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
|
61 |
+
|
62 |
+
# verify load/save model formats
|
63 |
+
if load_stable_diffusion_format:
|
64 |
+
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
|
65 |
+
src_diffusers_model_path = None
|
66 |
+
else:
|
67 |
+
src_stable_diffusion_ckpt = None
|
68 |
+
src_diffusers_model_path = args.pretrained_model_name_or_path
|
69 |
+
|
70 |
+
if args.save_model_as is None:
|
71 |
+
save_stable_diffusion_format = load_stable_diffusion_format
|
72 |
+
use_safetensors = args.use_safetensors
|
73 |
+
else:
|
74 |
+
save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors'
|
75 |
+
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
|
76 |
+
|
77 |
+
# Diffusers版のxformers使用フラグを設定する関数
|
78 |
+
def set_diffusers_xformers_flag(model, valid):
|
79 |
+
# model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう
|
80 |
+
# pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`)
|
81 |
+
# U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか
|
82 |
+
# 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^)
|
83 |
+
|
84 |
+
# Recursively walk through all the children.
|
85 |
+
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
86 |
+
# gets the message
|
87 |
+
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
88 |
+
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
89 |
+
module.set_use_memory_efficient_attention_xformers(valid)
|
90 |
+
|
91 |
+
for child in module.children():
|
92 |
+
fn_recursive_set_mem_eff(child)
|
93 |
+
|
94 |
+
fn_recursive_set_mem_eff(model)
|
95 |
+
|
96 |
+
# モデルに xformers とか memory efficient attention を組み込む
|
97 |
+
if args.diffusers_xformers:
|
98 |
+
print("Use xformers by Diffusers")
|
99 |
+
set_diffusers_xformers_flag(unet, True)
|
100 |
+
else:
|
101 |
+
# Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
|
102 |
+
print("Disable Diffusers' xformers")
|
103 |
+
set_diffusers_xformers_flag(unet, False)
|
104 |
+
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
105 |
+
|
106 |
+
# 学習を準備する
|
107 |
+
if cache_latents:
|
108 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
109 |
+
vae.requires_grad_(False)
|
110 |
+
vae.eval()
|
111 |
+
with torch.no_grad():
|
112 |
+
train_dataset.cache_latents(vae)
|
113 |
+
vae.to("cpu")
|
114 |
+
if torch.cuda.is_available():
|
115 |
+
torch.cuda.empty_cache()
|
116 |
+
gc.collect()
|
117 |
+
|
118 |
+
# 学習を準備する:モデルを適切な状態にする
|
119 |
+
training_models = []
|
120 |
+
if args.gradient_checkpointing:
|
121 |
+
unet.enable_gradient_checkpointing()
|
122 |
+
training_models.append(unet)
|
123 |
+
|
124 |
+
if args.train_text_encoder:
|
125 |
+
print("enable text encoder training")
|
126 |
+
if args.gradient_checkpointing:
|
127 |
+
text_encoder.gradient_checkpointing_enable()
|
128 |
+
training_models.append(text_encoder)
|
129 |
+
else:
|
130 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
131 |
+
text_encoder.requires_grad_(False) # text encoderは学習しない
|
132 |
+
if args.gradient_checkpointing:
|
133 |
+
text_encoder.gradient_checkpointing_enable()
|
134 |
+
text_encoder.train() # required for gradient_checkpointing
|
135 |
+
else:
|
136 |
+
text_encoder.eval()
|
137 |
+
|
138 |
+
if not cache_latents:
|
139 |
+
vae.requires_grad_(False)
|
140 |
+
vae.eval()
|
141 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
142 |
+
|
143 |
+
for m in training_models:
|
144 |
+
m.requires_grad_(True)
|
145 |
+
params = []
|
146 |
+
for m in training_models:
|
147 |
+
params.extend(m.parameters())
|
148 |
+
params_to_optimize = params
|
149 |
+
|
150 |
+
# 学習に必要なクラスを準備する
|
151 |
+
print("prepare optimizer, data loader etc.")
|
152 |
+
|
153 |
+
# 8-bit Adamを使う
|
154 |
+
if args.use_8bit_adam:
|
155 |
+
try:
|
156 |
+
import bitsandbytes as bnb
|
157 |
+
except ImportError:
|
158 |
+
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
159 |
+
print("use 8-bit Adam optimizer")
|
160 |
+
optimizer_class = bnb.optim.AdamW8bit
|
161 |
+
elif args.use_lion_optimizer:
|
162 |
+
try:
|
163 |
+
import lion_pytorch
|
164 |
+
except ImportError:
|
165 |
+
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
166 |
+
print("use Lion optimizer")
|
167 |
+
optimizer_class = lion_pytorch.Lion
|
168 |
+
else:
|
169 |
+
optimizer_class = torch.optim.AdamW
|
170 |
+
|
171 |
+
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
172 |
+
optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
|
173 |
+
|
174 |
+
# dataloaderを準備する
|
175 |
+
# DataLoaderのプロセス数:0はメインプロセスになる
|
176 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
177 |
+
train_dataloader = torch.utils.data.DataLoader(
|
178 |
+
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
179 |
+
|
180 |
+
# 学習ステップ数を計算する
|
181 |
+
if args.max_train_epochs is not None:
|
182 |
+
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
183 |
+
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
184 |
+
|
185 |
+
# lr schedulerを用意する
|
186 |
+
lr_scheduler = diffusers.optimization.get_scheduler(
|
187 |
+
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
188 |
+
|
189 |
+
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
190 |
+
if args.full_fp16:
|
191 |
+
assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
192 |
+
print("enable full fp16 training.")
|
193 |
+
unet.to(weight_dtype)
|
194 |
+
text_encoder.to(weight_dtype)
|
195 |
+
|
196 |
+
# acceleratorがなんかよろしくやってくれるらしい
|
197 |
+
if args.train_text_encoder:
|
198 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
199 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
|
200 |
+
else:
|
201 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
202 |
+
|
203 |
+
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
204 |
+
if args.full_fp16:
|
205 |
+
train_util.patch_accelerator_for_fp16_training(accelerator)
|
206 |
+
|
207 |
+
# resumeする
|
208 |
+
if args.resume is not None:
|
209 |
+
print(f"resume training from state: {args.resume}")
|
210 |
+
accelerator.load_state(args.resume)
|
211 |
+
|
212 |
+
# epoch数を計算する
|
213 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
214 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
215 |
+
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
216 |
+
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
217 |
+
|
218 |
+
# 学習する
|
219 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
220 |
+
print("running training / 学習開始")
|
221 |
+
print(f" num examples / サンプル数: {train_dataset.num_train_images}")
|
222 |
+
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
223 |
+
print(f" num epochs / epoch数: {num_train_epochs}")
|
224 |
+
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
225 |
+
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
226 |
+
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
227 |
+
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
228 |
+
|
229 |
+
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
230 |
+
global_step = 0
|
231 |
+
|
232 |
+
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
233 |
+
num_train_timesteps=1000, clip_sample=False)
|
234 |
+
|
235 |
+
if accelerator.is_main_process:
|
236 |
+
accelerator.init_trackers("finetuning")
|
237 |
+
|
238 |
+
for epoch in range(num_train_epochs):
|
239 |
+
print(f"epoch {epoch+1}/{num_train_epochs}")
|
240 |
+
train_dataset.set_current_epoch(epoch + 1)
|
241 |
+
|
242 |
+
for m in training_models:
|
243 |
+
m.train()
|
244 |
+
|
245 |
+
loss_total = 0
|
246 |
+
for step, batch in enumerate(train_dataloader):
|
247 |
+
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
248 |
+
with torch.no_grad():
|
249 |
+
if "latents" in batch and batch["latents"] is not None:
|
250 |
+
latents = batch["latents"].to(accelerator.device)
|
251 |
+
else:
|
252 |
+
# latentに変換
|
253 |
+
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
254 |
+
latents = latents * 0.18215
|
255 |
+
b_size = latents.shape[0]
|
256 |
+
|
257 |
+
with torch.set_grad_enabled(args.train_text_encoder):
|
258 |
+
# Get the text embedding for conditioning
|
259 |
+
input_ids = batch["input_ids"].to(accelerator.device)
|
260 |
+
encoder_hidden_states = train_util.get_hidden_states(
|
261 |
+
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype)
|
262 |
+
|
263 |
+
# Sample noise that we'll add to the latents
|
264 |
+
noise = torch.randn_like(latents, device=latents.device)
|
265 |
+
if args.noise_offset:
|
266 |
+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
267 |
+
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
268 |
+
|
269 |
+
# Sample a random timestep for each image
|
270 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
271 |
+
timesteps = timesteps.long()
|
272 |
+
|
273 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
274 |
+
# (this is the forward diffusion process)
|
275 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
276 |
+
|
277 |
+
# Predict the noise residual
|
278 |
+
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
279 |
+
|
280 |
+
if args.v_parameterization:
|
281 |
+
# v-parameterization training
|
282 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
283 |
+
else:
|
284 |
+
target = noise
|
285 |
+
|
286 |
+
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
287 |
+
|
288 |
+
accelerator.backward(loss)
|
289 |
+
if accelerator.sync_gradients:
|
290 |
+
params_to_clip = []
|
291 |
+
for m in training_models:
|
292 |
+
params_to_clip.extend(m.parameters())
|
293 |
+
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
294 |
+
|
295 |
+
optimizer.step()
|
296 |
+
lr_scheduler.step()
|
297 |
+
optimizer.zero_grad(set_to_none=True)
|
298 |
+
|
299 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
300 |
+
if accelerator.sync_gradients:
|
301 |
+
progress_bar.update(1)
|
302 |
+
global_step += 1
|
303 |
+
|
304 |
+
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
305 |
+
if args.logging_dir is not None:
|
306 |
+
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
307 |
+
accelerator.log(logs, step=global_step)
|
308 |
+
|
309 |
+
loss_total += current_loss
|
310 |
+
avr_loss = loss_total / (step+1)
|
311 |
+
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
312 |
+
progress_bar.set_postfix(**logs)
|
313 |
+
|
314 |
+
if global_step >= args.max_train_steps:
|
315 |
+
break
|
316 |
+
|
317 |
+
if args.logging_dir is not None:
|
318 |
+
logs = {"epoch_loss": loss_total / len(train_dataloader)}
|
319 |
+
accelerator.log(logs, step=epoch+1)
|
320 |
+
|
321 |
+
accelerator.wait_for_everyone()
|
322 |
+
|
323 |
+
if args.save_every_n_epochs is not None:
|
324 |
+
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
325 |
+
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
326 |
+
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
327 |
+
|
328 |
+
is_main_process = accelerator.is_main_process
|
329 |
+
if is_main_process:
|
330 |
+
unet = unwrap_model(unet)
|
331 |
+
text_encoder = unwrap_model(text_encoder)
|
332 |
+
|
333 |
+
accelerator.end_training()
|
334 |
+
|
335 |
+
if args.save_state:
|
336 |
+
train_util.save_state_on_train_end(args, accelerator)
|
337 |
+
|
338 |
+
del accelerator # この後メモリを使うのでこれは消す
|
339 |
+
|
340 |
+
if is_main_process:
|
341 |
+
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
342 |
+
train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors,
|
343 |
+
save_dtype, epoch, global_step, text_encoder, unet, vae)
|
344 |
+
print("model saved.")
|
345 |
+
|
346 |
+
|
347 |
+
if __name__ == '__main__':
|
348 |
+
parser = argparse.ArgumentParser()
|
349 |
+
|
350 |
+
train_util.add_sd_models_arguments(parser)
|
351 |
+
train_util.add_dataset_arguments(parser, False, True, True)
|
352 |
+
train_util.add_training_arguments(parser, False)
|
353 |
+
train_util.add_sd_saving_arguments(parser)
|
354 |
+
|
355 |
+
parser.add_argument("--diffusers_xformers", action='store_true',
|
356 |
+
help='use xformers by diffusers / Diffusersでxformersを使用する')
|
357 |
+
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
358 |
+
|
359 |
+
args = parser.parse_args()
|
360 |
+
train(args)
|
gen_img_diffusers.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
library.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: library
|
3 |
+
Version: 0.0.0
|
4 |
+
License-File: LICENSE.md
|
library.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LICENSE.md
|
2 |
+
README.md
|
3 |
+
setup.py
|
4 |
+
library/__init__.py
|
5 |
+
library/model_util.py
|
6 |
+
library/train_util.py
|
7 |
+
library.egg-info/PKG-INFO
|
8 |
+
library.egg-info/SOURCES.txt
|
9 |
+
library.egg-info/dependency_links.txt
|
10 |
+
library.egg-info/top_level.txt
|
library.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
library.egg-info/top_level.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
library
|
library/__init__.py
ADDED
File without changes
|
library/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (131 Bytes). View file
|
|
library/__pycache__/model_util.cpython-310.pyc
ADDED
Binary file (29.2 kB). View file
|
|
library/__pycache__/train_util.cpython-310.pyc
ADDED
Binary file (57.4 kB). View file
|
|
library/model_util.py
ADDED
@@ -0,0 +1,1180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# v1: split from train_db_fixed.py.
|
2 |
+
# v2: support safetensors
|
3 |
+
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
|
8 |
+
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
9 |
+
from safetensors.torch import load_file, save_file
|
10 |
+
|
11 |
+
# DiffUsers版StableDiffusionのモデルパラメータ
|
12 |
+
NUM_TRAIN_TIMESTEPS = 1000
|
13 |
+
BETA_START = 0.00085
|
14 |
+
BETA_END = 0.0120
|
15 |
+
|
16 |
+
UNET_PARAMS_MODEL_CHANNELS = 320
|
17 |
+
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
|
18 |
+
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
|
19 |
+
UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
|
20 |
+
UNET_PARAMS_IN_CHANNELS = 4
|
21 |
+
UNET_PARAMS_OUT_CHANNELS = 4
|
22 |
+
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
23 |
+
UNET_PARAMS_CONTEXT_DIM = 768
|
24 |
+
UNET_PARAMS_NUM_HEADS = 8
|
25 |
+
|
26 |
+
VAE_PARAMS_Z_CHANNELS = 4
|
27 |
+
VAE_PARAMS_RESOLUTION = 256
|
28 |
+
VAE_PARAMS_IN_CHANNELS = 3
|
29 |
+
VAE_PARAMS_OUT_CH = 3
|
30 |
+
VAE_PARAMS_CH = 128
|
31 |
+
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
|
32 |
+
VAE_PARAMS_NUM_RES_BLOCKS = 2
|
33 |
+
|
34 |
+
# V2
|
35 |
+
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
36 |
+
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
37 |
+
|
38 |
+
# Diffusersの設定を読み込むための参照モデル
|
39 |
+
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
|
40 |
+
DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
|
41 |
+
|
42 |
+
|
43 |
+
# region StableDiffusion->Diffusersの変換コード
|
44 |
+
# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
|
45 |
+
|
46 |
+
|
47 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
48 |
+
"""
|
49 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
50 |
+
"""
|
51 |
+
if n_shave_prefix_segments >= 0:
|
52 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
53 |
+
else:
|
54 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
55 |
+
|
56 |
+
|
57 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
58 |
+
"""
|
59 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
60 |
+
"""
|
61 |
+
mapping = []
|
62 |
+
for old_item in old_list:
|
63 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
64 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
65 |
+
|
66 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
67 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
68 |
+
|
69 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
70 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
71 |
+
|
72 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
73 |
+
|
74 |
+
mapping.append({"old": old_item, "new": new_item})
|
75 |
+
|
76 |
+
return mapping
|
77 |
+
|
78 |
+
|
79 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
80 |
+
"""
|
81 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
82 |
+
"""
|
83 |
+
mapping = []
|
84 |
+
for old_item in old_list:
|
85 |
+
new_item = old_item
|
86 |
+
|
87 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
88 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
89 |
+
|
90 |
+
mapping.append({"old": old_item, "new": new_item})
|
91 |
+
|
92 |
+
return mapping
|
93 |
+
|
94 |
+
|
95 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
96 |
+
"""
|
97 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
98 |
+
"""
|
99 |
+
mapping = []
|
100 |
+
for old_item in old_list:
|
101 |
+
new_item = old_item
|
102 |
+
|
103 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
104 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
105 |
+
|
106 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
107 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
108 |
+
|
109 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
110 |
+
|
111 |
+
mapping.append({"old": old_item, "new": new_item})
|
112 |
+
|
113 |
+
return mapping
|
114 |
+
|
115 |
+
|
116 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
117 |
+
"""
|
118 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
119 |
+
"""
|
120 |
+
mapping = []
|
121 |
+
for old_item in old_list:
|
122 |
+
new_item = old_item
|
123 |
+
|
124 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
125 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
126 |
+
|
127 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
128 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
129 |
+
|
130 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
131 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
132 |
+
|
133 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
134 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
135 |
+
|
136 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
137 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
138 |
+
|
139 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
140 |
+
|
141 |
+
mapping.append({"old": old_item, "new": new_item})
|
142 |
+
|
143 |
+
return mapping
|
144 |
+
|
145 |
+
|
146 |
+
def assign_to_checkpoint(
|
147 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
148 |
+
):
|
149 |
+
"""
|
150 |
+
This does the final conversion step: take locally converted weights and apply a global renaming
|
151 |
+
to them. It splits attention layers, and takes into account additional replacements
|
152 |
+
that may arise.
|
153 |
+
|
154 |
+
Assigns the weights to the new checkpoint.
|
155 |
+
"""
|
156 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
157 |
+
|
158 |
+
# Splits the attention layers into three variables.
|
159 |
+
if attention_paths_to_split is not None:
|
160 |
+
for path, path_map in attention_paths_to_split.items():
|
161 |
+
old_tensor = old_checkpoint[path]
|
162 |
+
channels = old_tensor.shape[0] // 3
|
163 |
+
|
164 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
165 |
+
|
166 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
167 |
+
|
168 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
169 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
170 |
+
|
171 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
172 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
173 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
174 |
+
|
175 |
+
for path in paths:
|
176 |
+
new_path = path["new"]
|
177 |
+
|
178 |
+
# These have already been assigned
|
179 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
180 |
+
continue
|
181 |
+
|
182 |
+
# Global renaming happens here
|
183 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
184 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
185 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
186 |
+
|
187 |
+
if additional_replacements is not None:
|
188 |
+
for replacement in additional_replacements:
|
189 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
190 |
+
|
191 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
192 |
+
if "proj_attn.weight" in new_path:
|
193 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
194 |
+
else:
|
195 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
196 |
+
|
197 |
+
|
198 |
+
def conv_attn_to_linear(checkpoint):
|
199 |
+
keys = list(checkpoint.keys())
|
200 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
201 |
+
for key in keys:
|
202 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
203 |
+
if checkpoint[key].ndim > 2:
|
204 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
205 |
+
elif "proj_attn.weight" in key:
|
206 |
+
if checkpoint[key].ndim > 2:
|
207 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
208 |
+
|
209 |
+
|
210 |
+
def linear_transformer_to_conv(checkpoint):
|
211 |
+
keys = list(checkpoint.keys())
|
212 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
213 |
+
for key in keys:
|
214 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
215 |
+
if checkpoint[key].ndim == 2:
|
216 |
+
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
217 |
+
|
218 |
+
|
219 |
+
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
220 |
+
"""
|
221 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
222 |
+
"""
|
223 |
+
|
224 |
+
# extract state_dict for UNet
|
225 |
+
unet_state_dict = {}
|
226 |
+
unet_key = "model.diffusion_model."
|
227 |
+
keys = list(checkpoint.keys())
|
228 |
+
for key in keys:
|
229 |
+
if key.startswith(unet_key):
|
230 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
231 |
+
|
232 |
+
new_checkpoint = {}
|
233 |
+
|
234 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
235 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
236 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
237 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
238 |
+
|
239 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
240 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
241 |
+
|
242 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
243 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
244 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
245 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
246 |
+
|
247 |
+
# Retrieves the keys for the input blocks only
|
248 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
249 |
+
input_blocks = {
|
250 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
|
251 |
+
for layer_id in range(num_input_blocks)
|
252 |
+
}
|
253 |
+
|
254 |
+
# Retrieves the keys for the middle blocks only
|
255 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
256 |
+
middle_blocks = {
|
257 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
|
258 |
+
for layer_id in range(num_middle_blocks)
|
259 |
+
}
|
260 |
+
|
261 |
+
# Retrieves the keys for the output blocks only
|
262 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
263 |
+
output_blocks = {
|
264 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
|
265 |
+
for layer_id in range(num_output_blocks)
|
266 |
+
}
|
267 |
+
|
268 |
+
for i in range(1, num_input_blocks):
|
269 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
270 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
271 |
+
|
272 |
+
resnets = [
|
273 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
274 |
+
]
|
275 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
276 |
+
|
277 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
278 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
279 |
+
f"input_blocks.{i}.0.op.weight"
|
280 |
+
)
|
281 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
282 |
+
f"input_blocks.{i}.0.op.bias"
|
283 |
+
)
|
284 |
+
|
285 |
+
paths = renew_resnet_paths(resnets)
|
286 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
287 |
+
assign_to_checkpoint(
|
288 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
289 |
+
)
|
290 |
+
|
291 |
+
if len(attentions):
|
292 |
+
paths = renew_attention_paths(attentions)
|
293 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
294 |
+
assign_to_checkpoint(
|
295 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
296 |
+
)
|
297 |
+
|
298 |
+
resnet_0 = middle_blocks[0]
|
299 |
+
attentions = middle_blocks[1]
|
300 |
+
resnet_1 = middle_blocks[2]
|
301 |
+
|
302 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
303 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
304 |
+
|
305 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
306 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
307 |
+
|
308 |
+
attentions_paths = renew_attention_paths(attentions)
|
309 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
310 |
+
assign_to_checkpoint(
|
311 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
312 |
+
)
|
313 |
+
|
314 |
+
for i in range(num_output_blocks):
|
315 |
+
block_id = i // (config["layers_per_block"] + 1)
|
316 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
317 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
318 |
+
output_block_list = {}
|
319 |
+
|
320 |
+
for layer in output_block_layers:
|
321 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
322 |
+
if layer_id in output_block_list:
|
323 |
+
output_block_list[layer_id].append(layer_name)
|
324 |
+
else:
|
325 |
+
output_block_list[layer_id] = [layer_name]
|
326 |
+
|
327 |
+
if len(output_block_list) > 1:
|
328 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
329 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
330 |
+
|
331 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
332 |
+
paths = renew_resnet_paths(resnets)
|
333 |
+
|
334 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
335 |
+
assign_to_checkpoint(
|
336 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
337 |
+
)
|
338 |
+
|
339 |
+
# オリジナル:
|
340 |
+
# if ["conv.weight", "conv.bias"] in output_block_list.values():
|
341 |
+
# index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
342 |
+
|
343 |
+
# biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
|
344 |
+
for l in output_block_list.values():
|
345 |
+
l.sort()
|
346 |
+
|
347 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
348 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
349 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
350 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
351 |
+
]
|
352 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
353 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
354 |
+
]
|
355 |
+
|
356 |
+
# Clear attentions as they have been attributed above.
|
357 |
+
if len(attentions) == 2:
|
358 |
+
attentions = []
|
359 |
+
|
360 |
+
if len(attentions):
|
361 |
+
paths = renew_attention_paths(attentions)
|
362 |
+
meta_path = {
|
363 |
+
"old": f"output_blocks.{i}.1",
|
364 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
365 |
+
}
|
366 |
+
assign_to_checkpoint(
|
367 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
368 |
+
)
|
369 |
+
else:
|
370 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
371 |
+
for path in resnet_0_paths:
|
372 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
373 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
374 |
+
|
375 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
376 |
+
|
377 |
+
# SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
|
378 |
+
if v2:
|
379 |
+
linear_transformer_to_conv(new_checkpoint)
|
380 |
+
|
381 |
+
return new_checkpoint
|
382 |
+
|
383 |
+
|
384 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
385 |
+
# extract state dict for VAE
|
386 |
+
vae_state_dict = {}
|
387 |
+
vae_key = "first_stage_model."
|
388 |
+
keys = list(checkpoint.keys())
|
389 |
+
for key in keys:
|
390 |
+
if key.startswith(vae_key):
|
391 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
392 |
+
# if len(vae_state_dict) == 0:
|
393 |
+
# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
|
394 |
+
# vae_state_dict = checkpoint
|
395 |
+
|
396 |
+
new_checkpoint = {}
|
397 |
+
|
398 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
399 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
400 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
401 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
402 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
403 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
404 |
+
|
405 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
406 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
407 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
408 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
409 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
410 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
411 |
+
|
412 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
413 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
414 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
415 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
416 |
+
|
417 |
+
# Retrieves the keys for the encoder down blocks only
|
418 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
419 |
+
down_blocks = {
|
420 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
421 |
+
}
|
422 |
+
|
423 |
+
# Retrieves the keys for the decoder up blocks only
|
424 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
425 |
+
up_blocks = {
|
426 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
427 |
+
}
|
428 |
+
|
429 |
+
for i in range(num_down_blocks):
|
430 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
431 |
+
|
432 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
433 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
434 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
435 |
+
)
|
436 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
437 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
438 |
+
)
|
439 |
+
|
440 |
+
paths = renew_vae_resnet_paths(resnets)
|
441 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
442 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
443 |
+
|
444 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
445 |
+
num_mid_res_blocks = 2
|
446 |
+
for i in range(1, num_mid_res_blocks + 1):
|
447 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
448 |
+
|
449 |
+
paths = renew_vae_resnet_paths(resnets)
|
450 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
451 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
452 |
+
|
453 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
454 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
455 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
456 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
457 |
+
conv_attn_to_linear(new_checkpoint)
|
458 |
+
|
459 |
+
for i in range(num_up_blocks):
|
460 |
+
block_id = num_up_blocks - 1 - i
|
461 |
+
resnets = [
|
462 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
463 |
+
]
|
464 |
+
|
465 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
466 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
467 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
468 |
+
]
|
469 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
470 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
471 |
+
]
|
472 |
+
|
473 |
+
paths = renew_vae_resnet_paths(resnets)
|
474 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
475 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
476 |
+
|
477 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
478 |
+
num_mid_res_blocks = 2
|
479 |
+
for i in range(1, num_mid_res_blocks + 1):
|
480 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
481 |
+
|
482 |
+
paths = renew_vae_resnet_paths(resnets)
|
483 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
484 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
485 |
+
|
486 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
487 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
488 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
489 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
490 |
+
conv_attn_to_linear(new_checkpoint)
|
491 |
+
return new_checkpoint
|
492 |
+
|
493 |
+
|
494 |
+
def create_unet_diffusers_config(v2):
|
495 |
+
"""
|
496 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
497 |
+
"""
|
498 |
+
# unet_params = original_config.model.params.unet_config.params
|
499 |
+
|
500 |
+
block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
|
501 |
+
|
502 |
+
down_block_types = []
|
503 |
+
resolution = 1
|
504 |
+
for i in range(len(block_out_channels)):
|
505 |
+
block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
|
506 |
+
down_block_types.append(block_type)
|
507 |
+
if i != len(block_out_channels) - 1:
|
508 |
+
resolution *= 2
|
509 |
+
|
510 |
+
up_block_types = []
|
511 |
+
for i in range(len(block_out_channels)):
|
512 |
+
block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
|
513 |
+
up_block_types.append(block_type)
|
514 |
+
resolution //= 2
|
515 |
+
|
516 |
+
config = dict(
|
517 |
+
sample_size=UNET_PARAMS_IMAGE_SIZE,
|
518 |
+
in_channels=UNET_PARAMS_IN_CHANNELS,
|
519 |
+
out_channels=UNET_PARAMS_OUT_CHANNELS,
|
520 |
+
down_block_types=tuple(down_block_types),
|
521 |
+
up_block_types=tuple(up_block_types),
|
522 |
+
block_out_channels=tuple(block_out_channels),
|
523 |
+
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
524 |
+
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
|
525 |
+
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
526 |
+
)
|
527 |
+
|
528 |
+
return config
|
529 |
+
|
530 |
+
|
531 |
+
def create_vae_diffusers_config():
|
532 |
+
"""
|
533 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
534 |
+
"""
|
535 |
+
# vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
536 |
+
# _ = original_config.model.params.first_stage_config.params.embed_dim
|
537 |
+
block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
|
538 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
539 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
540 |
+
|
541 |
+
config = dict(
|
542 |
+
sample_size=VAE_PARAMS_RESOLUTION,
|
543 |
+
in_channels=VAE_PARAMS_IN_CHANNELS,
|
544 |
+
out_channels=VAE_PARAMS_OUT_CH,
|
545 |
+
down_block_types=tuple(down_block_types),
|
546 |
+
up_block_types=tuple(up_block_types),
|
547 |
+
block_out_channels=tuple(block_out_channels),
|
548 |
+
latent_channels=VAE_PARAMS_Z_CHANNELS,
|
549 |
+
layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
|
550 |
+
)
|
551 |
+
return config
|
552 |
+
|
553 |
+
|
554 |
+
def convert_ldm_clip_checkpoint_v1(checkpoint):
|
555 |
+
keys = list(checkpoint.keys())
|
556 |
+
text_model_dict = {}
|
557 |
+
for key in keys:
|
558 |
+
if key.startswith("cond_stage_model.transformer"):
|
559 |
+
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
|
560 |
+
return text_model_dict
|
561 |
+
|
562 |
+
|
563 |
+
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
564 |
+
# 嫌になるくらい違うぞ!
|
565 |
+
def convert_key(key):
|
566 |
+
if not key.startswith("cond_stage_model"):
|
567 |
+
return None
|
568 |
+
|
569 |
+
# common conversion
|
570 |
+
key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
|
571 |
+
key = key.replace("cond_stage_model.model.", "text_model.")
|
572 |
+
|
573 |
+
if "resblocks" in key:
|
574 |
+
# resblocks conversion
|
575 |
+
key = key.replace(".resblocks.", ".layers.")
|
576 |
+
if ".ln_" in key:
|
577 |
+
key = key.replace(".ln_", ".layer_norm")
|
578 |
+
elif ".mlp." in key:
|
579 |
+
key = key.replace(".c_fc.", ".fc1.")
|
580 |
+
key = key.replace(".c_proj.", ".fc2.")
|
581 |
+
elif '.attn.out_proj' in key:
|
582 |
+
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
583 |
+
elif '.attn.in_proj' in key:
|
584 |
+
key = None # 特殊なので後で処理する
|
585 |
+
else:
|
586 |
+
raise ValueError(f"unexpected key in SD: {key}")
|
587 |
+
elif '.positional_embedding' in key:
|
588 |
+
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
589 |
+
elif '.text_projection' in key:
|
590 |
+
key = None # 使われない???
|
591 |
+
elif '.logit_scale' in key:
|
592 |
+
key = None # 使われない???
|
593 |
+
elif '.token_embedding' in key:
|
594 |
+
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
595 |
+
elif '.ln_final' in key:
|
596 |
+
key = key.replace(".ln_final", ".final_layer_norm")
|
597 |
+
return key
|
598 |
+
|
599 |
+
keys = list(checkpoint.keys())
|
600 |
+
new_sd = {}
|
601 |
+
for key in keys:
|
602 |
+
# remove resblocks 23
|
603 |
+
if '.resblocks.23.' in key:
|
604 |
+
continue
|
605 |
+
new_key = convert_key(key)
|
606 |
+
if new_key is None:
|
607 |
+
continue
|
608 |
+
new_sd[new_key] = checkpoint[key]
|
609 |
+
|
610 |
+
# attnの変換
|
611 |
+
for key in keys:
|
612 |
+
if '.resblocks.23.' in key:
|
613 |
+
continue
|
614 |
+
if '.resblocks' in key and '.attn.in_proj_' in key:
|
615 |
+
# 三つに分割
|
616 |
+
values = torch.chunk(checkpoint[key], 3)
|
617 |
+
|
618 |
+
key_suffix = ".weight" if "weight" in key else ".bias"
|
619 |
+
key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
|
620 |
+
key_pfx = key_pfx.replace("_weight", "")
|
621 |
+
key_pfx = key_pfx.replace("_bias", "")
|
622 |
+
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
623 |
+
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
624 |
+
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
625 |
+
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
626 |
+
|
627 |
+
# rename or add position_ids
|
628 |
+
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
|
629 |
+
if ANOTHER_POSITION_IDS_KEY in new_sd:
|
630 |
+
# waifu diffusion v1.4
|
631 |
+
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
|
632 |
+
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
633 |
+
else:
|
634 |
+
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
635 |
+
|
636 |
+
new_sd["text_model.embeddings.position_ids"] = position_ids
|
637 |
+
return new_sd
|
638 |
+
|
639 |
+
# endregion
|
640 |
+
|
641 |
+
|
642 |
+
# region Diffusers->StableDiffusion の変換コード
|
643 |
+
# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
|
644 |
+
|
645 |
+
def conv_transformer_to_linear(checkpoint):
|
646 |
+
keys = list(checkpoint.keys())
|
647 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
648 |
+
for key in keys:
|
649 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
650 |
+
if checkpoint[key].ndim > 2:
|
651 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
652 |
+
|
653 |
+
|
654 |
+
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
655 |
+
unet_conversion_map = [
|
656 |
+
# (stable-diffusion, HF Diffusers)
|
657 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
658 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
659 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
660 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
661 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
662 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
663 |
+
("out.0.weight", "conv_norm_out.weight"),
|
664 |
+
("out.0.bias", "conv_norm_out.bias"),
|
665 |
+
("out.2.weight", "conv_out.weight"),
|
666 |
+
("out.2.bias", "conv_out.bias"),
|
667 |
+
]
|
668 |
+
|
669 |
+
unet_conversion_map_resnet = [
|
670 |
+
# (stable-diffusion, HF Diffusers)
|
671 |
+
("in_layers.0", "norm1"),
|
672 |
+
("in_layers.2", "conv1"),
|
673 |
+
("out_layers.0", "norm2"),
|
674 |
+
("out_layers.3", "conv2"),
|
675 |
+
("emb_layers.1", "time_emb_proj"),
|
676 |
+
("skip_connection", "conv_shortcut"),
|
677 |
+
]
|
678 |
+
|
679 |
+
unet_conversion_map_layer = []
|
680 |
+
for i in range(4):
|
681 |
+
# loop over downblocks/upblocks
|
682 |
+
|
683 |
+
for j in range(2):
|
684 |
+
# loop over resnets/attentions for downblocks
|
685 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
686 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
687 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
688 |
+
|
689 |
+
if i < 3:
|
690 |
+
# no attention layers in down_blocks.3
|
691 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
692 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
693 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
694 |
+
|
695 |
+
for j in range(3):
|
696 |
+
# loop over resnets/attentions for upblocks
|
697 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
698 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
699 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
700 |
+
|
701 |
+
if i > 0:
|
702 |
+
# no attention layers in up_blocks.0
|
703 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
704 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
705 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
706 |
+
|
707 |
+
if i < 3:
|
708 |
+
# no downsample in down_blocks.3
|
709 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
710 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
711 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
712 |
+
|
713 |
+
# no upsample in up_blocks.3
|
714 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
715 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
716 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
717 |
+
|
718 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
719 |
+
sd_mid_atn_prefix = "middle_block.1."
|
720 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
721 |
+
|
722 |
+
for j in range(2):
|
723 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
724 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
725 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
726 |
+
|
727 |
+
# buyer beware: this is a *brittle* function,
|
728 |
+
# and correct output requires that all of these pieces interact in
|
729 |
+
# the exact order in which I have arranged them.
|
730 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
731 |
+
for sd_name, hf_name in unet_conversion_map:
|
732 |
+
mapping[hf_name] = sd_name
|
733 |
+
for k, v in mapping.items():
|
734 |
+
if "resnets" in k:
|
735 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
736 |
+
v = v.replace(hf_part, sd_part)
|
737 |
+
mapping[k] = v
|
738 |
+
for k, v in mapping.items():
|
739 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
740 |
+
v = v.replace(hf_part, sd_part)
|
741 |
+
mapping[k] = v
|
742 |
+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
743 |
+
|
744 |
+
if v2:
|
745 |
+
conv_transformer_to_linear(new_state_dict)
|
746 |
+
|
747 |
+
return new_state_dict
|
748 |
+
|
749 |
+
|
750 |
+
# ================#
|
751 |
+
# VAE Conversion #
|
752 |
+
# ================#
|
753 |
+
|
754 |
+
def reshape_weight_for_sd(w):
|
755 |
+
# convert HF linear weights to SD conv2d weights
|
756 |
+
return w.reshape(*w.shape, 1, 1)
|
757 |
+
|
758 |
+
|
759 |
+
def convert_vae_state_dict(vae_state_dict):
|
760 |
+
vae_conversion_map = [
|
761 |
+
# (stable-diffusion, HF Diffusers)
|
762 |
+
("nin_shortcut", "conv_shortcut"),
|
763 |
+
("norm_out", "conv_norm_out"),
|
764 |
+
("mid.attn_1.", "mid_block.attentions.0."),
|
765 |
+
]
|
766 |
+
|
767 |
+
for i in range(4):
|
768 |
+
# down_blocks have two resnets
|
769 |
+
for j in range(2):
|
770 |
+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
771 |
+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
772 |
+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
773 |
+
|
774 |
+
if i < 3:
|
775 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
776 |
+
sd_downsample_prefix = f"down.{i}.downsample."
|
777 |
+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
778 |
+
|
779 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
780 |
+
sd_upsample_prefix = f"up.{3-i}.upsample."
|
781 |
+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
782 |
+
|
783 |
+
# up_blocks have three resnets
|
784 |
+
# also, up blocks in hf are numbered in reverse from sd
|
785 |
+
for j in range(3):
|
786 |
+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
787 |
+
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
788 |
+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
789 |
+
|
790 |
+
# this part accounts for mid blocks in both the encoder and the decoder
|
791 |
+
for i in range(2):
|
792 |
+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
793 |
+
sd_mid_res_prefix = f"mid.block_{i+1}."
|
794 |
+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
795 |
+
|
796 |
+
vae_conversion_map_attn = [
|
797 |
+
# (stable-diffusion, HF Diffusers)
|
798 |
+
("norm.", "group_norm."),
|
799 |
+
("q.", "query."),
|
800 |
+
("k.", "key."),
|
801 |
+
("v.", "value."),
|
802 |
+
("proj_out.", "proj_attn."),
|
803 |
+
]
|
804 |
+
|
805 |
+
mapping = {k: k for k in vae_state_dict.keys()}
|
806 |
+
for k, v in mapping.items():
|
807 |
+
for sd_part, hf_part in vae_conversion_map:
|
808 |
+
v = v.replace(hf_part, sd_part)
|
809 |
+
mapping[k] = v
|
810 |
+
for k, v in mapping.items():
|
811 |
+
if "attentions" in k:
|
812 |
+
for sd_part, hf_part in vae_conversion_map_attn:
|
813 |
+
v = v.replace(hf_part, sd_part)
|
814 |
+
mapping[k] = v
|
815 |
+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
816 |
+
weights_to_convert = ["q", "k", "v", "proj_out"]
|
817 |
+
for k, v in new_state_dict.items():
|
818 |
+
for weight_name in weights_to_convert:
|
819 |
+
if f"mid.attn_1.{weight_name}.weight" in k:
|
820 |
+
# print(f"Reshaping {k} for SD format")
|
821 |
+
new_state_dict[k] = reshape_weight_for_sd(v)
|
822 |
+
|
823 |
+
return new_state_dict
|
824 |
+
|
825 |
+
|
826 |
+
# endregion
|
827 |
+
|
828 |
+
# region 自作のモデル読み書きなど
|
829 |
+
|
830 |
+
def is_safetensors(path):
|
831 |
+
return os.path.splitext(path)[1].lower() == '.safetensors'
|
832 |
+
|
833 |
+
|
834 |
+
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
835 |
+
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
836 |
+
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
837 |
+
('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
|
838 |
+
('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
|
839 |
+
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
|
840 |
+
]
|
841 |
+
|
842 |
+
if is_safetensors(ckpt_path):
|
843 |
+
checkpoint = None
|
844 |
+
state_dict = load_file(ckpt_path, "cpu")
|
845 |
+
else:
|
846 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
847 |
+
if "state_dict" in checkpoint:
|
848 |
+
state_dict = checkpoint["state_dict"]
|
849 |
+
else:
|
850 |
+
state_dict = checkpoint
|
851 |
+
checkpoint = None
|
852 |
+
|
853 |
+
key_reps = []
|
854 |
+
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
855 |
+
for key in state_dict.keys():
|
856 |
+
if key.startswith(rep_from):
|
857 |
+
new_key = rep_to + key[len(rep_from):]
|
858 |
+
key_reps.append((key, new_key))
|
859 |
+
|
860 |
+
for key, new_key in key_reps:
|
861 |
+
state_dict[new_key] = state_dict[key]
|
862 |
+
del state_dict[key]
|
863 |
+
|
864 |
+
return checkpoint, state_dict
|
865 |
+
|
866 |
+
|
867 |
+
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
868 |
+
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
869 |
+
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
870 |
+
if dtype is not None:
|
871 |
+
for k, v in state_dict.items():
|
872 |
+
if type(v) is torch.Tensor:
|
873 |
+
state_dict[k] = v.to(dtype)
|
874 |
+
|
875 |
+
# Convert the UNet2DConditionModel model.
|
876 |
+
unet_config = create_unet_diffusers_config(v2)
|
877 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
|
878 |
+
|
879 |
+
unet = UNet2DConditionModel(**unet_config)
|
880 |
+
info = unet.load_state_dict(converted_unet_checkpoint)
|
881 |
+
print("loading u-net:", info)
|
882 |
+
|
883 |
+
# Convert the VAE model.
|
884 |
+
vae_config = create_vae_diffusers_config()
|
885 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
|
886 |
+
|
887 |
+
vae = AutoencoderKL(**vae_config)
|
888 |
+
info = vae.load_state_dict(converted_vae_checkpoint)
|
889 |
+
print("loading vae:", info)
|
890 |
+
|
891 |
+
# convert text_model
|
892 |
+
if v2:
|
893 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
|
894 |
+
cfg = CLIPTextConfig(
|
895 |
+
vocab_size=49408,
|
896 |
+
hidden_size=1024,
|
897 |
+
intermediate_size=4096,
|
898 |
+
num_hidden_layers=23,
|
899 |
+
num_attention_heads=16,
|
900 |
+
max_position_embeddings=77,
|
901 |
+
hidden_act="gelu",
|
902 |
+
layer_norm_eps=1e-05,
|
903 |
+
dropout=0.0,
|
904 |
+
attention_dropout=0.0,
|
905 |
+
initializer_range=0.02,
|
906 |
+
initializer_factor=1.0,
|
907 |
+
pad_token_id=1,
|
908 |
+
bos_token_id=0,
|
909 |
+
eos_token_id=2,
|
910 |
+
model_type="clip_text_model",
|
911 |
+
projection_dim=512,
|
912 |
+
torch_dtype="float32",
|
913 |
+
transformers_version="4.25.0.dev0",
|
914 |
+
)
|
915 |
+
text_model = CLIPTextModel._from_config(cfg)
|
916 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
917 |
+
else:
|
918 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
919 |
+
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
920 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
921 |
+
print("loading text encoder:", info)
|
922 |
+
|
923 |
+
return text_model, vae, unet
|
924 |
+
|
925 |
+
|
926 |
+
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
|
927 |
+
def convert_key(key):
|
928 |
+
# position_idsの除去
|
929 |
+
if ".position_ids" in key:
|
930 |
+
return None
|
931 |
+
|
932 |
+
# common
|
933 |
+
key = key.replace("text_model.encoder.", "transformer.")
|
934 |
+
key = key.replace("text_model.", "")
|
935 |
+
if "layers" in key:
|
936 |
+
# resblocks conversion
|
937 |
+
key = key.replace(".layers.", ".resblocks.")
|
938 |
+
if ".layer_norm" in key:
|
939 |
+
key = key.replace(".layer_norm", ".ln_")
|
940 |
+
elif ".mlp." in key:
|
941 |
+
key = key.replace(".fc1.", ".c_fc.")
|
942 |
+
key = key.replace(".fc2.", ".c_proj.")
|
943 |
+
elif '.self_attn.out_proj' in key:
|
944 |
+
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
945 |
+
elif '.self_attn.' in key:
|
946 |
+
key = None # 特殊なので後で処理する
|
947 |
+
else:
|
948 |
+
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
949 |
+
elif '.position_embedding' in key:
|
950 |
+
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
951 |
+
elif '.token_embedding' in key:
|
952 |
+
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
953 |
+
elif 'final_layer_norm' in key:
|
954 |
+
key = key.replace("final_layer_norm", "ln_final")
|
955 |
+
return key
|
956 |
+
|
957 |
+
keys = list(checkpoint.keys())
|
958 |
+
new_sd = {}
|
959 |
+
for key in keys:
|
960 |
+
new_key = convert_key(key)
|
961 |
+
if new_key is None:
|
962 |
+
continue
|
963 |
+
new_sd[new_key] = checkpoint[key]
|
964 |
+
|
965 |
+
# attnの変換
|
966 |
+
for key in keys:
|
967 |
+
if 'layers' in key and 'q_proj' in key:
|
968 |
+
# 三つを結合
|
969 |
+
key_q = key
|
970 |
+
key_k = key.replace("q_proj", "k_proj")
|
971 |
+
key_v = key.replace("q_proj", "v_proj")
|
972 |
+
|
973 |
+
value_q = checkpoint[key_q]
|
974 |
+
value_k = checkpoint[key_k]
|
975 |
+
value_v = checkpoint[key_v]
|
976 |
+
value = torch.cat([value_q, value_k, value_v])
|
977 |
+
|
978 |
+
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
979 |
+
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
980 |
+
new_sd[new_key] = value
|
981 |
+
|
982 |
+
# 最後の層などを捏造するか
|
983 |
+
if make_dummy_weights:
|
984 |
+
print("make dummy weights for resblock.23, text_projection and logit scale.")
|
985 |
+
keys = list(new_sd.keys())
|
986 |
+
for key in keys:
|
987 |
+
if key.startswith("transformer.resblocks.22."):
|
988 |
+
new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
|
989 |
+
|
990 |
+
# Diffusersに含まれない重みを作っておく
|
991 |
+
new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
|
992 |
+
new_sd['logit_scale'] = torch.tensor(1)
|
993 |
+
|
994 |
+
return new_sd
|
995 |
+
|
996 |
+
|
997 |
+
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
|
998 |
+
if ckpt_path is not None:
|
999 |
+
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
1000 |
+
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
1001 |
+
if checkpoint is None: # safetensors または state_dictのckpt
|
1002 |
+
checkpoint = {}
|
1003 |
+
strict = False
|
1004 |
+
else:
|
1005 |
+
strict = True
|
1006 |
+
if "state_dict" in state_dict:
|
1007 |
+
del state_dict["state_dict"]
|
1008 |
+
else:
|
1009 |
+
# 新しく作る
|
1010 |
+
assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
|
1011 |
+
checkpoint = {}
|
1012 |
+
state_dict = {}
|
1013 |
+
strict = False
|
1014 |
+
|
1015 |
+
def update_sd(prefix, sd):
|
1016 |
+
for k, v in sd.items():
|
1017 |
+
key = prefix + k
|
1018 |
+
assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
|
1019 |
+
if save_dtype is not None:
|
1020 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
1021 |
+
state_dict[key] = v
|
1022 |
+
|
1023 |
+
# Convert the UNet model
|
1024 |
+
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
1025 |
+
update_sd("model.diffusion_model.", unet_state_dict)
|
1026 |
+
|
1027 |
+
# Convert the text encoder model
|
1028 |
+
if v2:
|
1029 |
+
make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製��て作るなどダミーの重みを入れる
|
1030 |
+
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
|
1031 |
+
update_sd("cond_stage_model.model.", text_enc_dict)
|
1032 |
+
else:
|
1033 |
+
text_enc_dict = text_encoder.state_dict()
|
1034 |
+
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
1035 |
+
|
1036 |
+
# Convert the VAE
|
1037 |
+
if vae is not None:
|
1038 |
+
vae_dict = convert_vae_state_dict(vae.state_dict())
|
1039 |
+
update_sd("first_stage_model.", vae_dict)
|
1040 |
+
|
1041 |
+
# Put together new checkpoint
|
1042 |
+
key_count = len(state_dict.keys())
|
1043 |
+
new_ckpt = {'state_dict': state_dict}
|
1044 |
+
|
1045 |
+
if 'epoch' in checkpoint:
|
1046 |
+
epochs += checkpoint['epoch']
|
1047 |
+
if 'global_step' in checkpoint:
|
1048 |
+
steps += checkpoint['global_step']
|
1049 |
+
|
1050 |
+
new_ckpt['epoch'] = epochs
|
1051 |
+
new_ckpt['global_step'] = steps
|
1052 |
+
|
1053 |
+
if is_safetensors(output_file):
|
1054 |
+
# TODO Tensor以外のdictの値を削除したほうがいいか
|
1055 |
+
save_file(state_dict, output_file)
|
1056 |
+
else:
|
1057 |
+
torch.save(new_ckpt, output_file)
|
1058 |
+
|
1059 |
+
return key_count
|
1060 |
+
|
1061 |
+
|
1062 |
+
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
|
1063 |
+
if pretrained_model_name_or_path is None:
|
1064 |
+
# load default settings for v1/v2
|
1065 |
+
if v2:
|
1066 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
|
1067 |
+
else:
|
1068 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
|
1069 |
+
|
1070 |
+
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
1071 |
+
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
1072 |
+
if vae is None:
|
1073 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
1074 |
+
|
1075 |
+
pipeline = StableDiffusionPipeline(
|
1076 |
+
unet=unet,
|
1077 |
+
text_encoder=text_encoder,
|
1078 |
+
vae=vae,
|
1079 |
+
scheduler=scheduler,
|
1080 |
+
tokenizer=tokenizer,
|
1081 |
+
safety_checker=None,
|
1082 |
+
feature_extractor=None,
|
1083 |
+
requires_safety_checker=None,
|
1084 |
+
)
|
1085 |
+
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
1086 |
+
|
1087 |
+
|
1088 |
+
VAE_PREFIX = "first_stage_model."
|
1089 |
+
|
1090 |
+
|
1091 |
+
def load_vae(vae_id, dtype):
|
1092 |
+
print(f"load VAE: {vae_id}")
|
1093 |
+
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
1094 |
+
# Diffusers local/remote
|
1095 |
+
try:
|
1096 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
|
1097 |
+
except EnvironmentError as e:
|
1098 |
+
print(f"exception occurs in loading vae: {e}")
|
1099 |
+
print("retry with subfolder='vae'")
|
1100 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
|
1101 |
+
return vae
|
1102 |
+
|
1103 |
+
# local
|
1104 |
+
vae_config = create_vae_diffusers_config()
|
1105 |
+
|
1106 |
+
if vae_id.endswith(".bin"):
|
1107 |
+
# SD 1.5 VAE on Huggingface
|
1108 |
+
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
|
1109 |
+
else:
|
1110 |
+
# StableDiffusion
|
1111 |
+
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
|
1112 |
+
else torch.load(vae_id, map_location="cpu"))
|
1113 |
+
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
|
1114 |
+
|
1115 |
+
# vae only or full model
|
1116 |
+
full_model = False
|
1117 |
+
for vae_key in vae_sd:
|
1118 |
+
if vae_key.startswith(VAE_PREFIX):
|
1119 |
+
full_model = True
|
1120 |
+
break
|
1121 |
+
if not full_model:
|
1122 |
+
sd = {}
|
1123 |
+
for key, value in vae_sd.items():
|
1124 |
+
sd[VAE_PREFIX + key] = value
|
1125 |
+
vae_sd = sd
|
1126 |
+
del sd
|
1127 |
+
|
1128 |
+
# Convert the VAE model.
|
1129 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
1130 |
+
|
1131 |
+
vae = AutoencoderKL(**vae_config)
|
1132 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
1133 |
+
return vae
|
1134 |
+
|
1135 |
+
# endregion
|
1136 |
+
|
1137 |
+
|
1138 |
+
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
|
1139 |
+
max_width, max_height = max_reso
|
1140 |
+
max_area = (max_width // divisible) * (max_height // divisible)
|
1141 |
+
|
1142 |
+
resos = set()
|
1143 |
+
|
1144 |
+
size = int(math.sqrt(max_area)) * divisible
|
1145 |
+
resos.add((size, size))
|
1146 |
+
|
1147 |
+
size = min_size
|
1148 |
+
while size <= max_size:
|
1149 |
+
width = size
|
1150 |
+
height = min(max_size, (max_area // (width // divisible)) * divisible)
|
1151 |
+
resos.add((width, height))
|
1152 |
+
resos.add((height, width))
|
1153 |
+
|
1154 |
+
# # make additional resos
|
1155 |
+
# if width >= height and width - divisible >= min_size:
|
1156 |
+
# resos.add((width - divisible, height))
|
1157 |
+
# resos.add((height, width - divisible))
|
1158 |
+
# if height >= width and height - divisible >= min_size:
|
1159 |
+
# resos.add((width, height - divisible))
|
1160 |
+
# resos.add((height - divisible, width))
|
1161 |
+
|
1162 |
+
size += divisible
|
1163 |
+
|
1164 |
+
resos = list(resos)
|
1165 |
+
resos.sort()
|
1166 |
+
return resos
|
1167 |
+
|
1168 |
+
|
1169 |
+
if __name__ == '__main__':
|
1170 |
+
resos = make_bucket_resolutions((512, 768))
|
1171 |
+
print(len(resos))
|
1172 |
+
print(resos)
|
1173 |
+
aspect_ratios = [w / h for w, h in resos]
|
1174 |
+
print(aspect_ratios)
|
1175 |
+
|
1176 |
+
ars = set()
|
1177 |
+
for ar in aspect_ratios:
|
1178 |
+
if ar in ars:
|
1179 |
+
print("error! duplicate ar:", ar)
|
1180 |
+
ars.add(ar)
|
library/train_util.py
ADDED
@@ -0,0 +1,1796 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# common functions for training
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import shutil
|
6 |
+
import time
|
7 |
+
from typing import Dict, List, NamedTuple, Tuple
|
8 |
+
from accelerate import Accelerator
|
9 |
+
from torch.autograd.function import Function
|
10 |
+
import glob
|
11 |
+
import math
|
12 |
+
import os
|
13 |
+
import random
|
14 |
+
import hashlib
|
15 |
+
import subprocess
|
16 |
+
from io import BytesIO
|
17 |
+
|
18 |
+
from tqdm import tqdm
|
19 |
+
import torch
|
20 |
+
from torchvision import transforms
|
21 |
+
from transformers import CLIPTokenizer
|
22 |
+
import diffusers
|
23 |
+
from diffusers import DDPMScheduler, StableDiffusionPipeline
|
24 |
+
import albumentations as albu
|
25 |
+
import numpy as np
|
26 |
+
from PIL import Image
|
27 |
+
import cv2
|
28 |
+
from einops import rearrange
|
29 |
+
from torch import einsum
|
30 |
+
import safetensors.torch
|
31 |
+
|
32 |
+
import library.model_util as model_util
|
33 |
+
|
34 |
+
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
35 |
+
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
36 |
+
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
|
37 |
+
|
38 |
+
# checkpointファイル名
|
39 |
+
EPOCH_STATE_NAME = "{}-{:06d}-state"
|
40 |
+
EPOCH_FILE_NAME = "{}-{:06d}"
|
41 |
+
EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}"
|
42 |
+
LAST_STATE_NAME = "{}-state"
|
43 |
+
DEFAULT_EPOCH_NAME = "epoch"
|
44 |
+
DEFAULT_LAST_OUTPUT_NAME = "last"
|
45 |
+
|
46 |
+
# region dataset
|
47 |
+
|
48 |
+
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
|
49 |
+
# , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux?
|
50 |
+
|
51 |
+
|
52 |
+
class ImageInfo():
|
53 |
+
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
54 |
+
self.image_key: str = image_key
|
55 |
+
self.num_repeats: int = num_repeats
|
56 |
+
self.caption: str = caption
|
57 |
+
self.is_reg: bool = is_reg
|
58 |
+
self.absolute_path: str = absolute_path
|
59 |
+
self.image_size: Tuple[int, int] = None
|
60 |
+
self.resized_size: Tuple[int, int] = None
|
61 |
+
self.bucket_reso: Tuple[int, int] = None
|
62 |
+
self.latents: torch.Tensor = None
|
63 |
+
self.latents_flipped: torch.Tensor = None
|
64 |
+
self.latents_npz: str = None
|
65 |
+
self.latents_npz_flipped: str = None
|
66 |
+
|
67 |
+
|
68 |
+
class BucketManager():
|
69 |
+
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
|
70 |
+
self.no_upscale = no_upscale
|
71 |
+
if max_reso is None:
|
72 |
+
self.max_reso = None
|
73 |
+
self.max_area = None
|
74 |
+
else:
|
75 |
+
self.max_reso = max_reso
|
76 |
+
self.max_area = max_reso[0] * max_reso[1]
|
77 |
+
self.min_size = min_size
|
78 |
+
self.max_size = max_size
|
79 |
+
self.reso_steps = reso_steps
|
80 |
+
|
81 |
+
self.resos = []
|
82 |
+
self.reso_to_id = {}
|
83 |
+
self.buckets = [] # 前処理時は (image_key, image)、学習時は image_key
|
84 |
+
|
85 |
+
def add_image(self, reso, image):
|
86 |
+
bucket_id = self.reso_to_id[reso]
|
87 |
+
self.buckets[bucket_id].append(image)
|
88 |
+
|
89 |
+
def shuffle(self):
|
90 |
+
for bucket in self.buckets:
|
91 |
+
random.shuffle(bucket)
|
92 |
+
|
93 |
+
def sort(self):
|
94 |
+
# 解像度順にソートする(表示時、メタデータ格納時の見栄えをよくするためだけ)。bucketsも入れ替えてreso_to_idも振り直す
|
95 |
+
sorted_resos = self.resos.copy()
|
96 |
+
sorted_resos.sort()
|
97 |
+
|
98 |
+
sorted_buckets = []
|
99 |
+
sorted_reso_to_id = {}
|
100 |
+
for i, reso in enumerate(sorted_resos):
|
101 |
+
bucket_id = self.reso_to_id[reso]
|
102 |
+
sorted_buckets.append(self.buckets[bucket_id])
|
103 |
+
sorted_reso_to_id[reso] = i
|
104 |
+
|
105 |
+
self.resos = sorted_resos
|
106 |
+
self.buckets = sorted_buckets
|
107 |
+
self.reso_to_id = sorted_reso_to_id
|
108 |
+
|
109 |
+
def make_buckets(self):
|
110 |
+
resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps)
|
111 |
+
self.set_predefined_resos(resos)
|
112 |
+
|
113 |
+
def set_predefined_resos(self, resos):
|
114 |
+
# 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
|
115 |
+
self.predefined_resos = resos.copy()
|
116 |
+
self.predefined_resos_set = set(resos)
|
117 |
+
self.predefined_aspect_ratios = np.array([w / h for w, h in resos])
|
118 |
+
|
119 |
+
def add_if_new_reso(self, reso):
|
120 |
+
if reso not in self.reso_to_id:
|
121 |
+
bucket_id = len(self.resos)
|
122 |
+
self.reso_to_id[reso] = bucket_id
|
123 |
+
self.resos.append(reso)
|
124 |
+
self.buckets.append([])
|
125 |
+
# print(reso, bucket_id, len(self.buckets))
|
126 |
+
|
127 |
+
def round_to_steps(self, x):
|
128 |
+
x = int(x + .5)
|
129 |
+
return x - x % self.reso_steps
|
130 |
+
|
131 |
+
def select_bucket(self, image_width, image_height):
|
132 |
+
aspect_ratio = image_width / image_height
|
133 |
+
if not self.no_upscale:
|
134 |
+
# 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する
|
135 |
+
reso = (image_width, image_height)
|
136 |
+
if reso in self.predefined_resos_set:
|
137 |
+
pass
|
138 |
+
else:
|
139 |
+
ar_errors = self.predefined_aspect_ratios - aspect_ratio
|
140 |
+
predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
|
141 |
+
reso = self.predefined_resos[predefined_bucket_id]
|
142 |
+
|
143 |
+
ar_reso = reso[0] / reso[1]
|
144 |
+
if aspect_ratio > ar_reso: # 横が長い→縦を合わせる
|
145 |
+
scale = reso[1] / image_height
|
146 |
+
else:
|
147 |
+
scale = reso[0] / image_width
|
148 |
+
|
149 |
+
resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
|
150 |
+
# print("use predef", image_width, image_height, reso, resized_size)
|
151 |
+
else:
|
152 |
+
if image_width * image_height > self.max_area:
|
153 |
+
# 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める
|
154 |
+
resized_width = math.sqrt(self.max_area * aspect_ratio)
|
155 |
+
resized_height = self.max_area / resized_width
|
156 |
+
assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal"
|
157 |
+
|
158 |
+
# リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ
|
159 |
+
# 元のbucketingと同じロジック
|
160 |
+
b_width_rounded = self.round_to_steps(resized_width)
|
161 |
+
b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio)
|
162 |
+
ar_width_rounded = b_width_rounded / b_height_in_wr
|
163 |
+
|
164 |
+
b_height_rounded = self.round_to_steps(resized_height)
|
165 |
+
b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio)
|
166 |
+
ar_height_rounded = b_width_in_hr / b_height_rounded
|
167 |
+
|
168 |
+
# print(b_width_rounded, b_height_in_wr, ar_width_rounded)
|
169 |
+
# print(b_width_in_hr, b_height_rounded, ar_height_rounded)
|
170 |
+
|
171 |
+
if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio):
|
172 |
+
resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + .5))
|
173 |
+
else:
|
174 |
+
resized_size = (int(b_height_rounded * aspect_ratio + .5), b_height_rounded)
|
175 |
+
# print(resized_size)
|
176 |
+
else:
|
177 |
+
resized_size = (image_width, image_height) # リサイズは不要
|
178 |
+
|
179 |
+
# 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする)
|
180 |
+
bucket_width = resized_size[0] - resized_size[0] % self.reso_steps
|
181 |
+
bucket_height = resized_size[1] - resized_size[1] % self.reso_steps
|
182 |
+
# print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height)
|
183 |
+
|
184 |
+
reso = (bucket_width, bucket_height)
|
185 |
+
|
186 |
+
self.add_if_new_reso(reso)
|
187 |
+
|
188 |
+
ar_error = (reso[0] / reso[1]) - aspect_ratio
|
189 |
+
return reso, resized_size, ar_error
|
190 |
+
|
191 |
+
|
192 |
+
class BucketBatchIndex(NamedTuple):
|
193 |
+
bucket_index: int
|
194 |
+
bucket_batch_size: int
|
195 |
+
batch_index: int
|
196 |
+
|
197 |
+
|
198 |
+
class BaseDataset(torch.utils.data.Dataset):
|
199 |
+
def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
|
200 |
+
super().__init__()
|
201 |
+
self.tokenizer: CLIPTokenizer = tokenizer
|
202 |
+
self.max_token_length = max_token_length
|
203 |
+
self.shuffle_caption = shuffle_caption
|
204 |
+
self.shuffle_keep_tokens = shuffle_keep_tokens
|
205 |
+
# width/height is used when enable_bucket==False
|
206 |
+
self.width, self.height = (None, None) if resolution is None else resolution
|
207 |
+
self.face_crop_aug_range = face_crop_aug_range
|
208 |
+
self.flip_aug = flip_aug
|
209 |
+
self.color_aug = color_aug
|
210 |
+
self.debug_dataset = debug_dataset
|
211 |
+
self.random_crop = random_crop
|
212 |
+
self.token_padding_disabled = False
|
213 |
+
self.dataset_dirs_info = {}
|
214 |
+
self.reg_dataset_dirs_info = {}
|
215 |
+
self.tag_frequency = {}
|
216 |
+
|
217 |
+
self.enable_bucket = False
|
218 |
+
self.bucket_manager: BucketManager = None # not initialized
|
219 |
+
self.min_bucket_reso = None
|
220 |
+
self.max_bucket_reso = None
|
221 |
+
self.bucket_reso_steps = None
|
222 |
+
self.bucket_no_upscale = None
|
223 |
+
self.bucket_info = None # for metadata
|
224 |
+
|
225 |
+
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
226 |
+
|
227 |
+
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
228 |
+
self.dropout_rate: float = 0
|
229 |
+
self.dropout_every_n_epochs: int = None
|
230 |
+
self.tag_dropout_rate: float = 0
|
231 |
+
|
232 |
+
# augmentation
|
233 |
+
flip_p = 0.5 if flip_aug else 0.0
|
234 |
+
if color_aug:
|
235 |
+
# わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
|
236 |
+
self.aug = albu.Compose([
|
237 |
+
albu.OneOf([
|
238 |
+
albu.HueSaturationValue(8, 0, 0, p=.5),
|
239 |
+
albu.RandomGamma((95, 105), p=.5),
|
240 |
+
], p=.33),
|
241 |
+
albu.HorizontalFlip(p=flip_p)
|
242 |
+
], p=1.)
|
243 |
+
elif flip_aug:
|
244 |
+
self.aug = albu.Compose([
|
245 |
+
albu.HorizontalFlip(p=flip_p)
|
246 |
+
], p=1.)
|
247 |
+
else:
|
248 |
+
self.aug = None
|
249 |
+
|
250 |
+
self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
|
251 |
+
|
252 |
+
self.image_data: Dict[str, ImageInfo] = {}
|
253 |
+
|
254 |
+
self.replacements = {}
|
255 |
+
|
256 |
+
def set_current_epoch(self, epoch):
|
257 |
+
self.current_epoch = epoch
|
258 |
+
|
259 |
+
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
|
260 |
+
# コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
|
261 |
+
self.dropout_rate = dropout_rate
|
262 |
+
self.dropout_every_n_epochs = dropout_every_n_epochs
|
263 |
+
self.tag_dropout_rate = tag_dropout_rate
|
264 |
+
|
265 |
+
def set_tag_frequency(self, dir_name, captions):
|
266 |
+
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
267 |
+
self.tag_frequency[dir_name] = frequency_for_dir
|
268 |
+
for caption in captions:
|
269 |
+
for tag in caption.split(","):
|
270 |
+
if tag and not tag.isspace():
|
271 |
+
tag = tag.lower()
|
272 |
+
frequency = frequency_for_dir.get(tag, 0)
|
273 |
+
frequency_for_dir[tag] = frequency + 1
|
274 |
+
|
275 |
+
def disable_token_padding(self):
|
276 |
+
self.token_padding_disabled = True
|
277 |
+
|
278 |
+
def add_replacement(self, str_from, str_to):
|
279 |
+
self.replacements[str_from] = str_to
|
280 |
+
|
281 |
+
def process_caption(self, caption):
|
282 |
+
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
283 |
+
is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
|
284 |
+
is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
|
285 |
+
|
286 |
+
if is_drop_out:
|
287 |
+
caption = ""
|
288 |
+
else:
|
289 |
+
if self.shuffle_caption or self.tag_dropout_rate > 0:
|
290 |
+
def dropout_tags(tokens):
|
291 |
+
if self.tag_dropout_rate <= 0:
|
292 |
+
return tokens
|
293 |
+
l = []
|
294 |
+
for token in tokens:
|
295 |
+
if random.random() >= self.tag_dropout_rate:
|
296 |
+
l.append(token)
|
297 |
+
return l
|
298 |
+
|
299 |
+
tokens = [t.strip() for t in caption.strip().split(",")]
|
300 |
+
if self.shuffle_keep_tokens is None:
|
301 |
+
if self.shuffle_caption:
|
302 |
+
random.shuffle(tokens)
|
303 |
+
|
304 |
+
tokens = dropout_tags(tokens)
|
305 |
+
else:
|
306 |
+
if len(tokens) > self.shuffle_keep_tokens:
|
307 |
+
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
308 |
+
tokens = tokens[self.shuffle_keep_tokens:]
|
309 |
+
|
310 |
+
if self.shuffle_caption:
|
311 |
+
random.shuffle(tokens)
|
312 |
+
|
313 |
+
tokens = dropout_tags(tokens)
|
314 |
+
|
315 |
+
tokens = keep_tokens + tokens
|
316 |
+
caption = ", ".join(tokens)
|
317 |
+
|
318 |
+
# textual inversion対応
|
319 |
+
for str_from, str_to in self.replacements.items():
|
320 |
+
if str_from == "":
|
321 |
+
# replace all
|
322 |
+
if type(str_to) == list:
|
323 |
+
caption = random.choice(str_to)
|
324 |
+
else:
|
325 |
+
caption = str_to
|
326 |
+
else:
|
327 |
+
caption = caption.replace(str_from, str_to)
|
328 |
+
|
329 |
+
return caption
|
330 |
+
|
331 |
+
def get_input_ids(self, caption):
|
332 |
+
input_ids = self.tokenizer(caption, padding="max_length", truncation=True,
|
333 |
+
max_length=self.tokenizer_max_length, return_tensors="pt").input_ids
|
334 |
+
|
335 |
+
if self.tokenizer_max_length > self.tokenizer.model_max_length:
|
336 |
+
input_ids = input_ids.squeeze(0)
|
337 |
+
iids_list = []
|
338 |
+
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
|
339 |
+
# v1
|
340 |
+
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
|
341 |
+
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
|
342 |
+
for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75)
|
343 |
+
ids_chunk = (input_ids[0].unsqueeze(0),
|
344 |
+
input_ids[i:i + self.tokenizer.model_max_length - 2],
|
345 |
+
input_ids[-1].unsqueeze(0))
|
346 |
+
ids_chunk = torch.cat(ids_chunk)
|
347 |
+
iids_list.append(ids_chunk)
|
348 |
+
else:
|
349 |
+
# v2
|
350 |
+
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
|
351 |
+
for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2):
|
352 |
+
ids_chunk = (input_ids[0].unsqueeze(0), # BOS
|
353 |
+
input_ids[i:i + self.tokenizer.model_max_length - 2],
|
354 |
+
input_ids[-1].unsqueeze(0)) # PAD or EOS
|
355 |
+
ids_chunk = torch.cat(ids_chunk)
|
356 |
+
|
357 |
+
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
|
358 |
+
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
|
359 |
+
if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id:
|
360 |
+
ids_chunk[-1] = self.tokenizer.eos_token_id
|
361 |
+
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
|
362 |
+
if ids_chunk[1] == self.tokenizer.pad_token_id:
|
363 |
+
ids_chunk[1] = self.tokenizer.eos_token_id
|
364 |
+
|
365 |
+
iids_list.append(ids_chunk)
|
366 |
+
|
367 |
+
input_ids = torch.stack(iids_list) # 3,77
|
368 |
+
return input_ids
|
369 |
+
|
370 |
+
def register_image(self, info: ImageInfo):
|
371 |
+
self.image_data[info.image_key] = info
|
372 |
+
|
373 |
+
def make_buckets(self):
|
374 |
+
'''
|
375 |
+
bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
|
376 |
+
min_size and max_size are ignored when enable_bucket is False
|
377 |
+
'''
|
378 |
+
print("loading image sizes.")
|
379 |
+
for info in tqdm(self.image_data.values()):
|
380 |
+
if info.image_size is None:
|
381 |
+
info.image_size = self.get_image_size(info.absolute_path)
|
382 |
+
|
383 |
+
if self.enable_bucket:
|
384 |
+
print("make buckets")
|
385 |
+
else:
|
386 |
+
print("prepare dataset")
|
387 |
+
|
388 |
+
# bucketを作成し、画像をbucketに振り分ける
|
389 |
+
if self.enable_bucket:
|
390 |
+
if self.bucket_manager is None: # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み
|
391 |
+
self.bucket_manager = BucketManager(self.bucket_no_upscale, (self.width, self.height),
|
392 |
+
self.min_bucket_reso, self.max_bucket_reso, self.bucket_reso_steps)
|
393 |
+
if not self.bucket_no_upscale:
|
394 |
+
self.bucket_manager.make_buckets()
|
395 |
+
else:
|
396 |
+
print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
|
397 |
+
|
398 |
+
img_ar_errors = []
|
399 |
+
for image_info in self.image_data.values():
|
400 |
+
image_width, image_height = image_info.image_size
|
401 |
+
image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(image_width, image_height)
|
402 |
+
|
403 |
+
# print(image_info.image_key, image_info.bucket_reso)
|
404 |
+
img_ar_errors.append(abs(ar_error))
|
405 |
+
|
406 |
+
self.bucket_manager.sort()
|
407 |
+
else:
|
408 |
+
self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None)
|
409 |
+
self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ
|
410 |
+
for image_info in self.image_data.values():
|
411 |
+
image_width, image_height = image_info.image_size
|
412 |
+
image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height)
|
413 |
+
|
414 |
+
for image_info in self.image_data.values():
|
415 |
+
for _ in range(image_info.num_repeats):
|
416 |
+
self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key)
|
417 |
+
|
418 |
+
# bucket情報を表示、格納する
|
419 |
+
if self.enable_bucket:
|
420 |
+
self.bucket_info = {"buckets": {}}
|
421 |
+
print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
|
422 |
+
for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)):
|
423 |
+
count = len(bucket)
|
424 |
+
if count > 0:
|
425 |
+
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
|
426 |
+
print(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
|
427 |
+
|
428 |
+
img_ar_errors = np.array(img_ar_errors)
|
429 |
+
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
|
430 |
+
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
|
431 |
+
print(f"mean ar error (without repeats): {mean_img_ar_error}")
|
432 |
+
|
433 |
+
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
|
434 |
+
self.buckets_indices: List(BucketBatchIndex) = []
|
435 |
+
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
|
436 |
+
batch_count = int(math.ceil(len(bucket) / self.batch_size))
|
437 |
+
for batch_index in range(batch_count):
|
438 |
+
self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index))
|
439 |
+
|
440 |
+
# ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す
|
441 |
+
# 学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる
|
442 |
+
#
|
443 |
+
# # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
|
444 |
+
# # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
|
445 |
+
# # そのためバッチサイズを画像種類までに制限する
|
446 |
+
# # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
|
447 |
+
# # TO DO 正則化画像をepochまたがりで利用する仕組み
|
448 |
+
# num_of_image_types = len(set(bucket))
|
449 |
+
# bucket_batch_size = min(self.batch_size, num_of_image_types)
|
450 |
+
# batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
|
451 |
+
# # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
|
452 |
+
# for batch_index in range(batch_count):
|
453 |
+
# self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
|
454 |
+
# ↑ここまで
|
455 |
+
|
456 |
+
self.shuffle_buckets()
|
457 |
+
self._length = len(self.buckets_indices)
|
458 |
+
|
459 |
+
def shuffle_buckets(self):
|
460 |
+
random.shuffle(self.buckets_indices)
|
461 |
+
self.bucket_manager.shuffle()
|
462 |
+
|
463 |
+
def load_image(self, image_path):
|
464 |
+
image = Image.open(image_path)
|
465 |
+
if not image.mode == "RGB":
|
466 |
+
image = image.convert("RGB")
|
467 |
+
img = np.array(image, np.uint8)
|
468 |
+
return img
|
469 |
+
|
470 |
+
def trim_and_resize_if_required(self, image, reso, resized_size):
|
471 |
+
image_height, image_width = image.shape[0:2]
|
472 |
+
|
473 |
+
if image_width != resized_size[0] or image_height != resized_size[1]:
|
474 |
+
# リサイズする
|
475 |
+
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
476 |
+
|
477 |
+
image_height, image_width = image.shape[0:2]
|
478 |
+
if image_width > reso[0]:
|
479 |
+
trim_size = image_width - reso[0]
|
480 |
+
p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
|
481 |
+
# print("w", trim_size, p)
|
482 |
+
image = image[:, p:p + reso[0]]
|
483 |
+
if image_height > reso[1]:
|
484 |
+
trim_size = image_height - reso[1]
|
485 |
+
p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
|
486 |
+
# print("h", trim_size, p)
|
487 |
+
image = image[p:p + reso[1]]
|
488 |
+
|
489 |
+
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
490 |
+
return image
|
491 |
+
|
492 |
+
def cache_latents(self, vae):
|
493 |
+
# TODO ここを高速化したい
|
494 |
+
print("caching latents.")
|
495 |
+
for info in tqdm(self.image_data.values()):
|
496 |
+
if info.latents_npz is not None:
|
497 |
+
info.latents = self.load_latents_from_npz(info, False)
|
498 |
+
info.latents = torch.FloatTensor(info.latents)
|
499 |
+
info.latents_flipped = self.load_latents_from_npz(info, True) # might be None
|
500 |
+
if info.latents_flipped is not None:
|
501 |
+
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
|
502 |
+
continue
|
503 |
+
|
504 |
+
image = self.load_image(info.absolute_path)
|
505 |
+
image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
|
506 |
+
|
507 |
+
img_tensor = self.image_transforms(image)
|
508 |
+
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
509 |
+
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
510 |
+
|
511 |
+
if self.flip_aug:
|
512 |
+
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
|
513 |
+
img_tensor = self.image_transforms(image)
|
514 |
+
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
515 |
+
info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
516 |
+
|
517 |
+
def get_image_size(self, image_path):
|
518 |
+
image = Image.open(image_path)
|
519 |
+
return image.size
|
520 |
+
|
521 |
+
def load_image_with_face_info(self, image_path: str):
|
522 |
+
img = self.load_image(image_path)
|
523 |
+
|
524 |
+
face_cx = face_cy = face_w = face_h = 0
|
525 |
+
if self.face_crop_aug_range is not None:
|
526 |
+
tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
|
527 |
+
if len(tokens) >= 5:
|
528 |
+
face_cx = int(tokens[-4])
|
529 |
+
face_cy = int(tokens[-3])
|
530 |
+
face_w = int(tokens[-2])
|
531 |
+
face_h = int(tokens[-1])
|
532 |
+
|
533 |
+
return img, face_cx, face_cy, face_w, face_h
|
534 |
+
|
535 |
+
# いい感じに切り出す
|
536 |
+
def crop_target(self, image, face_cx, face_cy, face_w, face_h):
|
537 |
+
height, width = image.shape[0:2]
|
538 |
+
if height == self.height and width == self.width:
|
539 |
+
return image
|
540 |
+
|
541 |
+
# 画像サイズはsizeより大きいのでリサイズする
|
542 |
+
face_size = max(face_w, face_h)
|
543 |
+
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
|
544 |
+
min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
|
545 |
+
max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
|
546 |
+
if min_scale >= max_scale: # range指定がmin==max
|
547 |
+
scale = min_scale
|
548 |
+
else:
|
549 |
+
scale = random.uniform(min_scale, max_scale)
|
550 |
+
|
551 |
+
nh = int(height * scale + .5)
|
552 |
+
nw = int(width * scale + .5)
|
553 |
+
assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
|
554 |
+
image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
|
555 |
+
face_cx = int(face_cx * scale + .5)
|
556 |
+
face_cy = int(face_cy * scale + .5)
|
557 |
+
height, width = nh, nw
|
558 |
+
|
559 |
+
# 顔を中心として448*640とかへ切り出す
|
560 |
+
for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
|
561 |
+
p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
|
562 |
+
|
563 |
+
if self.random_crop:
|
564 |
+
# 背景も含めるために顔を中心に置く確率を高めつつずらす
|
565 |
+
range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
|
566 |
+
p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
|
567 |
+
else:
|
568 |
+
# range指定があるときのみ、すこしだけランダムに(わりと適当)
|
569 |
+
if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
|
570 |
+
if face_size > self.size // 10 and face_size >= 40:
|
571 |
+
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
|
572 |
+
|
573 |
+
p1 = max(0, min(p1, length - target_size))
|
574 |
+
|
575 |
+
if axis == 0:
|
576 |
+
image = image[p1:p1 + target_size, :]
|
577 |
+
else:
|
578 |
+
image = image[:, p1:p1 + target_size]
|
579 |
+
|
580 |
+
return image
|
581 |
+
|
582 |
+
def load_latents_from_npz(self, image_info: ImageInfo, flipped):
|
583 |
+
npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
|
584 |
+
if npz_file is None:
|
585 |
+
return None
|
586 |
+
return np.load(npz_file)['arr_0']
|
587 |
+
|
588 |
+
def __len__(self):
|
589 |
+
return self._length
|
590 |
+
|
591 |
+
def __getitem__(self, index):
|
592 |
+
if index == 0:
|
593 |
+
self.shuffle_buckets()
|
594 |
+
|
595 |
+
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
596 |
+
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
597 |
+
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
598 |
+
|
599 |
+
loss_weights = []
|
600 |
+
captions = []
|
601 |
+
input_ids_list = []
|
602 |
+
latents_list = []
|
603 |
+
images = []
|
604 |
+
|
605 |
+
for image_key in bucket[image_index:image_index + bucket_batch_size]:
|
606 |
+
image_info = self.image_data[image_key]
|
607 |
+
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
608 |
+
|
609 |
+
# image/latentsを処理する
|
610 |
+
if image_info.latents is not None:
|
611 |
+
latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped
|
612 |
+
image = None
|
613 |
+
elif image_info.latents_npz is not None:
|
614 |
+
latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5)
|
615 |
+
latents = torch.FloatTensor(latents)
|
616 |
+
image = None
|
617 |
+
else:
|
618 |
+
# 画像を読み込み、必要ならcropする
|
619 |
+
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
|
620 |
+
im_h, im_w = img.shape[0:2]
|
621 |
+
|
622 |
+
if self.enable_bucket:
|
623 |
+
img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
|
624 |
+
else:
|
625 |
+
if face_cx > 0: # 顔位置情報あり
|
626 |
+
img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
|
627 |
+
elif im_h > self.height or im_w > self.width:
|
628 |
+
assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
|
629 |
+
if im_h > self.height:
|
630 |
+
p = random.randint(0, im_h - self.height)
|
631 |
+
img = img[p:p + self.height]
|
632 |
+
if im_w > self.width:
|
633 |
+
p = random.randint(0, im_w - self.width)
|
634 |
+
img = img[:, p:p + self.width]
|
635 |
+
|
636 |
+
im_h, im_w = img.shape[0:2]
|
637 |
+
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
638 |
+
|
639 |
+
# augmentation
|
640 |
+
if self.aug is not None:
|
641 |
+
img = self.aug(image=img)['image']
|
642 |
+
|
643 |
+
latents = None
|
644 |
+
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
645 |
+
|
646 |
+
images.append(image)
|
647 |
+
latents_list.append(latents)
|
648 |
+
|
649 |
+
caption = self.process_caption(image_info.caption)
|
650 |
+
captions.append(caption)
|
651 |
+
if not self.token_padding_disabled: # this option might be omitted in future
|
652 |
+
input_ids_list.append(self.get_input_ids(caption))
|
653 |
+
|
654 |
+
example = {}
|
655 |
+
example['loss_weights'] = torch.FloatTensor(loss_weights)
|
656 |
+
|
657 |
+
if self.token_padding_disabled:
|
658 |
+
# padding=True means pad in the batch
|
659 |
+
example['input_ids'] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
660 |
+
else:
|
661 |
+
# batch processing seems to be good
|
662 |
+
example['input_ids'] = torch.stack(input_ids_list)
|
663 |
+
|
664 |
+
if images[0] is not None:
|
665 |
+
images = torch.stack(images)
|
666 |
+
images = images.to(memory_format=torch.contiguous_format).float()
|
667 |
+
else:
|
668 |
+
images = None
|
669 |
+
example['images'] = images
|
670 |
+
|
671 |
+
example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None
|
672 |
+
|
673 |
+
if self.debug_dataset:
|
674 |
+
example['image_keys'] = bucket[image_index:image_index + self.batch_size]
|
675 |
+
example['captions'] = captions
|
676 |
+
return example
|
677 |
+
|
678 |
+
|
679 |
+
class DreamBoothDataset(BaseDataset):
|
680 |
+
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
|
681 |
+
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
682 |
+
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
683 |
+
|
684 |
+
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
685 |
+
|
686 |
+
self.batch_size = batch_size
|
687 |
+
self.size = min(self.width, self.height) # 短いほう
|
688 |
+
self.prior_loss_weight = prior_loss_weight
|
689 |
+
self.latents_cache = None
|
690 |
+
|
691 |
+
self.enable_bucket = enable_bucket
|
692 |
+
if self.enable_bucket:
|
693 |
+
assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
|
694 |
+
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
695 |
+
self.min_bucket_reso = min_bucket_reso
|
696 |
+
self.max_bucket_reso = max_bucket_reso
|
697 |
+
self.bucket_reso_steps = bucket_reso_steps
|
698 |
+
self.bucket_no_upscale = bucket_no_upscale
|
699 |
+
else:
|
700 |
+
self.min_bucket_reso = None
|
701 |
+
self.max_bucket_reso = None
|
702 |
+
self.bucket_reso_steps = None # この情報は使われない
|
703 |
+
self.bucket_no_upscale = False
|
704 |
+
|
705 |
+
def read_caption(img_path):
|
706 |
+
# captionの候補ファイル名を作る
|
707 |
+
base_name = os.path.splitext(img_path)[0]
|
708 |
+
base_name_face_det = base_name
|
709 |
+
tokens = base_name.split("_")
|
710 |
+
if len(tokens) >= 5:
|
711 |
+
base_name_face_det = "_".join(tokens[:-4])
|
712 |
+
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
|
713 |
+
|
714 |
+
caption = None
|
715 |
+
for cap_path in cap_paths:
|
716 |
+
if os.path.isfile(cap_path):
|
717 |
+
with open(cap_path, "rt", encoding='utf-8') as f:
|
718 |
+
try:
|
719 |
+
lines = f.readlines()
|
720 |
+
except UnicodeDecodeError as e:
|
721 |
+
print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
|
722 |
+
raise e
|
723 |
+
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
724 |
+
caption = lines[0].strip()
|
725 |
+
break
|
726 |
+
return caption
|
727 |
+
|
728 |
+
def load_dreambooth_dir(dir):
|
729 |
+
if not os.path.isdir(dir):
|
730 |
+
# print(f"ignore file: {dir}")
|
731 |
+
return 0, [], []
|
732 |
+
|
733 |
+
tokens = os.path.basename(dir).split('_')
|
734 |
+
try:
|
735 |
+
n_repeats = int(tokens[0])
|
736 |
+
except ValueError as e:
|
737 |
+
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
|
738 |
+
return 0, [], []
|
739 |
+
|
740 |
+
caption_by_folder = '_'.join(tokens[1:])
|
741 |
+
img_paths = glob_images(dir, "*")
|
742 |
+
print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
|
743 |
+
|
744 |
+
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
745 |
+
captions = []
|
746 |
+
for img_path in img_paths:
|
747 |
+
cap_for_img = read_caption(img_path)
|
748 |
+
captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
|
749 |
+
|
750 |
+
self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
|
751 |
+
|
752 |
+
return n_repeats, img_paths, captions
|
753 |
+
|
754 |
+
print("prepare train images.")
|
755 |
+
train_dirs = os.listdir(train_data_dir)
|
756 |
+
num_train_images = 0
|
757 |
+
for dir in train_dirs:
|
758 |
+
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
|
759 |
+
num_train_images += n_repeats * len(img_paths)
|
760 |
+
|
761 |
+
for img_path, caption in zip(img_paths, captions):
|
762 |
+
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
|
763 |
+
self.register_image(info)
|
764 |
+
|
765 |
+
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
766 |
+
|
767 |
+
print(f"{num_train_images} train images with repeating.")
|
768 |
+
self.num_train_images = num_train_images
|
769 |
+
|
770 |
+
# reg imageは数を数えて学習画像と同じ枚数にする
|
771 |
+
num_reg_images = 0
|
772 |
+
if reg_data_dir:
|
773 |
+
print("prepare reg images.")
|
774 |
+
reg_infos: List[ImageInfo] = []
|
775 |
+
|
776 |
+
reg_dirs = os.listdir(reg_data_dir)
|
777 |
+
for dir in reg_dirs:
|
778 |
+
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
|
779 |
+
num_reg_images += n_repeats * len(img_paths)
|
780 |
+
|
781 |
+
for img_path, caption in zip(img_paths, captions):
|
782 |
+
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
783 |
+
reg_infos.append(info)
|
784 |
+
|
785 |
+
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
786 |
+
|
787 |
+
print(f"{num_reg_images} reg images.")
|
788 |
+
if num_train_images < num_reg_images:
|
789 |
+
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
|
790 |
+
|
791 |
+
if num_reg_images == 0:
|
792 |
+
print("no regularization images / 正則化画像が見つかりませんでした")
|
793 |
+
else:
|
794 |
+
# num_repeatsを計算する:どうせ大した数ではないのでループで処理する
|
795 |
+
n = 0
|
796 |
+
first_loop = True
|
797 |
+
while n < num_train_images:
|
798 |
+
for info in reg_infos:
|
799 |
+
if first_loop:
|
800 |
+
self.register_image(info)
|
801 |
+
n += info.num_repeats
|
802 |
+
else:
|
803 |
+
info.num_repeats += 1
|
804 |
+
n += 1
|
805 |
+
if n >= num_train_images:
|
806 |
+
break
|
807 |
+
first_loop = False
|
808 |
+
|
809 |
+
self.num_reg_images = num_reg_images
|
810 |
+
|
811 |
+
|
812 |
+
class FineTuningDataset(BaseDataset):
|
813 |
+
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
|
814 |
+
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
815 |
+
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
816 |
+
|
817 |
+
# メタデータを読み込む
|
818 |
+
if os.path.exists(json_file_name):
|
819 |
+
print(f"loading existing metadata: {json_file_name}")
|
820 |
+
with open(json_file_name, "rt", encoding='utf-8') as f:
|
821 |
+
metadata = json.load(f)
|
822 |
+
else:
|
823 |
+
raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}")
|
824 |
+
|
825 |
+
self.metadata = metadata
|
826 |
+
self.train_data_dir = train_data_dir
|
827 |
+
self.batch_size = batch_size
|
828 |
+
|
829 |
+
tags_list = []
|
830 |
+
for image_key, img_md in metadata.items():
|
831 |
+
# path情報を作る
|
832 |
+
if os.path.exists(image_key):
|
833 |
+
abs_path = image_key
|
834 |
+
else:
|
835 |
+
# わりといい加減だがいい方法が思いつかん
|
836 |
+
abs_path = glob_images(train_data_dir, image_key)
|
837 |
+
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
838 |
+
abs_path = abs_path[0]
|
839 |
+
|
840 |
+
caption = img_md.get('caption')
|
841 |
+
tags = img_md.get('tags')
|
842 |
+
if caption is None:
|
843 |
+
caption = tags
|
844 |
+
elif tags is not None and len(tags) > 0:
|
845 |
+
caption = caption + ', ' + tags
|
846 |
+
tags_list.append(tags)
|
847 |
+
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
|
848 |
+
|
849 |
+
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
|
850 |
+
image_info.image_size = img_md.get('train_resolution')
|
851 |
+
|
852 |
+
if not self.color_aug and not self.random_crop:
|
853 |
+
# if npz exists, use them
|
854 |
+
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
|
855 |
+
|
856 |
+
self.register_image(image_info)
|
857 |
+
self.num_train_images = len(metadata) * dataset_repeats
|
858 |
+
self.num_reg_images = 0
|
859 |
+
|
860 |
+
# TODO do not record tag freq when no tag
|
861 |
+
self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
|
862 |
+
self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
|
863 |
+
|
864 |
+
# check existence of all npz files
|
865 |
+
use_npz_latents = not (self.color_aug or self.random_crop)
|
866 |
+
if use_npz_latents:
|
867 |
+
npz_any = False
|
868 |
+
npz_all = True
|
869 |
+
for image_info in self.image_data.values():
|
870 |
+
has_npz = image_info.latents_npz is not None
|
871 |
+
npz_any = npz_any or has_npz
|
872 |
+
|
873 |
+
if self.flip_aug:
|
874 |
+
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
875 |
+
npz_all = npz_all and has_npz
|
876 |
+
|
877 |
+
if npz_any and not npz_all:
|
878 |
+
break
|
879 |
+
|
880 |
+
if not npz_any:
|
881 |
+
use_npz_latents = False
|
882 |
+
print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します")
|
883 |
+
elif not npz_all:
|
884 |
+
use_npz_latents = False
|
885 |
+
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
|
886 |
+
if self.flip_aug:
|
887 |
+
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
888 |
+
# else:
|
889 |
+
# print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
890 |
+
|
891 |
+
# check min/max bucket size
|
892 |
+
sizes = set()
|
893 |
+
resos = set()
|
894 |
+
for image_info in self.image_data.values():
|
895 |
+
if image_info.image_size is None:
|
896 |
+
sizes = None # not calculated
|
897 |
+
break
|
898 |
+
sizes.add(image_info.image_size[0])
|
899 |
+
sizes.add(image_info.image_size[1])
|
900 |
+
resos.add(tuple(image_info.image_size))
|
901 |
+
|
902 |
+
if sizes is None:
|
903 |
+
if use_npz_latents:
|
904 |
+
use_npz_latents = False
|
905 |
+
print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します")
|
906 |
+
|
907 |
+
assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
|
908 |
+
|
909 |
+
self.enable_bucket = enable_bucket
|
910 |
+
if self.enable_bucket:
|
911 |
+
self.min_bucket_reso = min_bucket_reso
|
912 |
+
self.max_bucket_reso = max_bucket_reso
|
913 |
+
self.bucket_reso_steps = bucket_reso_steps
|
914 |
+
self.bucket_no_upscale = bucket_no_upscale
|
915 |
+
else:
|
916 |
+
if not enable_bucket:
|
917 |
+
print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします")
|
918 |
+
print("using bucket info in metadata / メタデータ内のbucket情報を使います")
|
919 |
+
self.enable_bucket = True
|
920 |
+
|
921 |
+
assert not bucket_no_upscale, "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません"
|
922 |
+
|
923 |
+
# bucket情報を初期化しておく、make_bucketsで再作成しない
|
924 |
+
self.bucket_manager = BucketManager(False, None, None, None, None)
|
925 |
+
self.bucket_manager.set_predefined_resos(resos)
|
926 |
+
|
927 |
+
# npz情報をきれいにしておく
|
928 |
+
if not use_npz_latents:
|
929 |
+
for image_info in self.image_data.values():
|
930 |
+
image_info.latents_npz = image_info.latents_npz_flipped = None
|
931 |
+
|
932 |
+
def image_key_to_npz_file(self, image_key):
|
933 |
+
base_name = os.path.splitext(image_key)[0]
|
934 |
+
npz_file_norm = base_name + '.npz'
|
935 |
+
|
936 |
+
if os.path.exists(npz_file_norm):
|
937 |
+
# image_key is full path
|
938 |
+
npz_file_flip = base_name + '_flip.npz'
|
939 |
+
if not os.path.exists(npz_file_flip):
|
940 |
+
npz_file_flip = None
|
941 |
+
return npz_file_norm, npz_file_flip
|
942 |
+
|
943 |
+
# image_key is relative path
|
944 |
+
npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz')
|
945 |
+
npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz')
|
946 |
+
|
947 |
+
if not os.path.exists(npz_file_norm):
|
948 |
+
npz_file_norm = None
|
949 |
+
npz_file_flip = None
|
950 |
+
elif not os.path.exists(npz_file_flip):
|
951 |
+
npz_file_flip = None
|
952 |
+
|
953 |
+
return npz_file_norm, npz_file_flip
|
954 |
+
|
955 |
+
|
956 |
+
def debug_dataset(train_dataset, show_input_ids=False):
|
957 |
+
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
958 |
+
print("Escape for exit. / Escキーで中断、終了します")
|
959 |
+
|
960 |
+
train_dataset.set_current_epoch(1)
|
961 |
+
k = 0
|
962 |
+
for i, example in enumerate(train_dataset):
|
963 |
+
if example['latents'] is not None:
|
964 |
+
print(f"sample has latents from npz file: {example['latents'].size()}")
|
965 |
+
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
|
966 |
+
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
|
967 |
+
if show_input_ids:
|
968 |
+
print(f"input ids: {iid}")
|
969 |
+
if example['images'] is not None:
|
970 |
+
im = example['images'][j]
|
971 |
+
print(f"image size: {im.size()}")
|
972 |
+
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
|
973 |
+
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
|
974 |
+
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
|
975 |
+
if os.name == 'nt': # only windows
|
976 |
+
cv2.imshow("img", im)
|
977 |
+
k = cv2.waitKey()
|
978 |
+
cv2.destroyAllWindows()
|
979 |
+
if k == 27:
|
980 |
+
break
|
981 |
+
if k == 27 or (example['images'] is None and i >= 8):
|
982 |
+
break
|
983 |
+
|
984 |
+
|
985 |
+
def glob_images(directory, base="*"):
|
986 |
+
img_paths = []
|
987 |
+
for ext in IMAGE_EXTENSIONS:
|
988 |
+
if base == '*':
|
989 |
+
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
|
990 |
+
else:
|
991 |
+
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
992 |
+
# img_paths = list(set(img_paths)) # 重複を排除
|
993 |
+
# img_paths.sort()
|
994 |
+
return img_paths
|
995 |
+
|
996 |
+
|
997 |
+
def glob_images_pathlib(dir_path, recursive):
|
998 |
+
image_paths = []
|
999 |
+
if recursive:
|
1000 |
+
for ext in IMAGE_EXTENSIONS:
|
1001 |
+
image_paths += list(dir_path.rglob('*' + ext))
|
1002 |
+
else:
|
1003 |
+
for ext in IMAGE_EXTENSIONS:
|
1004 |
+
image_paths += list(dir_path.glob('*' + ext))
|
1005 |
+
# image_paths = list(set(image_paths)) # 重複を排除
|
1006 |
+
# image_paths.sort()
|
1007 |
+
return image_paths
|
1008 |
+
|
1009 |
+
# endregion
|
1010 |
+
|
1011 |
+
|
1012 |
+
# region モジュール入れ替え部
|
1013 |
+
"""
|
1014 |
+
高速化のためのモジュール入れ替え
|
1015 |
+
"""
|
1016 |
+
|
1017 |
+
# FlashAttentionを使うCrossAttention
|
1018 |
+
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
|
1019 |
+
# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
|
1020 |
+
|
1021 |
+
# constants
|
1022 |
+
|
1023 |
+
EPSILON = 1e-6
|
1024 |
+
|
1025 |
+
# helper functions
|
1026 |
+
|
1027 |
+
|
1028 |
+
def exists(val):
|
1029 |
+
return val is not None
|
1030 |
+
|
1031 |
+
|
1032 |
+
def default(val, d):
|
1033 |
+
return val if exists(val) else d
|
1034 |
+
|
1035 |
+
|
1036 |
+
def model_hash(filename):
|
1037 |
+
"""Old model hash used by stable-diffusion-webui"""
|
1038 |
+
try:
|
1039 |
+
with open(filename, "rb") as file:
|
1040 |
+
m = hashlib.sha256()
|
1041 |
+
|
1042 |
+
file.seek(0x100000)
|
1043 |
+
m.update(file.read(0x10000))
|
1044 |
+
return m.hexdigest()[0:8]
|
1045 |
+
except FileNotFoundError:
|
1046 |
+
return 'NOFILE'
|
1047 |
+
|
1048 |
+
|
1049 |
+
def calculate_sha256(filename):
|
1050 |
+
"""New model hash used by stable-diffusion-webui"""
|
1051 |
+
hash_sha256 = hashlib.sha256()
|
1052 |
+
blksize = 1024 * 1024
|
1053 |
+
|
1054 |
+
with open(filename, "rb") as f:
|
1055 |
+
for chunk in iter(lambda: f.read(blksize), b""):
|
1056 |
+
hash_sha256.update(chunk)
|
1057 |
+
|
1058 |
+
return hash_sha256.hexdigest()
|
1059 |
+
|
1060 |
+
|
1061 |
+
def precalculate_safetensors_hashes(tensors, metadata):
|
1062 |
+
"""Precalculate the model hashes needed by sd-webui-additional-networks to
|
1063 |
+
save time on indexing the model later."""
|
1064 |
+
|
1065 |
+
# Because writing user metadata to the file can change the result of
|
1066 |
+
# sd_models.model_hash(), only retain the training metadata for purposes of
|
1067 |
+
# calculating the hash, as they are meant to be immutable
|
1068 |
+
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
|
1069 |
+
|
1070 |
+
bytes = safetensors.torch.save(tensors, metadata)
|
1071 |
+
b = BytesIO(bytes)
|
1072 |
+
|
1073 |
+
model_hash = addnet_hash_safetensors(b)
|
1074 |
+
legacy_hash = addnet_hash_legacy(b)
|
1075 |
+
return model_hash, legacy_hash
|
1076 |
+
|
1077 |
+
|
1078 |
+
def addnet_hash_legacy(b):
|
1079 |
+
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
|
1080 |
+
m = hashlib.sha256()
|
1081 |
+
|
1082 |
+
b.seek(0x100000)
|
1083 |
+
m.update(b.read(0x10000))
|
1084 |
+
return m.hexdigest()[0:8]
|
1085 |
+
|
1086 |
+
|
1087 |
+
def addnet_hash_safetensors(b):
|
1088 |
+
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
|
1089 |
+
hash_sha256 = hashlib.sha256()
|
1090 |
+
blksize = 1024 * 1024
|
1091 |
+
|
1092 |
+
b.seek(0)
|
1093 |
+
header = b.read(8)
|
1094 |
+
n = int.from_bytes(header, "little")
|
1095 |
+
|
1096 |
+
offset = n + 8
|
1097 |
+
b.seek(offset)
|
1098 |
+
for chunk in iter(lambda: b.read(blksize), b""):
|
1099 |
+
hash_sha256.update(chunk)
|
1100 |
+
|
1101 |
+
return hash_sha256.hexdigest()
|
1102 |
+
|
1103 |
+
|
1104 |
+
def get_git_revision_hash() -> str:
|
1105 |
+
try:
|
1106 |
+
return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=os.path.dirname(__file__)).decode('ascii').strip()
|
1107 |
+
except:
|
1108 |
+
return "(unknown)"
|
1109 |
+
|
1110 |
+
|
1111 |
+
# flash attention forwards and backwards
|
1112 |
+
|
1113 |
+
# https://arxiv.org/abs/2205.14135
|
1114 |
+
|
1115 |
+
|
1116 |
+
class FlashAttentionFunction(torch.autograd.function.Function):
|
1117 |
+
@ staticmethod
|
1118 |
+
@ torch.no_grad()
|
1119 |
+
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
1120 |
+
""" Algorithm 2 in the paper """
|
1121 |
+
|
1122 |
+
device = q.device
|
1123 |
+
dtype = q.dtype
|
1124 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
1125 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
1126 |
+
|
1127 |
+
o = torch.zeros_like(q)
|
1128 |
+
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
1129 |
+
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
|
1130 |
+
|
1131 |
+
scale = (q.shape[-1] ** -0.5)
|
1132 |
+
|
1133 |
+
if not exists(mask):
|
1134 |
+
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
1135 |
+
else:
|
1136 |
+
mask = rearrange(mask, 'b n -> b 1 1 n')
|
1137 |
+
mask = mask.split(q_bucket_size, dim=-1)
|
1138 |
+
|
1139 |
+
row_splits = zip(
|
1140 |
+
q.split(q_bucket_size, dim=-2),
|
1141 |
+
o.split(q_bucket_size, dim=-2),
|
1142 |
+
mask,
|
1143 |
+
all_row_sums.split(q_bucket_size, dim=-2),
|
1144 |
+
all_row_maxes.split(q_bucket_size, dim=-2),
|
1145 |
+
)
|
1146 |
+
|
1147 |
+
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
1148 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
1149 |
+
|
1150 |
+
col_splits = zip(
|
1151 |
+
k.split(k_bucket_size, dim=-2),
|
1152 |
+
v.split(k_bucket_size, dim=-2),
|
1153 |
+
)
|
1154 |
+
|
1155 |
+
for k_ind, (kc, vc) in enumerate(col_splits):
|
1156 |
+
k_start_index = k_ind * k_bucket_size
|
1157 |
+
|
1158 |
+
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
|
1159 |
+
|
1160 |
+
if exists(row_mask):
|
1161 |
+
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
1162 |
+
|
1163 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
1164 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
|
1165 |
+
device=device).triu(q_start_index - k_start_index + 1)
|
1166 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
1167 |
+
|
1168 |
+
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
1169 |
+
attn_weights -= block_row_maxes
|
1170 |
+
exp_weights = torch.exp(attn_weights)
|
1171 |
+
|
1172 |
+
if exists(row_mask):
|
1173 |
+
exp_weights.masked_fill_(~row_mask, 0.)
|
1174 |
+
|
1175 |
+
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
|
1176 |
+
|
1177 |
+
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
1178 |
+
|
1179 |
+
exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
|
1180 |
+
|
1181 |
+
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
1182 |
+
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
1183 |
+
|
1184 |
+
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
|
1185 |
+
|
1186 |
+
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
|
1187 |
+
|
1188 |
+
row_maxes.copy_(new_row_maxes)
|
1189 |
+
row_sums.copy_(new_row_sums)
|
1190 |
+
|
1191 |
+
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
1192 |
+
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
1193 |
+
|
1194 |
+
return o
|
1195 |
+
|
1196 |
+
@ staticmethod
|
1197 |
+
@ torch.no_grad()
|
1198 |
+
def backward(ctx, do):
|
1199 |
+
""" Algorithm 4 in the paper """
|
1200 |
+
|
1201 |
+
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
1202 |
+
q, k, v, o, l, m = ctx.saved_tensors
|
1203 |
+
|
1204 |
+
device = q.device
|
1205 |
+
|
1206 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
1207 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
1208 |
+
|
1209 |
+
dq = torch.zeros_like(q)
|
1210 |
+
dk = torch.zeros_like(k)
|
1211 |
+
dv = torch.zeros_like(v)
|
1212 |
+
|
1213 |
+
row_splits = zip(
|
1214 |
+
q.split(q_bucket_size, dim=-2),
|
1215 |
+
o.split(q_bucket_size, dim=-2),
|
1216 |
+
do.split(q_bucket_size, dim=-2),
|
1217 |
+
mask,
|
1218 |
+
l.split(q_bucket_size, dim=-2),
|
1219 |
+
m.split(q_bucket_size, dim=-2),
|
1220 |
+
dq.split(q_bucket_size, dim=-2)
|
1221 |
+
)
|
1222 |
+
|
1223 |
+
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
1224 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
1225 |
+
|
1226 |
+
col_splits = zip(
|
1227 |
+
k.split(k_bucket_size, dim=-2),
|
1228 |
+
v.split(k_bucket_size, dim=-2),
|
1229 |
+
dk.split(k_bucket_size, dim=-2),
|
1230 |
+
dv.split(k_bucket_size, dim=-2),
|
1231 |
+
)
|
1232 |
+
|
1233 |
+
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
1234 |
+
k_start_index = k_ind * k_bucket_size
|
1235 |
+
|
1236 |
+
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
|
1237 |
+
|
1238 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
1239 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
|
1240 |
+
device=device).triu(q_start_index - k_start_index + 1)
|
1241 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
1242 |
+
|
1243 |
+
exp_attn_weights = torch.exp(attn_weights - mc)
|
1244 |
+
|
1245 |
+
if exists(row_mask):
|
1246 |
+
exp_attn_weights.masked_fill_(~row_mask, 0.)
|
1247 |
+
|
1248 |
+
p = exp_attn_weights / lc
|
1249 |
+
|
1250 |
+
dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
|
1251 |
+
dp = einsum('... i d, ... j d -> ... i j', doc, vc)
|
1252 |
+
|
1253 |
+
D = (doc * oc).sum(dim=-1, keepdims=True)
|
1254 |
+
ds = p * scale * (dp - D)
|
1255 |
+
|
1256 |
+
dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
|
1257 |
+
dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
|
1258 |
+
|
1259 |
+
dqc.add_(dq_chunk)
|
1260 |
+
dkc.add_(dk_chunk)
|
1261 |
+
dvc.add_(dv_chunk)
|
1262 |
+
|
1263 |
+
return dq, dk, dv, None, None, None, None
|
1264 |
+
|
1265 |
+
|
1266 |
+
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
1267 |
+
if mem_eff_attn:
|
1268 |
+
replace_unet_cross_attn_to_memory_efficient()
|
1269 |
+
elif xformers:
|
1270 |
+
replace_unet_cross_attn_to_xformers()
|
1271 |
+
|
1272 |
+
|
1273 |
+
def replace_unet_cross_attn_to_memory_efficient():
|
1274 |
+
print("Replace CrossAttention.forward to use FlashAttention (not xformers)")
|
1275 |
+
flash_func = FlashAttentionFunction
|
1276 |
+
|
1277 |
+
def forward_flash_attn(self, x, context=None, mask=None):
|
1278 |
+
q_bucket_size = 512
|
1279 |
+
k_bucket_size = 1024
|
1280 |
+
|
1281 |
+
h = self.heads
|
1282 |
+
q = self.to_q(x)
|
1283 |
+
|
1284 |
+
context = context if context is not None else x
|
1285 |
+
context = context.to(x.dtype)
|
1286 |
+
|
1287 |
+
if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
|
1288 |
+
context_k, context_v = self.hypernetwork.forward(x, context)
|
1289 |
+
context_k = context_k.to(x.dtype)
|
1290 |
+
context_v = context_v.to(x.dtype)
|
1291 |
+
else:
|
1292 |
+
context_k = context
|
1293 |
+
context_v = context
|
1294 |
+
|
1295 |
+
k = self.to_k(context_k)
|
1296 |
+
v = self.to_v(context_v)
|
1297 |
+
del context, x
|
1298 |
+
|
1299 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
|
1300 |
+
|
1301 |
+
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
|
1302 |
+
|
1303 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
1304 |
+
|
1305 |
+
# diffusers 0.7.0~ わざわざ変えるなよ (;´Д`)
|
1306 |
+
out = self.to_out[0](out)
|
1307 |
+
out = self.to_out[1](out)
|
1308 |
+
return out
|
1309 |
+
|
1310 |
+
diffusers.models.attention.CrossAttention.forward = forward_flash_attn
|
1311 |
+
|
1312 |
+
|
1313 |
+
def replace_unet_cross_attn_to_xformers():
|
1314 |
+
print("Replace CrossAttention.forward to use xformers")
|
1315 |
+
try:
|
1316 |
+
import xformers.ops
|
1317 |
+
except ImportError:
|
1318 |
+
raise ImportError("No xformers / xformersがインストールされていないようです")
|
1319 |
+
|
1320 |
+
def forward_xformers(self, x, context=None, mask=None):
|
1321 |
+
h = self.heads
|
1322 |
+
q_in = self.to_q(x)
|
1323 |
+
|
1324 |
+
context = default(context, x)
|
1325 |
+
context = context.to(x.dtype)
|
1326 |
+
|
1327 |
+
if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
|
1328 |
+
context_k, context_v = self.hypernetwork.forward(x, context)
|
1329 |
+
context_k = context_k.to(x.dtype)
|
1330 |
+
context_v = context_v.to(x.dtype)
|
1331 |
+
else:
|
1332 |
+
context_k = context
|
1333 |
+
context_v = context
|
1334 |
+
|
1335 |
+
k_in = self.to_k(context_k)
|
1336 |
+
v_in = self.to_v(context_v)
|
1337 |
+
|
1338 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
1339 |
+
del q_in, k_in, v_in
|
1340 |
+
|
1341 |
+
q = q.contiguous()
|
1342 |
+
k = k.contiguous()
|
1343 |
+
v = v.contiguous()
|
1344 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
1345 |
+
|
1346 |
+
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
1347 |
+
|
1348 |
+
# diffusers 0.7.0~
|
1349 |
+
out = self.to_out[0](out)
|
1350 |
+
out = self.to_out[1](out)
|
1351 |
+
return out
|
1352 |
+
|
1353 |
+
diffusers.models.attention.CrossAttention.forward = forward_xformers
|
1354 |
+
# endregion
|
1355 |
+
|
1356 |
+
|
1357 |
+
# region arguments
|
1358 |
+
|
1359 |
+
def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
1360 |
+
# for pretrained models
|
1361 |
+
parser.add_argument("--v2", action='store_true',
|
1362 |
+
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
|
1363 |
+
parser.add_argument("--v_parameterization", action='store_true',
|
1364 |
+
help='enable v-parameterization training / v-parameterization学習を有効にする')
|
1365 |
+
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
|
1366 |
+
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
1367 |
+
|
1368 |
+
|
1369 |
+
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
1370 |
+
parser.add_argument("--output_dir", type=str, default=None,
|
1371 |
+
help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
|
1372 |
+
parser.add_argument("--output_name", type=str, default=None,
|
1373 |
+
help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
|
1374 |
+
parser.add_argument("--save_precision", type=str, default=None,
|
1375 |
+
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
|
1376 |
+
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
1377 |
+
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
1378 |
+
parser.add_argument("--save_n_epoch_ratio", type=int, default=None,
|
1379 |
+
help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)")
|
1380 |
+
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
|
1381 |
+
parser.add_argument("--save_last_n_epochs_state", type=int, default=None,
|
1382 |
+
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
|
1383 |
+
parser.add_argument("--save_state", action="store_true",
|
1384 |
+
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
1385 |
+
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
1386 |
+
|
1387 |
+
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
1388 |
+
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
1389 |
+
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
1390 |
+
parser.add_argument("--use_8bit_adam", action="store_true",
|
1391 |
+
help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
1392 |
+
parser.add_argument("--use_lion_optimizer", action="store_true",
|
1393 |
+
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
1394 |
+
parser.add_argument("--mem_eff_attn", action="store_true",
|
1395 |
+
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
1396 |
+
parser.add_argument("--xformers", action="store_true",
|
1397 |
+
help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
1398 |
+
parser.add_argument("--vae", type=str, default=None,
|
1399 |
+
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
1400 |
+
|
1401 |
+
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
1402 |
+
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
1403 |
+
parser.add_argument("--max_train_epochs", type=int, default=None,
|
1404 |
+
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
1405 |
+
parser.add_argument("--max_data_loader_n_workers", type=int, default=8,
|
1406 |
+
help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
|
1407 |
+
parser.add_argument("--persistent_data_loader_workers", action="store_true",
|
1408 |
+
help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)")
|
1409 |
+
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
1410 |
+
parser.add_argument("--gradient_checkpointing", action="store_true",
|
1411 |
+
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
1412 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
|
1413 |
+
help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数")
|
1414 |
+
parser.add_argument("--mixed_precision", type=str, default="no",
|
1415 |
+
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
1416 |
+
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
|
1417 |
+
parser.add_argument("--clip_skip", type=int, default=None,
|
1418 |
+
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
1419 |
+
parser.add_argument("--logging_dir", type=str, default=None,
|
1420 |
+
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
1421 |
+
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
1422 |
+
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
1423 |
+
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
|
1424 |
+
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
1425 |
+
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
1426 |
+
parser.add_argument("--noise_offset", type=float, default=None,
|
1427 |
+
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
|
1428 |
+
parser.add_argument("--lowram", action="store_true",
|
1429 |
+
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
|
1430 |
+
|
1431 |
+
if support_dreambooth:
|
1432 |
+
# DreamBooth training
|
1433 |
+
parser.add_argument("--prior_loss_weight", type=float, default=1.0,
|
1434 |
+
help="loss weight for regularization images / 正則化画像のlossの重み")
|
1435 |
+
|
1436 |
+
|
1437 |
+
def verify_training_args(args: argparse.Namespace):
|
1438 |
+
if args.v_parameterization and not args.v2:
|
1439 |
+
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
|
1440 |
+
if args.v2 and args.clip_skip is not None:
|
1441 |
+
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
1442 |
+
|
1443 |
+
|
1444 |
+
def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool):
|
1445 |
+
# dataset common
|
1446 |
+
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
|
1447 |
+
parser.add_argument("--shuffle_caption", action="store_true",
|
1448 |
+
help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする")
|
1449 |
+
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
|
1450 |
+
parser.add_argument("--caption_extention", type=str, default=None,
|
1451 |
+
help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
|
1452 |
+
parser.add_argument("--keep_tokens", type=int, default=None,
|
1453 |
+
help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
|
1454 |
+
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
|
1455 |
+
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
|
1456 |
+
parser.add_argument("--face_crop_aug_range", type=str, default=None,
|
1457 |
+
help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)")
|
1458 |
+
parser.add_argument("--random_crop", action="store_true",
|
1459 |
+
help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)")
|
1460 |
+
parser.add_argument("--debug_dataset", action="store_true",
|
1461 |
+
help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
|
1462 |
+
parser.add_argument("--resolution", type=str, default=None,
|
1463 |
+
help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)")
|
1464 |
+
parser.add_argument("--cache_latents", action="store_true",
|
1465 |
+
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
|
1466 |
+
parser.add_argument("--enable_bucket", action="store_true",
|
1467 |
+
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
|
1468 |
+
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
1469 |
+
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
|
1470 |
+
parser.add_argument("--bucket_reso_steps", type=int, default=64,
|
1471 |
+
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
|
1472 |
+
parser.add_argument("--bucket_no_upscale", action="store_true",
|
1473 |
+
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
|
1474 |
+
|
1475 |
+
if support_caption_dropout:
|
1476 |
+
# Textual Inversion はcaptionのdropoutをsupportしない
|
1477 |
+
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
1478 |
+
parser.add_argument("--caption_dropout_rate", type=float, default=0,
|
1479 |
+
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
1480 |
+
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
|
1481 |
+
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
1482 |
+
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
|
1483 |
+
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
|
1484 |
+
|
1485 |
+
if support_dreambooth:
|
1486 |
+
# DreamBooth dataset
|
1487 |
+
parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
|
1488 |
+
|
1489 |
+
if support_caption:
|
1490 |
+
# caption dataset
|
1491 |
+
parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル")
|
1492 |
+
parser.add_argument("--dataset_repeats", type=int, default=1,
|
1493 |
+
help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数")
|
1494 |
+
|
1495 |
+
|
1496 |
+
def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
1497 |
+
parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"],
|
1498 |
+
help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)")
|
1499 |
+
parser.add_argument("--use_safetensors", action='store_true',
|
1500 |
+
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)")
|
1501 |
+
|
1502 |
+
# endregion
|
1503 |
+
|
1504 |
+
# region utils
|
1505 |
+
|
1506 |
+
|
1507 |
+
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
1508 |
+
# backward compatibility
|
1509 |
+
if args.caption_extention is not None:
|
1510 |
+
args.caption_extension = args.caption_extention
|
1511 |
+
args.caption_extention = None
|
1512 |
+
|
1513 |
+
if args.cache_latents:
|
1514 |
+
assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
|
1515 |
+
assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
|
1516 |
+
|
1517 |
+
# assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
|
1518 |
+
if args.resolution is not None:
|
1519 |
+
args.resolution = tuple([int(r) for r in args.resolution.split(',')])
|
1520 |
+
if len(args.resolution) == 1:
|
1521 |
+
args.resolution = (args.resolution[0], args.resolution[0])
|
1522 |
+
assert len(args.resolution) == 2, \
|
1523 |
+
f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
|
1524 |
+
|
1525 |
+
if args.face_crop_aug_range is not None:
|
1526 |
+
args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
|
1527 |
+
assert len(args.face_crop_aug_range) == 2 and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1], \
|
1528 |
+
f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}"
|
1529 |
+
else:
|
1530 |
+
args.face_crop_aug_range = None
|
1531 |
+
|
1532 |
+
if support_metadata:
|
1533 |
+
if args.in_json is not None and (args.color_aug or args.random_crop):
|
1534 |
+
print(f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます")
|
1535 |
+
|
1536 |
+
|
1537 |
+
def load_tokenizer(args: argparse.Namespace):
|
1538 |
+
print("prepare tokenizer")
|
1539 |
+
if args.v2:
|
1540 |
+
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
1541 |
+
else:
|
1542 |
+
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
|
1543 |
+
if args.max_token_length is not None:
|
1544 |
+
print(f"update token length: {args.max_token_length}")
|
1545 |
+
return tokenizer
|
1546 |
+
|
1547 |
+
|
1548 |
+
def prepare_accelerator(args: argparse.Namespace):
|
1549 |
+
if args.logging_dir is None:
|
1550 |
+
log_with = None
|
1551 |
+
logging_dir = None
|
1552 |
+
else:
|
1553 |
+
log_with = "tensorboard"
|
1554 |
+
log_prefix = "" if args.log_prefix is None else args.log_prefix
|
1555 |
+
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime())
|
1556 |
+
|
1557 |
+
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision,
|
1558 |
+
log_with=log_with, logging_dir=logging_dir)
|
1559 |
+
|
1560 |
+
# accelerateの互換性問題を解決する
|
1561 |
+
accelerator_0_15 = True
|
1562 |
+
try:
|
1563 |
+
accelerator.unwrap_model("dummy", True)
|
1564 |
+
print("Using accelerator 0.15.0 or above.")
|
1565 |
+
except TypeError:
|
1566 |
+
accelerator_0_15 = False
|
1567 |
+
|
1568 |
+
def unwrap_model(model):
|
1569 |
+
if accelerator_0_15:
|
1570 |
+
return accelerator.unwrap_model(model, True)
|
1571 |
+
return accelerator.unwrap_model(model)
|
1572 |
+
|
1573 |
+
return accelerator, unwrap_model
|
1574 |
+
|
1575 |
+
|
1576 |
+
def prepare_dtype(args: argparse.Namespace):
|
1577 |
+
weight_dtype = torch.float32
|
1578 |
+
if args.mixed_precision == "fp16":
|
1579 |
+
weight_dtype = torch.float16
|
1580 |
+
elif args.mixed_precision == "bf16":
|
1581 |
+
weight_dtype = torch.bfloat16
|
1582 |
+
|
1583 |
+
save_dtype = None
|
1584 |
+
if args.save_precision == "fp16":
|
1585 |
+
save_dtype = torch.float16
|
1586 |
+
elif args.save_precision == "bf16":
|
1587 |
+
save_dtype = torch.bfloat16
|
1588 |
+
elif args.save_precision == "float":
|
1589 |
+
save_dtype = torch.float32
|
1590 |
+
|
1591 |
+
return weight_dtype, save_dtype
|
1592 |
+
|
1593 |
+
|
1594 |
+
def load_target_model(args: argparse.Namespace, weight_dtype):
|
1595 |
+
load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers
|
1596 |
+
if load_stable_diffusion_format:
|
1597 |
+
print("load StableDiffusion checkpoint")
|
1598 |
+
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
|
1599 |
+
else:
|
1600 |
+
print("load Diffusers pretrained models")
|
1601 |
+
pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
|
1602 |
+
text_encoder = pipe.text_encoder
|
1603 |
+
vae = pipe.vae
|
1604 |
+
unet = pipe.unet
|
1605 |
+
del pipe
|
1606 |
+
|
1607 |
+
# VAEを読み込む
|
1608 |
+
if args.vae is not None:
|
1609 |
+
vae = model_util.load_vae(args.vae, weight_dtype)
|
1610 |
+
print("additional VAE loaded")
|
1611 |
+
|
1612 |
+
return text_encoder, vae, unet, load_stable_diffusion_format
|
1613 |
+
|
1614 |
+
|
1615 |
+
def patch_accelerator_for_fp16_training(accelerator):
|
1616 |
+
org_unscale_grads = accelerator.scaler._unscale_grads_
|
1617 |
+
|
1618 |
+
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
1619 |
+
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
|
1620 |
+
|
1621 |
+
accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
|
1622 |
+
|
1623 |
+
|
1624 |
+
def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None):
|
1625 |
+
# with no_token_padding, the length is not max length, return result immediately
|
1626 |
+
if input_ids.size()[-1] != tokenizer.model_max_length:
|
1627 |
+
return text_encoder(input_ids)[0]
|
1628 |
+
|
1629 |
+
b_size = input_ids.size()[0]
|
1630 |
+
input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
|
1631 |
+
|
1632 |
+
if args.clip_skip is None:
|
1633 |
+
encoder_hidden_states = text_encoder(input_ids)[0]
|
1634 |
+
else:
|
1635 |
+
enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
|
1636 |
+
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
1637 |
+
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
1638 |
+
|
1639 |
+
# bs*3, 77, 768 or 1024
|
1640 |
+
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
|
1641 |
+
|
1642 |
+
if args.max_token_length is not None:
|
1643 |
+
if args.v2:
|
1644 |
+
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
1645 |
+
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
1646 |
+
for i in range(1, args.max_token_length, tokenizer.model_max_length):
|
1647 |
+
chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # <BOS> の後から 最後の前まで
|
1648 |
+
if i > 0:
|
1649 |
+
for j in range(len(chunk)):
|
1650 |
+
if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
1651 |
+
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
1652 |
+
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
1653 |
+
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
1654 |
+
encoder_hidden_states = torch.cat(states_list, dim=1)
|
1655 |
+
else:
|
1656 |
+
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
1657 |
+
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
1658 |
+
for i in range(1, args.max_token_length, tokenizer.model_max_length):
|
1659 |
+
states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
1660 |
+
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
|
1661 |
+
encoder_hidden_states = torch.cat(states_list, dim=1)
|
1662 |
+
|
1663 |
+
if weight_dtype is not None:
|
1664 |
+
# this is required for additional network training
|
1665 |
+
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
|
1666 |
+
|
1667 |
+
return encoder_hidden_states
|
1668 |
+
|
1669 |
+
|
1670 |
+
def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch):
|
1671 |
+
model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
1672 |
+
ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt")
|
1673 |
+
return model_name, ckpt_name
|
1674 |
+
|
1675 |
+
|
1676 |
+
def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int):
|
1677 |
+
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
|
1678 |
+
if saving:
|
1679 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
1680 |
+
save_func()
|
1681 |
+
|
1682 |
+
if args.save_last_n_epochs is not None:
|
1683 |
+
remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
|
1684 |
+
remove_old_func(remove_epoch_no)
|
1685 |
+
return saving
|
1686 |
+
|
1687 |
+
|
1688 |
+
def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae):
|
1689 |
+
epoch_no = epoch + 1
|
1690 |
+
model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no)
|
1691 |
+
|
1692 |
+
if save_stable_diffusion_format:
|
1693 |
+
def save_sd():
|
1694 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
1695 |
+
print(f"saving checkpoint: {ckpt_file}")
|
1696 |
+
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
|
1697 |
+
src_path, epoch_no, global_step, save_dtype, vae)
|
1698 |
+
|
1699 |
+
def remove_sd(old_epoch_no):
|
1700 |
+
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
|
1701 |
+
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
1702 |
+
if os.path.exists(old_ckpt_file):
|
1703 |
+
print(f"removing old checkpoint: {old_ckpt_file}")
|
1704 |
+
os.remove(old_ckpt_file)
|
1705 |
+
|
1706 |
+
save_func = save_sd
|
1707 |
+
remove_old_func = remove_sd
|
1708 |
+
else:
|
1709 |
+
def save_du():
|
1710 |
+
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no))
|
1711 |
+
print(f"saving model: {out_dir}")
|
1712 |
+
os.makedirs(out_dir, exist_ok=True)
|
1713 |
+
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
|
1714 |
+
src_path, vae=vae, use_safetensors=use_safetensors)
|
1715 |
+
|
1716 |
+
def remove_du(old_epoch_no):
|
1717 |
+
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
|
1718 |
+
if os.path.exists(out_dir_old):
|
1719 |
+
print(f"removing old model: {out_dir_old}")
|
1720 |
+
shutil.rmtree(out_dir_old)
|
1721 |
+
|
1722 |
+
save_func = save_du
|
1723 |
+
remove_old_func = remove_du
|
1724 |
+
|
1725 |
+
saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
|
1726 |
+
if saving and args.save_state:
|
1727 |
+
save_state_on_epoch_end(args, accelerator, model_name, epoch_no)
|
1728 |
+
|
1729 |
+
|
1730 |
+
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no):
|
1731 |
+
print("saving state.")
|
1732 |
+
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
|
1733 |
+
|
1734 |
+
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
|
1735 |
+
if last_n_epochs is not None:
|
1736 |
+
remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs
|
1737 |
+
state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
|
1738 |
+
if os.path.exists(state_dir_old):
|
1739 |
+
print(f"removing old state: {state_dir_old}")
|
1740 |
+
shutil.rmtree(state_dir_old)
|
1741 |
+
|
1742 |
+
|
1743 |
+
def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, global_step: int, text_encoder, unet, vae):
|
1744 |
+
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
1745 |
+
|
1746 |
+
if save_stable_diffusion_format:
|
1747 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
1748 |
+
|
1749 |
+
ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt")
|
1750 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
1751 |
+
|
1752 |
+
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
|
1753 |
+
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
|
1754 |
+
src_path, epoch, global_step, save_dtype, vae)
|
1755 |
+
else:
|
1756 |
+
out_dir = os.path.join(args.output_dir, model_name)
|
1757 |
+
os.makedirs(out_dir, exist_ok=True)
|
1758 |
+
|
1759 |
+
print(f"save trained model as Diffusers to {out_dir}")
|
1760 |
+
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
|
1761 |
+
src_path, vae=vae, use_safetensors=use_safetensors)
|
1762 |
+
|
1763 |
+
|
1764 |
+
def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
1765 |
+
print("saving last state.")
|
1766 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
1767 |
+
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
1768 |
+
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
1769 |
+
|
1770 |
+
# endregion
|
1771 |
+
|
1772 |
+
# region 前処理用
|
1773 |
+
|
1774 |
+
|
1775 |
+
class ImageLoadingDataset(torch.utils.data.Dataset):
|
1776 |
+
def __init__(self, image_paths):
|
1777 |
+
self.images = image_paths
|
1778 |
+
|
1779 |
+
def __len__(self):
|
1780 |
+
return len(self.images)
|
1781 |
+
|
1782 |
+
def __getitem__(self, idx):
|
1783 |
+
img_path = self.images[idx]
|
1784 |
+
|
1785 |
+
try:
|
1786 |
+
image = Image.open(img_path).convert("RGB")
|
1787 |
+
# convert to tensor temporarily so dataloader will accept it
|
1788 |
+
tensor_pil = transforms.functional.pil_to_tensor(image)
|
1789 |
+
except Exception as e:
|
1790 |
+
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
1791 |
+
return None
|
1792 |
+
|
1793 |
+
return (tensor_pil, img_path)
|
1794 |
+
|
1795 |
+
|
1796 |
+
# endregion
|
locon/__init__.py
ADDED
File without changes
|
locon/kohya_model_utils.py
ADDED
@@ -0,0 +1,1184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
https://github.com/kohya-ss/sd-scripts/blob/main/library/model_util.py
|
3 |
+
'''
|
4 |
+
# v1: split from train_db_fixed.py.
|
5 |
+
# v2: support safetensors
|
6 |
+
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import torch
|
10 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
|
11 |
+
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
12 |
+
from safetensors.torch import load_file, save_file
|
13 |
+
|
14 |
+
# DiffUsers版StableDiffusionのモデルパラメータ
|
15 |
+
NUM_TRAIN_TIMESTEPS = 1000
|
16 |
+
BETA_START = 0.00085
|
17 |
+
BETA_END = 0.0120
|
18 |
+
|
19 |
+
UNET_PARAMS_MODEL_CHANNELS = 320
|
20 |
+
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
|
21 |
+
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
|
22 |
+
UNET_PARAMS_IMAGE_SIZE = 32 # unused
|
23 |
+
UNET_PARAMS_IN_CHANNELS = 4
|
24 |
+
UNET_PARAMS_OUT_CHANNELS = 4
|
25 |
+
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
26 |
+
UNET_PARAMS_CONTEXT_DIM = 768
|
27 |
+
UNET_PARAMS_NUM_HEADS = 8
|
28 |
+
|
29 |
+
VAE_PARAMS_Z_CHANNELS = 4
|
30 |
+
VAE_PARAMS_RESOLUTION = 256
|
31 |
+
VAE_PARAMS_IN_CHANNELS = 3
|
32 |
+
VAE_PARAMS_OUT_CH = 3
|
33 |
+
VAE_PARAMS_CH = 128
|
34 |
+
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
|
35 |
+
VAE_PARAMS_NUM_RES_BLOCKS = 2
|
36 |
+
|
37 |
+
# V2
|
38 |
+
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
39 |
+
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
40 |
+
|
41 |
+
# Diffusersの設定を読み込むための参照モデル
|
42 |
+
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
|
43 |
+
DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
|
44 |
+
|
45 |
+
|
46 |
+
# region StableDiffusion->Diffusersの変換コード
|
47 |
+
# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
|
48 |
+
|
49 |
+
|
50 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
51 |
+
"""
|
52 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
53 |
+
"""
|
54 |
+
if n_shave_prefix_segments >= 0:
|
55 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
56 |
+
else:
|
57 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
58 |
+
|
59 |
+
|
60 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
61 |
+
"""
|
62 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
63 |
+
"""
|
64 |
+
mapping = []
|
65 |
+
for old_item in old_list:
|
66 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
67 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
68 |
+
|
69 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
70 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
71 |
+
|
72 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
73 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
74 |
+
|
75 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
76 |
+
|
77 |
+
mapping.append({"old": old_item, "new": new_item})
|
78 |
+
|
79 |
+
return mapping
|
80 |
+
|
81 |
+
|
82 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
83 |
+
"""
|
84 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
85 |
+
"""
|
86 |
+
mapping = []
|
87 |
+
for old_item in old_list:
|
88 |
+
new_item = old_item
|
89 |
+
|
90 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
91 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
92 |
+
|
93 |
+
mapping.append({"old": old_item, "new": new_item})
|
94 |
+
|
95 |
+
return mapping
|
96 |
+
|
97 |
+
|
98 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
99 |
+
"""
|
100 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
101 |
+
"""
|
102 |
+
mapping = []
|
103 |
+
for old_item in old_list:
|
104 |
+
new_item = old_item
|
105 |
+
|
106 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
107 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
108 |
+
|
109 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
110 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
111 |
+
|
112 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
113 |
+
|
114 |
+
mapping.append({"old": old_item, "new": new_item})
|
115 |
+
|
116 |
+
return mapping
|
117 |
+
|
118 |
+
|
119 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
120 |
+
"""
|
121 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
122 |
+
"""
|
123 |
+
mapping = []
|
124 |
+
for old_item in old_list:
|
125 |
+
new_item = old_item
|
126 |
+
|
127 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
128 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
129 |
+
|
130 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
131 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
132 |
+
|
133 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
134 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
135 |
+
|
136 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
137 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
138 |
+
|
139 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
140 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
141 |
+
|
142 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
143 |
+
|
144 |
+
mapping.append({"old": old_item, "new": new_item})
|
145 |
+
|
146 |
+
return mapping
|
147 |
+
|
148 |
+
|
149 |
+
def assign_to_checkpoint(
|
150 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
151 |
+
):
|
152 |
+
"""
|
153 |
+
This does the final conversion step: take locally converted weights and apply a global renaming
|
154 |
+
to them. It splits attention layers, and takes into account additional replacements
|
155 |
+
that may arise.
|
156 |
+
|
157 |
+
Assigns the weights to the new checkpoint.
|
158 |
+
"""
|
159 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
160 |
+
|
161 |
+
# Splits the attention layers into three variables.
|
162 |
+
if attention_paths_to_split is not None:
|
163 |
+
for path, path_map in attention_paths_to_split.items():
|
164 |
+
old_tensor = old_checkpoint[path]
|
165 |
+
channels = old_tensor.shape[0] // 3
|
166 |
+
|
167 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
168 |
+
|
169 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
170 |
+
|
171 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
172 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
173 |
+
|
174 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
175 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
176 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
177 |
+
|
178 |
+
for path in paths:
|
179 |
+
new_path = path["new"]
|
180 |
+
|
181 |
+
# These have already been assigned
|
182 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
183 |
+
continue
|
184 |
+
|
185 |
+
# Global renaming happens here
|
186 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
187 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
188 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
189 |
+
|
190 |
+
if additional_replacements is not None:
|
191 |
+
for replacement in additional_replacements:
|
192 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
193 |
+
|
194 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
195 |
+
if "proj_attn.weight" in new_path:
|
196 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
197 |
+
else:
|
198 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
199 |
+
|
200 |
+
|
201 |
+
def conv_attn_to_linear(checkpoint):
|
202 |
+
keys = list(checkpoint.keys())
|
203 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
204 |
+
for key in keys:
|
205 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
206 |
+
if checkpoint[key].ndim > 2:
|
207 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
208 |
+
elif "proj_attn.weight" in key:
|
209 |
+
if checkpoint[key].ndim > 2:
|
210 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
211 |
+
|
212 |
+
|
213 |
+
def linear_transformer_to_conv(checkpoint):
|
214 |
+
keys = list(checkpoint.keys())
|
215 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
216 |
+
for key in keys:
|
217 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
218 |
+
if checkpoint[key].ndim == 2:
|
219 |
+
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
220 |
+
|
221 |
+
|
222 |
+
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
223 |
+
"""
|
224 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
225 |
+
"""
|
226 |
+
|
227 |
+
# extract state_dict for UNet
|
228 |
+
unet_state_dict = {}
|
229 |
+
unet_key = "model.diffusion_model."
|
230 |
+
keys = list(checkpoint.keys())
|
231 |
+
for key in keys:
|
232 |
+
if key.startswith(unet_key):
|
233 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
234 |
+
|
235 |
+
new_checkpoint = {}
|
236 |
+
|
237 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
238 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
239 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
240 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
241 |
+
|
242 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
243 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
244 |
+
|
245 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
246 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
247 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
248 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
249 |
+
|
250 |
+
# Retrieves the keys for the input blocks only
|
251 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
252 |
+
input_blocks = {
|
253 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
|
254 |
+
for layer_id in range(num_input_blocks)
|
255 |
+
}
|
256 |
+
|
257 |
+
# Retrieves the keys for the middle blocks only
|
258 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
259 |
+
middle_blocks = {
|
260 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
|
261 |
+
for layer_id in range(num_middle_blocks)
|
262 |
+
}
|
263 |
+
|
264 |
+
# Retrieves the keys for the output blocks only
|
265 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
266 |
+
output_blocks = {
|
267 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
|
268 |
+
for layer_id in range(num_output_blocks)
|
269 |
+
}
|
270 |
+
|
271 |
+
for i in range(1, num_input_blocks):
|
272 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
273 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
274 |
+
|
275 |
+
resnets = [
|
276 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
277 |
+
]
|
278 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
279 |
+
|
280 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
281 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
282 |
+
f"input_blocks.{i}.0.op.weight"
|
283 |
+
)
|
284 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
285 |
+
f"input_blocks.{i}.0.op.bias"
|
286 |
+
)
|
287 |
+
|
288 |
+
paths = renew_resnet_paths(resnets)
|
289 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
290 |
+
assign_to_checkpoint(
|
291 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
292 |
+
)
|
293 |
+
|
294 |
+
if len(attentions):
|
295 |
+
paths = renew_attention_paths(attentions)
|
296 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
297 |
+
assign_to_checkpoint(
|
298 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
299 |
+
)
|
300 |
+
|
301 |
+
resnet_0 = middle_blocks[0]
|
302 |
+
attentions = middle_blocks[1]
|
303 |
+
resnet_1 = middle_blocks[2]
|
304 |
+
|
305 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
306 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
307 |
+
|
308 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
309 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
310 |
+
|
311 |
+
attentions_paths = renew_attention_paths(attentions)
|
312 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
313 |
+
assign_to_checkpoint(
|
314 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
315 |
+
)
|
316 |
+
|
317 |
+
for i in range(num_output_blocks):
|
318 |
+
block_id = i // (config["layers_per_block"] + 1)
|
319 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
320 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
321 |
+
output_block_list = {}
|
322 |
+
|
323 |
+
for layer in output_block_layers:
|
324 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
325 |
+
if layer_id in output_block_list:
|
326 |
+
output_block_list[layer_id].append(layer_name)
|
327 |
+
else:
|
328 |
+
output_block_list[layer_id] = [layer_name]
|
329 |
+
|
330 |
+
if len(output_block_list) > 1:
|
331 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
332 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
333 |
+
|
334 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
335 |
+
paths = renew_resnet_paths(resnets)
|
336 |
+
|
337 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
338 |
+
assign_to_checkpoint(
|
339 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
340 |
+
)
|
341 |
+
|
342 |
+
# オリジナル:
|
343 |
+
# if ["conv.weight", "conv.bias"] in output_block_list.values():
|
344 |
+
# index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
345 |
+
|
346 |
+
# biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
|
347 |
+
for l in output_block_list.values():
|
348 |
+
l.sort()
|
349 |
+
|
350 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
351 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
352 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
353 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
354 |
+
]
|
355 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
356 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
357 |
+
]
|
358 |
+
|
359 |
+
# Clear attentions as they have been attributed above.
|
360 |
+
if len(attentions) == 2:
|
361 |
+
attentions = []
|
362 |
+
|
363 |
+
if len(attentions):
|
364 |
+
paths = renew_attention_paths(attentions)
|
365 |
+
meta_path = {
|
366 |
+
"old": f"output_blocks.{i}.1",
|
367 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
368 |
+
}
|
369 |
+
assign_to_checkpoint(
|
370 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
371 |
+
)
|
372 |
+
else:
|
373 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
374 |
+
for path in resnet_0_paths:
|
375 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
376 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
377 |
+
|
378 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
379 |
+
|
380 |
+
# SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
|
381 |
+
if v2:
|
382 |
+
linear_transformer_to_conv(new_checkpoint)
|
383 |
+
|
384 |
+
return new_checkpoint
|
385 |
+
|
386 |
+
|
387 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
388 |
+
# extract state dict for VAE
|
389 |
+
vae_state_dict = {}
|
390 |
+
vae_key = "first_stage_model."
|
391 |
+
keys = list(checkpoint.keys())
|
392 |
+
for key in keys:
|
393 |
+
if key.startswith(vae_key):
|
394 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
395 |
+
# if len(vae_state_dict) == 0:
|
396 |
+
# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
|
397 |
+
# vae_state_dict = checkpoint
|
398 |
+
|
399 |
+
new_checkpoint = {}
|
400 |
+
|
401 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
402 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
403 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
404 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
405 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
406 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
407 |
+
|
408 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
409 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
410 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
411 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
412 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
413 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
414 |
+
|
415 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
416 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
417 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
418 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
419 |
+
|
420 |
+
# Retrieves the keys for the encoder down blocks only
|
421 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
422 |
+
down_blocks = {
|
423 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
424 |
+
}
|
425 |
+
|
426 |
+
# Retrieves the keys for the decoder up blocks only
|
427 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
428 |
+
up_blocks = {
|
429 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
430 |
+
}
|
431 |
+
|
432 |
+
for i in range(num_down_blocks):
|
433 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
434 |
+
|
435 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
436 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
437 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
438 |
+
)
|
439 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
440 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
441 |
+
)
|
442 |
+
|
443 |
+
paths = renew_vae_resnet_paths(resnets)
|
444 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
445 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
446 |
+
|
447 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
448 |
+
num_mid_res_blocks = 2
|
449 |
+
for i in range(1, num_mid_res_blocks + 1):
|
450 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
451 |
+
|
452 |
+
paths = renew_vae_resnet_paths(resnets)
|
453 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
454 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
455 |
+
|
456 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
457 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
458 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
459 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
460 |
+
conv_attn_to_linear(new_checkpoint)
|
461 |
+
|
462 |
+
for i in range(num_up_blocks):
|
463 |
+
block_id = num_up_blocks - 1 - i
|
464 |
+
resnets = [
|
465 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
466 |
+
]
|
467 |
+
|
468 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
469 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
470 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
471 |
+
]
|
472 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
473 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
474 |
+
]
|
475 |
+
|
476 |
+
paths = renew_vae_resnet_paths(resnets)
|
477 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
478 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
479 |
+
|
480 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
481 |
+
num_mid_res_blocks = 2
|
482 |
+
for i in range(1, num_mid_res_blocks + 1):
|
483 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
484 |
+
|
485 |
+
paths = renew_vae_resnet_paths(resnets)
|
486 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
487 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
488 |
+
|
489 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
490 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
491 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
492 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
493 |
+
conv_attn_to_linear(new_checkpoint)
|
494 |
+
return new_checkpoint
|
495 |
+
|
496 |
+
|
497 |
+
def create_unet_diffusers_config(v2):
|
498 |
+
"""
|
499 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
500 |
+
"""
|
501 |
+
# unet_params = original_config.model.params.unet_config.params
|
502 |
+
|
503 |
+
block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
|
504 |
+
|
505 |
+
down_block_types = []
|
506 |
+
resolution = 1
|
507 |
+
for i in range(len(block_out_channels)):
|
508 |
+
block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
|
509 |
+
down_block_types.append(block_type)
|
510 |
+
if i != len(block_out_channels) - 1:
|
511 |
+
resolution *= 2
|
512 |
+
|
513 |
+
up_block_types = []
|
514 |
+
for i in range(len(block_out_channels)):
|
515 |
+
block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
|
516 |
+
up_block_types.append(block_type)
|
517 |
+
resolution //= 2
|
518 |
+
|
519 |
+
config = dict(
|
520 |
+
sample_size=UNET_PARAMS_IMAGE_SIZE,
|
521 |
+
in_channels=UNET_PARAMS_IN_CHANNELS,
|
522 |
+
out_channels=UNET_PARAMS_OUT_CHANNELS,
|
523 |
+
down_block_types=tuple(down_block_types),
|
524 |
+
up_block_types=tuple(up_block_types),
|
525 |
+
block_out_channels=tuple(block_out_channels),
|
526 |
+
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
527 |
+
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
|
528 |
+
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
529 |
+
)
|
530 |
+
|
531 |
+
return config
|
532 |
+
|
533 |
+
|
534 |
+
def create_vae_diffusers_config():
|
535 |
+
"""
|
536 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
537 |
+
"""
|
538 |
+
# vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
539 |
+
# _ = original_config.model.params.first_stage_config.params.embed_dim
|
540 |
+
block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
|
541 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
542 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
543 |
+
|
544 |
+
config = dict(
|
545 |
+
sample_size=VAE_PARAMS_RESOLUTION,
|
546 |
+
in_channels=VAE_PARAMS_IN_CHANNELS,
|
547 |
+
out_channels=VAE_PARAMS_OUT_CH,
|
548 |
+
down_block_types=tuple(down_block_types),
|
549 |
+
up_block_types=tuple(up_block_types),
|
550 |
+
block_out_channels=tuple(block_out_channels),
|
551 |
+
latent_channels=VAE_PARAMS_Z_CHANNELS,
|
552 |
+
layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
|
553 |
+
)
|
554 |
+
return config
|
555 |
+
|
556 |
+
|
557 |
+
def convert_ldm_clip_checkpoint_v1(checkpoint):
|
558 |
+
keys = list(checkpoint.keys())
|
559 |
+
text_model_dict = {}
|
560 |
+
for key in keys:
|
561 |
+
if key.startswith("cond_stage_model.transformer"):
|
562 |
+
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
|
563 |
+
return text_model_dict
|
564 |
+
|
565 |
+
|
566 |
+
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
567 |
+
# 嫌になるくらい違うぞ!
|
568 |
+
def convert_key(key):
|
569 |
+
if not key.startswith("cond_stage_model"):
|
570 |
+
return None
|
571 |
+
|
572 |
+
# common conversion
|
573 |
+
key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
|
574 |
+
key = key.replace("cond_stage_model.model.", "text_model.")
|
575 |
+
|
576 |
+
if "resblocks" in key:
|
577 |
+
# resblocks conversion
|
578 |
+
key = key.replace(".resblocks.", ".layers.")
|
579 |
+
if ".ln_" in key:
|
580 |
+
key = key.replace(".ln_", ".layer_norm")
|
581 |
+
elif ".mlp." in key:
|
582 |
+
key = key.replace(".c_fc.", ".fc1.")
|
583 |
+
key = key.replace(".c_proj.", ".fc2.")
|
584 |
+
elif '.attn.out_proj' in key:
|
585 |
+
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
586 |
+
elif '.attn.in_proj' in key:
|
587 |
+
key = None # 特殊なので後で処理する
|
588 |
+
else:
|
589 |
+
raise ValueError(f"unexpected key in SD: {key}")
|
590 |
+
elif '.positional_embedding' in key:
|
591 |
+
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
592 |
+
elif '.text_projection' in key:
|
593 |
+
key = None # 使われない???
|
594 |
+
elif '.logit_scale' in key:
|
595 |
+
key = None # 使われない???
|
596 |
+
elif '.token_embedding' in key:
|
597 |
+
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
598 |
+
elif '.ln_final' in key:
|
599 |
+
key = key.replace(".ln_final", ".final_layer_norm")
|
600 |
+
return key
|
601 |
+
|
602 |
+
keys = list(checkpoint.keys())
|
603 |
+
new_sd = {}
|
604 |
+
for key in keys:
|
605 |
+
# remove resblocks 23
|
606 |
+
if '.resblocks.23.' in key:
|
607 |
+
continue
|
608 |
+
new_key = convert_key(key)
|
609 |
+
if new_key is None:
|
610 |
+
continue
|
611 |
+
new_sd[new_key] = checkpoint[key]
|
612 |
+
|
613 |
+
# attnの変換
|
614 |
+
for key in keys:
|
615 |
+
if '.resblocks.23.' in key:
|
616 |
+
continue
|
617 |
+
if '.resblocks' in key and '.attn.in_proj_' in key:
|
618 |
+
# 三つに分割
|
619 |
+
values = torch.chunk(checkpoint[key], 3)
|
620 |
+
|
621 |
+
key_suffix = ".weight" if "weight" in key else ".bias"
|
622 |
+
key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
|
623 |
+
key_pfx = key_pfx.replace("_weight", "")
|
624 |
+
key_pfx = key_pfx.replace("_bias", "")
|
625 |
+
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
626 |
+
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
627 |
+
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
628 |
+
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
629 |
+
|
630 |
+
# rename or add position_ids
|
631 |
+
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
|
632 |
+
if ANOTHER_POSITION_IDS_KEY in new_sd:
|
633 |
+
# waifu diffusion v1.4
|
634 |
+
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
|
635 |
+
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
636 |
+
else:
|
637 |
+
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
638 |
+
|
639 |
+
new_sd["text_model.embeddings.position_ids"] = position_ids
|
640 |
+
return new_sd
|
641 |
+
|
642 |
+
# endregion
|
643 |
+
|
644 |
+
|
645 |
+
# region Diffusers->StableDiffusion の変換コード
|
646 |
+
# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
|
647 |
+
|
648 |
+
def conv_transformer_to_linear(checkpoint):
|
649 |
+
keys = list(checkpoint.keys())
|
650 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
651 |
+
for key in keys:
|
652 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
653 |
+
if checkpoint[key].ndim > 2:
|
654 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
655 |
+
|
656 |
+
|
657 |
+
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
658 |
+
unet_conversion_map = [
|
659 |
+
# (stable-diffusion, HF Diffusers)
|
660 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
661 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
662 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
663 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
664 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
665 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
666 |
+
("out.0.weight", "conv_norm_out.weight"),
|
667 |
+
("out.0.bias", "conv_norm_out.bias"),
|
668 |
+
("out.2.weight", "conv_out.weight"),
|
669 |
+
("out.2.bias", "conv_out.bias"),
|
670 |
+
]
|
671 |
+
|
672 |
+
unet_conversion_map_resnet = [
|
673 |
+
# (stable-diffusion, HF Diffusers)
|
674 |
+
("in_layers.0", "norm1"),
|
675 |
+
("in_layers.2", "conv1"),
|
676 |
+
("out_layers.0", "norm2"),
|
677 |
+
("out_layers.3", "conv2"),
|
678 |
+
("emb_layers.1", "time_emb_proj"),
|
679 |
+
("skip_connection", "conv_shortcut"),
|
680 |
+
]
|
681 |
+
|
682 |
+
unet_conversion_map_layer = []
|
683 |
+
for i in range(4):
|
684 |
+
# loop over downblocks/upblocks
|
685 |
+
|
686 |
+
for j in range(2):
|
687 |
+
# loop over resnets/attentions for downblocks
|
688 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
689 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
690 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
691 |
+
|
692 |
+
if i < 3:
|
693 |
+
# no attention layers in down_blocks.3
|
694 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
695 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
696 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
697 |
+
|
698 |
+
for j in range(3):
|
699 |
+
# loop over resnets/attentions for upblocks
|
700 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
701 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
702 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
703 |
+
|
704 |
+
if i > 0:
|
705 |
+
# no attention layers in up_blocks.0
|
706 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
707 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
708 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
709 |
+
|
710 |
+
if i < 3:
|
711 |
+
# no downsample in down_blocks.3
|
712 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
713 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
714 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
715 |
+
|
716 |
+
# no upsample in up_blocks.3
|
717 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
718 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
719 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
720 |
+
|
721 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
722 |
+
sd_mid_atn_prefix = "middle_block.1."
|
723 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
724 |
+
|
725 |
+
for j in range(2):
|
726 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
727 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
728 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
729 |
+
|
730 |
+
# buyer beware: this is a *brittle* function,
|
731 |
+
# and correct output requires that all of these pieces interact in
|
732 |
+
# the exact order in which I have arranged them.
|
733 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
734 |
+
for sd_name, hf_name in unet_conversion_map:
|
735 |
+
mapping[hf_name] = sd_name
|
736 |
+
for k, v in mapping.items():
|
737 |
+
if "resnets" in k:
|
738 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
739 |
+
v = v.replace(hf_part, sd_part)
|
740 |
+
mapping[k] = v
|
741 |
+
for k, v in mapping.items():
|
742 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
743 |
+
v = v.replace(hf_part, sd_part)
|
744 |
+
mapping[k] = v
|
745 |
+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
746 |
+
|
747 |
+
if v2:
|
748 |
+
conv_transformer_to_linear(new_state_dict)
|
749 |
+
|
750 |
+
return new_state_dict
|
751 |
+
|
752 |
+
|
753 |
+
# ================#
|
754 |
+
# VAE Conversion #
|
755 |
+
# ================#
|
756 |
+
|
757 |
+
def reshape_weight_for_sd(w):
|
758 |
+
# convert HF linear weights to SD conv2d weights
|
759 |
+
return w.reshape(*w.shape, 1, 1)
|
760 |
+
|
761 |
+
|
762 |
+
def convert_vae_state_dict(vae_state_dict):
|
763 |
+
vae_conversion_map = [
|
764 |
+
# (stable-diffusion, HF Diffusers)
|
765 |
+
("nin_shortcut", "conv_shortcut"),
|
766 |
+
("norm_out", "conv_norm_out"),
|
767 |
+
("mid.attn_1.", "mid_block.attentions.0."),
|
768 |
+
]
|
769 |
+
|
770 |
+
for i in range(4):
|
771 |
+
# down_blocks have two resnets
|
772 |
+
for j in range(2):
|
773 |
+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
774 |
+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
775 |
+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
776 |
+
|
777 |
+
if i < 3:
|
778 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
779 |
+
sd_downsample_prefix = f"down.{i}.downsample."
|
780 |
+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
781 |
+
|
782 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
783 |
+
sd_upsample_prefix = f"up.{3-i}.upsample."
|
784 |
+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
785 |
+
|
786 |
+
# up_blocks have three resnets
|
787 |
+
# also, up blocks in hf are numbered in reverse from sd
|
788 |
+
for j in range(3):
|
789 |
+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
790 |
+
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
791 |
+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
792 |
+
|
793 |
+
# this part accounts for mid blocks in both the encoder and the decoder
|
794 |
+
for i in range(2):
|
795 |
+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
796 |
+
sd_mid_res_prefix = f"mid.block_{i+1}."
|
797 |
+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
798 |
+
|
799 |
+
vae_conversion_map_attn = [
|
800 |
+
# (stable-diffusion, HF Diffusers)
|
801 |
+
("norm.", "group_norm."),
|
802 |
+
("q.", "query."),
|
803 |
+
("k.", "key."),
|
804 |
+
("v.", "value."),
|
805 |
+
("proj_out.", "proj_attn."),
|
806 |
+
]
|
807 |
+
|
808 |
+
mapping = {k: k for k in vae_state_dict.keys()}
|
809 |
+
for k, v in mapping.items():
|
810 |
+
for sd_part, hf_part in vae_conversion_map:
|
811 |
+
v = v.replace(hf_part, sd_part)
|
812 |
+
mapping[k] = v
|
813 |
+
for k, v in mapping.items():
|
814 |
+
if "attentions" in k:
|
815 |
+
for sd_part, hf_part in vae_conversion_map_attn:
|
816 |
+
v = v.replace(hf_part, sd_part)
|
817 |
+
mapping[k] = v
|
818 |
+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
819 |
+
weights_to_convert = ["q", "k", "v", "proj_out"]
|
820 |
+
for k, v in new_state_dict.items():
|
821 |
+
for weight_name in weights_to_convert:
|
822 |
+
if f"mid.attn_1.{weight_name}.weight" in k:
|
823 |
+
# print(f"Reshaping {k} for SD format")
|
824 |
+
new_state_dict[k] = reshape_weight_for_sd(v)
|
825 |
+
|
826 |
+
return new_state_dict
|
827 |
+
|
828 |
+
|
829 |
+
# endregion
|
830 |
+
|
831 |
+
# region 自作のモデル読み書きなど
|
832 |
+
|
833 |
+
def is_safetensors(path):
|
834 |
+
return os.path.splitext(path)[1].lower() == '.safetensors'
|
835 |
+
|
836 |
+
|
837 |
+
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
838 |
+
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
839 |
+
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
840 |
+
('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
|
841 |
+
('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
|
842 |
+
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
|
843 |
+
]
|
844 |
+
|
845 |
+
if is_safetensors(ckpt_path):
|
846 |
+
checkpoint = None
|
847 |
+
state_dict = load_file(ckpt_path, "cpu")
|
848 |
+
else:
|
849 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
850 |
+
if "state_dict" in checkpoint:
|
851 |
+
state_dict = checkpoint["state_dict"]
|
852 |
+
else:
|
853 |
+
state_dict = checkpoint
|
854 |
+
checkpoint = None
|
855 |
+
|
856 |
+
key_reps = []
|
857 |
+
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
858 |
+
for key in state_dict.keys():
|
859 |
+
if key.startswith(rep_from):
|
860 |
+
new_key = rep_to + key[len(rep_from):]
|
861 |
+
key_reps.append((key, new_key))
|
862 |
+
|
863 |
+
for key, new_key in key_reps:
|
864 |
+
state_dict[new_key] = state_dict[key]
|
865 |
+
del state_dict[key]
|
866 |
+
|
867 |
+
return checkpoint, state_dict
|
868 |
+
|
869 |
+
|
870 |
+
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
871 |
+
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
872 |
+
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
873 |
+
if dtype is not None:
|
874 |
+
for k, v in state_dict.items():
|
875 |
+
if type(v) is torch.Tensor:
|
876 |
+
state_dict[k] = v.to(dtype)
|
877 |
+
|
878 |
+
# Convert the UNet2DConditionModel model.
|
879 |
+
unet_config = create_unet_diffusers_config(v2)
|
880 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
|
881 |
+
|
882 |
+
unet = UNet2DConditionModel(**unet_config)
|
883 |
+
info = unet.load_state_dict(converted_unet_checkpoint)
|
884 |
+
print("loading u-net:", info)
|
885 |
+
|
886 |
+
# Convert the VAE model.
|
887 |
+
vae_config = create_vae_diffusers_config()
|
888 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
|
889 |
+
|
890 |
+
vae = AutoencoderKL(**vae_config)
|
891 |
+
info = vae.load_state_dict(converted_vae_checkpoint)
|
892 |
+
print("loading vae:", info)
|
893 |
+
|
894 |
+
# convert text_model
|
895 |
+
if v2:
|
896 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
|
897 |
+
cfg = CLIPTextConfig(
|
898 |
+
vocab_size=49408,
|
899 |
+
hidden_size=1024,
|
900 |
+
intermediate_size=4096,
|
901 |
+
num_hidden_layers=23,
|
902 |
+
num_attention_heads=16,
|
903 |
+
max_position_embeddings=77,
|
904 |
+
hidden_act="gelu",
|
905 |
+
layer_norm_eps=1e-05,
|
906 |
+
dropout=0.0,
|
907 |
+
attention_dropout=0.0,
|
908 |
+
initializer_range=0.02,
|
909 |
+
initializer_factor=1.0,
|
910 |
+
pad_token_id=1,
|
911 |
+
bos_token_id=0,
|
912 |
+
eos_token_id=2,
|
913 |
+
model_type="clip_text_model",
|
914 |
+
projection_dim=512,
|
915 |
+
torch_dtype="float32",
|
916 |
+
transformers_version="4.25.0.dev0",
|
917 |
+
)
|
918 |
+
text_model = CLIPTextModel._from_config(cfg)
|
919 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
920 |
+
else:
|
921 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
922 |
+
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
923 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
924 |
+
print("loading text encoder:", info)
|
925 |
+
|
926 |
+
return text_model, vae, unet
|
927 |
+
|
928 |
+
|
929 |
+
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
|
930 |
+
def convert_key(key):
|
931 |
+
# position_idsの除去
|
932 |
+
if ".position_ids" in key:
|
933 |
+
return None
|
934 |
+
|
935 |
+
# common
|
936 |
+
key = key.replace("text_model.encoder.", "transformer.")
|
937 |
+
key = key.replace("text_model.", "")
|
938 |
+
if "layers" in key:
|
939 |
+
# resblocks conversion
|
940 |
+
key = key.replace(".layers.", ".resblocks.")
|
941 |
+
if ".layer_norm" in key:
|
942 |
+
key = key.replace(".layer_norm", ".ln_")
|
943 |
+
elif ".mlp." in key:
|
944 |
+
key = key.replace(".fc1.", ".c_fc.")
|
945 |
+
key = key.replace(".fc2.", ".c_proj.")
|
946 |
+
elif '.self_attn.out_proj' in key:
|
947 |
+
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
948 |
+
elif '.self_attn.' in key:
|
949 |
+
key = None # 特殊なので後で処理する
|
950 |
+
else:
|
951 |
+
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
952 |
+
elif '.position_embedding' in key:
|
953 |
+
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
954 |
+
elif '.token_embedding' in key:
|
955 |
+
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
956 |
+
elif 'final_layer_norm' in key:
|
957 |
+
key = key.replace("final_layer_norm", "ln_final")
|
958 |
+
return key
|
959 |
+
|
960 |
+
keys = list(checkpoint.keys())
|
961 |
+
new_sd = {}
|
962 |
+
for key in keys:
|
963 |
+
new_key = convert_key(key)
|
964 |
+
if new_key is None:
|
965 |
+
continue
|
966 |
+
new_sd[new_key] = checkpoint[key]
|
967 |
+
|
968 |
+
# attnの変換
|
969 |
+
for key in keys:
|
970 |
+
if 'layers' in key and 'q_proj' in key:
|
971 |
+
# 三つを結合
|
972 |
+
key_q = key
|
973 |
+
key_k = key.replace("q_proj", "k_proj")
|
974 |
+
key_v = key.replace("q_proj", "v_proj")
|
975 |
+
|
976 |
+
value_q = checkpoint[key_q]
|
977 |
+
value_k = checkpoint[key_k]
|
978 |
+
value_v = checkpoint[key_v]
|
979 |
+
value = torch.cat([value_q, value_k, value_v])
|
980 |
+
|
981 |
+
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
982 |
+
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
983 |
+
new_sd[new_key] = value
|
984 |
+
|
985 |
+
# 最後の層などを捏造するか
|
986 |
+
if make_dummy_weights:
|
987 |
+
print("make dummy weights for resblock.23, text_projection and logit scale.")
|
988 |
+
keys = list(new_sd.keys())
|
989 |
+
for key in keys:
|
990 |
+
if key.startswith("transformer.resblocks.22."):
|
991 |
+
new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
|
992 |
+
|
993 |
+
# Diffusersに含まれない重みを作っておく
|
994 |
+
new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
|
995 |
+
new_sd['logit_scale'] = torch.tensor(1)
|
996 |
+
|
997 |
+
return new_sd
|
998 |
+
|
999 |
+
|
1000 |
+
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
|
1001 |
+
if ckpt_path is not None:
|
1002 |
+
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
1003 |
+
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
1004 |
+
if checkpoint is None: # safetensors または state_dictのckpt
|
1005 |
+
checkpoint = {}
|
1006 |
+
strict = False
|
1007 |
+
else:
|
1008 |
+
strict = True
|
1009 |
+
if "state_dict" in state_dict:
|
1010 |
+
del state_dict["state_dict"]
|
1011 |
+
else:
|
1012 |
+
# 新しく作る
|
1013 |
+
assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
|
1014 |
+
checkpoint = {}
|
1015 |
+
state_dict = {}
|
1016 |
+
strict = False
|
1017 |
+
|
1018 |
+
def update_sd(prefix, sd):
|
1019 |
+
for k, v in sd.items():
|
1020 |
+
key = prefix + k
|
1021 |
+
assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
|
1022 |
+
if save_dtype is not None:
|
1023 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
1024 |
+
state_dict[key] = v
|
1025 |
+
|
1026 |
+
# Convert the UNet model
|
1027 |
+
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
1028 |
+
update_sd("model.diffusion_model.", unet_state_dict)
|
1029 |
+
|
1030 |
+
# Convert the text encoder model
|
1031 |
+
if v2:
|
1032 |
+
make_dummy = ckpt_path is None # 参照元のcheckpoint���ない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
|
1033 |
+
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
|
1034 |
+
update_sd("cond_stage_model.model.", text_enc_dict)
|
1035 |
+
else:
|
1036 |
+
text_enc_dict = text_encoder.state_dict()
|
1037 |
+
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
1038 |
+
|
1039 |
+
# Convert the VAE
|
1040 |
+
if vae is not None:
|
1041 |
+
vae_dict = convert_vae_state_dict(vae.state_dict())
|
1042 |
+
update_sd("first_stage_model.", vae_dict)
|
1043 |
+
|
1044 |
+
# Put together new checkpoint
|
1045 |
+
key_count = len(state_dict.keys())
|
1046 |
+
new_ckpt = {'state_dict': state_dict}
|
1047 |
+
|
1048 |
+
if 'epoch' in checkpoint:
|
1049 |
+
epochs += checkpoint['epoch']
|
1050 |
+
if 'global_step' in checkpoint:
|
1051 |
+
steps += checkpoint['global_step']
|
1052 |
+
|
1053 |
+
new_ckpt['epoch'] = epochs
|
1054 |
+
new_ckpt['global_step'] = steps
|
1055 |
+
|
1056 |
+
if is_safetensors(output_file):
|
1057 |
+
# TODO Tensor以外のdictの値を削除したほうがいいか
|
1058 |
+
save_file(state_dict, output_file)
|
1059 |
+
else:
|
1060 |
+
torch.save(new_ckpt, output_file)
|
1061 |
+
|
1062 |
+
return key_count
|
1063 |
+
|
1064 |
+
|
1065 |
+
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
|
1066 |
+
if pretrained_model_name_or_path is None:
|
1067 |
+
# load default settings for v1/v2
|
1068 |
+
if v2:
|
1069 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
|
1070 |
+
else:
|
1071 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
|
1072 |
+
|
1073 |
+
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
1074 |
+
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
1075 |
+
if vae is None:
|
1076 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
1077 |
+
|
1078 |
+
pipeline = StableDiffusionPipeline(
|
1079 |
+
unet=unet,
|
1080 |
+
text_encoder=text_encoder,
|
1081 |
+
vae=vae,
|
1082 |
+
scheduler=scheduler,
|
1083 |
+
tokenizer=tokenizer,
|
1084 |
+
safety_checker=None,
|
1085 |
+
feature_extractor=None,
|
1086 |
+
requires_safety_checker=None,
|
1087 |
+
)
|
1088 |
+
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
1089 |
+
|
1090 |
+
|
1091 |
+
VAE_PREFIX = "first_stage_model."
|
1092 |
+
|
1093 |
+
|
1094 |
+
def load_vae(vae_id, dtype):
|
1095 |
+
print(f"load VAE: {vae_id}")
|
1096 |
+
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
1097 |
+
# Diffusers local/remote
|
1098 |
+
try:
|
1099 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
|
1100 |
+
except EnvironmentError as e:
|
1101 |
+
print(f"exception occurs in loading vae: {e}")
|
1102 |
+
print("retry with subfolder='vae'")
|
1103 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
|
1104 |
+
return vae
|
1105 |
+
|
1106 |
+
# local
|
1107 |
+
vae_config = create_vae_diffusers_config()
|
1108 |
+
|
1109 |
+
if vae_id.endswith(".bin"):
|
1110 |
+
# SD 1.5 VAE on Huggingface
|
1111 |
+
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
|
1112 |
+
else:
|
1113 |
+
# StableDiffusion
|
1114 |
+
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
|
1115 |
+
else torch.load(vae_id, map_location="cpu"))
|
1116 |
+
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
|
1117 |
+
|
1118 |
+
# vae only or full model
|
1119 |
+
full_model = False
|
1120 |
+
for vae_key in vae_sd:
|
1121 |
+
if vae_key.startswith(VAE_PREFIX):
|
1122 |
+
full_model = True
|
1123 |
+
break
|
1124 |
+
if not full_model:
|
1125 |
+
sd = {}
|
1126 |
+
for key, value in vae_sd.items():
|
1127 |
+
sd[VAE_PREFIX + key] = value
|
1128 |
+
vae_sd = sd
|
1129 |
+
del sd
|
1130 |
+
|
1131 |
+
# Convert the VAE model.
|
1132 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
1133 |
+
|
1134 |
+
vae = AutoencoderKL(**vae_config)
|
1135 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
1136 |
+
return vae
|
1137 |
+
|
1138 |
+
# endregion
|
1139 |
+
|
1140 |
+
|
1141 |
+
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
|
1142 |
+
max_width, max_height = max_reso
|
1143 |
+
max_area = (max_width // divisible) * (max_height // divisible)
|
1144 |
+
|
1145 |
+
resos = set()
|
1146 |
+
|
1147 |
+
size = int(math.sqrt(max_area)) * divisible
|
1148 |
+
resos.add((size, size))
|
1149 |
+
|
1150 |
+
size = min_size
|
1151 |
+
while size <= max_size:
|
1152 |
+
width = size
|
1153 |
+
height = min(max_size, (max_area // (width // divisible)) * divisible)
|
1154 |
+
resos.add((width, height))
|
1155 |
+
resos.add((height, width))
|
1156 |
+
|
1157 |
+
# # make additional resos
|
1158 |
+
# if width >= height and width - divisible >= min_size:
|
1159 |
+
# resos.add((width - divisible, height))
|
1160 |
+
# resos.add((height, width - divisible))
|
1161 |
+
# if height >= width and height - divisible >= min_size:
|
1162 |
+
# resos.add((width, height - divisible))
|
1163 |
+
# resos.add((height - divisible, width))
|
1164 |
+
|
1165 |
+
size += divisible
|
1166 |
+
|
1167 |
+
resos = list(resos)
|
1168 |
+
resos.sort()
|
1169 |
+
|
1170 |
+
aspect_ratios = [w / h for w, h in resos]
|
1171 |
+
return resos, aspect_ratios
|
1172 |
+
|
1173 |
+
|
1174 |
+
if __name__ == '__main__':
|
1175 |
+
resos, aspect_ratios = make_bucket_resolutions((512, 768))
|
1176 |
+
print(len(resos))
|
1177 |
+
print(resos)
|
1178 |
+
print(aspect_ratios)
|
1179 |
+
|
1180 |
+
ars = set()
|
1181 |
+
for ar in aspect_ratios:
|
1182 |
+
if ar in ars:
|
1183 |
+
print("error! duplicate ar:", ar)
|
1184 |
+
ars.add(ar)
|
locon/kohya_utils.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# part of https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py
|
2 |
+
|
3 |
+
import hashlib
|
4 |
+
import safetensors
|
5 |
+
from io import BytesIO
|
6 |
+
|
7 |
+
|
8 |
+
def addnet_hash_legacy(b):
|
9 |
+
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
|
10 |
+
m = hashlib.sha256()
|
11 |
+
|
12 |
+
b.seek(0x100000)
|
13 |
+
m.update(b.read(0x10000))
|
14 |
+
return m.hexdigest()[0:8]
|
15 |
+
|
16 |
+
|
17 |
+
def addnet_hash_safetensors(b):
|
18 |
+
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
|
19 |
+
hash_sha256 = hashlib.sha256()
|
20 |
+
blksize = 1024 * 1024
|
21 |
+
|
22 |
+
b.seek(0)
|
23 |
+
header = b.read(8)
|
24 |
+
n = int.from_bytes(header, "little")
|
25 |
+
|
26 |
+
offset = n + 8
|
27 |
+
b.seek(offset)
|
28 |
+
for chunk in iter(lambda: b.read(blksize), b""):
|
29 |
+
hash_sha256.update(chunk)
|
30 |
+
|
31 |
+
return hash_sha256.hexdigest()
|
32 |
+
|
33 |
+
|
34 |
+
def precalculate_safetensors_hashes(tensors, metadata):
|
35 |
+
"""Precalculate the model hashes needed by sd-webui-additional-networks to
|
36 |
+
save time on indexing the model later."""
|
37 |
+
|
38 |
+
# Because writing user metadata to the file can change the result of
|
39 |
+
# sd_models.model_hash(), only retain the training metadata for purposes of
|
40 |
+
# calculating the hash, as they are meant to be immutable
|
41 |
+
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
|
42 |
+
|
43 |
+
bytes = safetensors.torch.save(tensors, metadata)
|
44 |
+
b = BytesIO(bytes)
|
45 |
+
|
46 |
+
model_hash = addnet_hash_safetensors(b)
|
47 |
+
legacy_hash = addnet_hash_legacy(b)
|
48 |
+
return model_hash, legacy_hash
|
locon/locon.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class LoConModule(nn.Module):
|
9 |
+
"""
|
10 |
+
modifed from kohya-ss/sd-scripts/networks/lora:LoRAModule
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, lora_name, org_module: nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
|
14 |
+
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
15 |
+
super().__init__()
|
16 |
+
self.lora_name = lora_name
|
17 |
+
self.lora_dim = lora_dim
|
18 |
+
|
19 |
+
if org_module.__class__.__name__ == 'Conv2d':
|
20 |
+
# For general LoCon
|
21 |
+
in_dim = org_module.in_channels
|
22 |
+
k_size = org_module.kernel_size
|
23 |
+
stride = org_module.stride
|
24 |
+
padding = org_module.padding
|
25 |
+
out_dim = org_module.out_channels
|
26 |
+
self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False)
|
27 |
+
self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
|
28 |
+
else:
|
29 |
+
in_dim = org_module.in_features
|
30 |
+
out_dim = org_module.out_features
|
31 |
+
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
|
32 |
+
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
|
33 |
+
|
34 |
+
if type(alpha) == torch.Tensor:
|
35 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
36 |
+
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
37 |
+
self.scale = alpha / self.lora_dim
|
38 |
+
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
39 |
+
|
40 |
+
# same as microsoft's
|
41 |
+
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
42 |
+
torch.nn.init.zeros_(self.lora_up.weight)
|
43 |
+
|
44 |
+
self.multiplier = multiplier
|
45 |
+
self.org_module = org_module # remove in applying
|
46 |
+
|
47 |
+
def apply_to(self):
|
48 |
+
self.org_forward = self.org_module.forward
|
49 |
+
self.org_module.forward = self.forward
|
50 |
+
del self.org_module
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
locon/locon_kohya.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LoCon network module
|
2 |
+
# reference:
|
3 |
+
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
4 |
+
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
5 |
+
# https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
|
6 |
+
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
from typing import List
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from .kohya_utils import *
|
13 |
+
from .locon import LoConModule
|
14 |
+
|
15 |
+
|
16 |
+
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
17 |
+
if network_dim is None:
|
18 |
+
network_dim = 4 # default
|
19 |
+
conv_dim = kwargs.get('conv_dim', network_dim)
|
20 |
+
conv_alpha = kwargs.get('conv_alpha', network_alpha)
|
21 |
+
network = LoRANetwork(
|
22 |
+
text_encoder, unet,
|
23 |
+
multiplier=multiplier,
|
24 |
+
lora_dim=network_dim, conv_lora_dim=conv_dim,
|
25 |
+
alpha=network_alpha, conv_alpha=conv_alpha
|
26 |
+
)
|
27 |
+
return network
|
28 |
+
|
29 |
+
|
30 |
+
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
|
31 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
32 |
+
from safetensors.torch import load_file, safe_open
|
33 |
+
weights_sd = load_file(file)
|
34 |
+
else:
|
35 |
+
weights_sd = torch.load(file, map_location='cpu')
|
36 |
+
|
37 |
+
# get dim (rank)
|
38 |
+
network_alpha = None
|
39 |
+
network_dim = None
|
40 |
+
for key, value in weights_sd.items():
|
41 |
+
if network_alpha is None and 'alpha' in key:
|
42 |
+
network_alpha = value
|
43 |
+
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
|
44 |
+
network_dim = value.size()[0]
|
45 |
+
|
46 |
+
if network_alpha is None:
|
47 |
+
network_alpha = network_dim
|
48 |
+
|
49 |
+
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
|
50 |
+
network.weights_sd = weights_sd
|
51 |
+
return network
|
52 |
+
|
53 |
+
torch.nn.Conv2d
|
54 |
+
class LoRANetwork(torch.nn.Module):
|
55 |
+
'''
|
56 |
+
LoRA + LoCon
|
57 |
+
'''
|
58 |
+
# Ignore proj_in or proj_out, their channels is only a few.
|
59 |
+
UNET_TARGET_REPLACE_MODULE = [
|
60 |
+
"Transformer2DModel",
|
61 |
+
"Attention",
|
62 |
+
"ResnetBlock2D",
|
63 |
+
"Downsample2D",
|
64 |
+
"Upsample2D"
|
65 |
+
]
|
66 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
67 |
+
LORA_PREFIX_UNET = 'lora_unet'
|
68 |
+
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
69 |
+
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
text_encoder, unet,
|
73 |
+
multiplier=1.0,
|
74 |
+
lora_dim=4, conv_lora_dim=4,
|
75 |
+
alpha=1, conv_alpha=1
|
76 |
+
) -> None:
|
77 |
+
super().__init__()
|
78 |
+
self.multiplier = multiplier
|
79 |
+
self.lora_dim = lora_dim
|
80 |
+
self.conv_lora_dim = int(conv_lora_dim)
|
81 |
+
if self.conv_lora_dim != self.lora_dim:
|
82 |
+
print('Apply different lora dim for conv layer')
|
83 |
+
print(f'LoCon Dim: {conv_lora_dim}, LoRA Dim: {lora_dim}')
|
84 |
+
self.alpha = alpha
|
85 |
+
self.conv_alpha = float(conv_alpha)
|
86 |
+
if self.alpha != self.conv_alpha:
|
87 |
+
print('Apply different alpha value for conv layer')
|
88 |
+
print(f'LoCon alpha: {conv_alpha}, LoRA alpha: {alpha}')
|
89 |
+
|
90 |
+
# create module instances
|
91 |
+
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoConModule]:
|
92 |
+
print('Create LoCon Module')
|
93 |
+
loras = []
|
94 |
+
for name, module in root_module.named_modules():
|
95 |
+
if module.__class__.__name__ in target_replace_modules:
|
96 |
+
for child_name, child_module in module.named_modules():
|
97 |
+
lora_name = prefix + '.' + name + '.' + child_name
|
98 |
+
lora_name = lora_name.replace('.', '_')
|
99 |
+
if child_module.__class__.__name__ == 'Linear':
|
100 |
+
lora = LoConModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
|
101 |
+
elif child_module.__class__.__name__ == 'Conv2d':
|
102 |
+
k_size, *_ = child_module.kernel_size
|
103 |
+
if k_size==1:
|
104 |
+
lora = LoConModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
|
105 |
+
else:
|
106 |
+
lora = LoConModule(lora_name, child_module, self.multiplier, self.conv_lora_dim, self.conv_alpha)
|
107 |
+
else:
|
108 |
+
continue
|
109 |
+
loras.append(lora)
|
110 |
+
return loras
|
111 |
+
|
112 |
+
self.text_encoder_loras = create_modules(
|
113 |
+
LoRANetwork.LORA_PREFIX_TEXT_ENCODER,
|
114 |
+
text_encoder,
|
115 |
+
LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
116 |
+
)
|
117 |
+
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
118 |
+
|
119 |
+
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
|
120 |
+
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
121 |
+
|
122 |
+
self.weights_sd = None
|
123 |
+
|
124 |
+
# assertion
|
125 |
+
names = set()
|
126 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
127 |
+
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
128 |
+
names.add(lora.lora_name)
|
129 |
+
|
130 |
+
def set_multiplier(self, multiplier):
|
131 |
+
self.multiplier = multiplier
|
132 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
133 |
+
lora.multiplier = self.multiplier
|
134 |
+
|
135 |
+
def load_weights(self, file):
|
136 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
137 |
+
from safetensors.torch import load_file, safe_open
|
138 |
+
self.weights_sd = load_file(file)
|
139 |
+
else:
|
140 |
+
self.weights_sd = torch.load(file, map_location='cpu')
|
141 |
+
|
142 |
+
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
|
143 |
+
if self.weights_sd:
|
144 |
+
weights_has_text_encoder = weights_has_unet = False
|
145 |
+
for key in self.weights_sd.keys():
|
146 |
+
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
147 |
+
weights_has_text_encoder = True
|
148 |
+
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
149 |
+
weights_has_unet = True
|
150 |
+
|
151 |
+
if apply_text_encoder is None:
|
152 |
+
apply_text_encoder = weights_has_text_encoder
|
153 |
+
else:
|
154 |
+
assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
|
155 |
+
|
156 |
+
if apply_unet is None:
|
157 |
+
apply_unet = weights_has_unet
|
158 |
+
else:
|
159 |
+
assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
|
160 |
+
else:
|
161 |
+
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
|
162 |
+
|
163 |
+
if apply_text_encoder:
|
164 |
+
print("enable LoRA for text encoder")
|
165 |
+
else:
|
166 |
+
self.text_encoder_loras = []
|
167 |
+
|
168 |
+
if apply_unet:
|
169 |
+
print("enable LoRA for U-Net")
|
170 |
+
else:
|
171 |
+
self.unet_loras = []
|
172 |
+
|
173 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
174 |
+
lora.apply_to()
|
175 |
+
self.add_module(lora.lora_name, lora)
|
176 |
+
|
177 |
+
if self.weights_sd:
|
178 |
+
# if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
|
179 |
+
info = self.load_state_dict(self.weights_sd, False)
|
180 |
+
print(f"weights are loaded: {info}")
|
181 |
+
|
182 |
+
def enable_gradient_checkpointing(self):
|
183 |
+
# not supported
|
184 |
+
pass
|
185 |
+
|
186 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
|
187 |
+
def enumerate_params(loras):
|
188 |
+
params = []
|
189 |
+
for lora in loras:
|
190 |
+
params.extend(lora.parameters())
|
191 |
+
return params
|
192 |
+
|
193 |
+
self.requires_grad_(True)
|
194 |
+
all_params = []
|
195 |
+
|
196 |
+
if self.text_encoder_loras:
|
197 |
+
param_data = {'params': enumerate_params(self.text_encoder_loras)}
|
198 |
+
if text_encoder_lr is not None:
|
199 |
+
param_data['lr'] = text_encoder_lr
|
200 |
+
all_params.append(param_data)
|
201 |
+
|
202 |
+
if self.unet_loras:
|
203 |
+
param_data = {'params': enumerate_params(self.unet_loras)}
|
204 |
+
if unet_lr is not None:
|
205 |
+
param_data['lr'] = unet_lr
|
206 |
+
all_params.append(param_data)
|
207 |
+
|
208 |
+
return all_params
|
209 |
+
|
210 |
+
def prepare_grad_etc(self, text_encoder, unet):
|
211 |
+
self.requires_grad_(True)
|
212 |
+
|
213 |
+
def on_epoch_start(self, text_encoder, unet):
|
214 |
+
self.train()
|
215 |
+
|
216 |
+
def get_trainable_params(self):
|
217 |
+
return self.parameters()
|
218 |
+
|
219 |
+
def save_weights(self, file, dtype, metadata):
|
220 |
+
if metadata is not None and len(metadata) == 0:
|
221 |
+
metadata = None
|
222 |
+
|
223 |
+
state_dict = self.state_dict()
|
224 |
+
|
225 |
+
if dtype is not None:
|
226 |
+
for key in list(state_dict.keys()):
|
227 |
+
v = state_dict[key]
|
228 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
229 |
+
state_dict[key] = v
|
230 |
+
|
231 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
232 |
+
from safetensors.torch import save_file
|
233 |
+
|
234 |
+
# Precalculate model hashes to save time on indexing
|
235 |
+
if metadata is None:
|
236 |
+
metadata = {}
|
237 |
+
model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
|
238 |
+
metadata["sshs_model_hash"] = model_hash
|
239 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
240 |
+
|
241 |
+
save_file(state_dict, file, metadata)
|
242 |
+
else:
|
243 |
+
torch.save(state_dict, file)
|
locon/utils.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import torch.linalg as linalg
|
6 |
+
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
|
10 |
+
def extract_conv(
|
11 |
+
weight: nn.Parameter|torch.Tensor,
|
12 |
+
lora_rank = 8
|
13 |
+
) -> tuple[nn.Parameter, nn.Parameter]:
|
14 |
+
out_ch, in_ch, kernel_size, _ = weight.shape
|
15 |
+
lora_rank = min(out_ch, in_ch, lora_rank)
|
16 |
+
|
17 |
+
U, S, Vh = linalg.svd(weight.reshape(out_ch, -1))
|
18 |
+
|
19 |
+
U = U[:, :lora_rank]
|
20 |
+
S = S[:lora_rank]
|
21 |
+
U = U @ torch.diag(S)
|
22 |
+
Vh = Vh[:lora_rank, :]
|
23 |
+
|
24 |
+
extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).cpu()
|
25 |
+
extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).cpu()
|
26 |
+
del U, S, Vh, weight
|
27 |
+
return extract_weight_A, extract_weight_B
|
28 |
+
|
29 |
+
|
30 |
+
def merge_conv(
|
31 |
+
weight_a: nn.Parameter|torch.Tensor,
|
32 |
+
weight_b: nn.Parameter|torch.Tensor,
|
33 |
+
):
|
34 |
+
rank, in_ch, kernel_size, k_ = weight_a.shape
|
35 |
+
out_ch, rank_, _, _ = weight_b.shape
|
36 |
+
|
37 |
+
assert rank == rank_ and kernel_size == k_
|
38 |
+
|
39 |
+
merged = weight_b.reshape(out_ch, -1) @ weight_a.reshape(rank, -1)
|
40 |
+
weight = merged.reshape(out_ch, in_ch, kernel_size, kernel_size)
|
41 |
+
return weight
|
42 |
+
|
43 |
+
|
44 |
+
def extract_linear(
|
45 |
+
weight: nn.Parameter|torch.Tensor,
|
46 |
+
lora_rank = 8
|
47 |
+
) -> tuple[nn.Parameter, nn.Parameter]:
|
48 |
+
out_ch, in_ch = weight.shape
|
49 |
+
lora_rank = min(out_ch, in_ch, lora_rank)
|
50 |
+
|
51 |
+
U, S, Vh = linalg.svd(weight)
|
52 |
+
|
53 |
+
U = U[:, :lora_rank]
|
54 |
+
S = S[:lora_rank]
|
55 |
+
U = U @ torch.diag(S)
|
56 |
+
Vh = Vh[:lora_rank, :]
|
57 |
+
|
58 |
+
extract_weight_A = Vh.reshape(lora_rank, in_ch).cpu()
|
59 |
+
extract_weight_B = U.reshape(out_ch, lora_rank).cpu()
|
60 |
+
del U, S, Vh, weight
|
61 |
+
return extract_weight_A, extract_weight_B
|
62 |
+
|
63 |
+
|
64 |
+
def merge_linear(
|
65 |
+
weight_a: nn.Parameter|torch.Tensor,
|
66 |
+
weight_b: nn.Parameter|torch.Tensor,
|
67 |
+
):
|
68 |
+
rank, in_ch = weight_a.shape
|
69 |
+
out_ch, rank_ = weight_b.shape
|
70 |
+
|
71 |
+
assert rank == rank_
|
72 |
+
|
73 |
+
weight = weight_b @ weight_a
|
74 |
+
return weight
|
75 |
+
|
76 |
+
|
77 |
+
def extract_diff(
|
78 |
+
base_model,
|
79 |
+
db_model,
|
80 |
+
lora_dim=4,
|
81 |
+
conv_lora_dim=4,
|
82 |
+
extract_device = 'cuda',
|
83 |
+
):
|
84 |
+
UNET_TARGET_REPLACE_MODULE = [
|
85 |
+
"Transformer2DModel",
|
86 |
+
"Attention",
|
87 |
+
"ResnetBlock2D",
|
88 |
+
"Downsample2D",
|
89 |
+
"Upsample2D"
|
90 |
+
]
|
91 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
92 |
+
LORA_PREFIX_UNET = 'lora_unet'
|
93 |
+
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
94 |
+
def make_state_dict(
|
95 |
+
prefix,
|
96 |
+
root_module: torch.nn.Module,
|
97 |
+
target_module: torch.nn.Module,
|
98 |
+
target_replace_modules
|
99 |
+
):
|
100 |
+
loras = {}
|
101 |
+
temp = {}
|
102 |
+
|
103 |
+
for name, module in root_module.named_modules():
|
104 |
+
if module.__class__.__name__ in target_replace_modules:
|
105 |
+
temp[name] = {}
|
106 |
+
for child_name, child_module in module.named_modules():
|
107 |
+
if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}:
|
108 |
+
continue
|
109 |
+
temp[name][child_name] = child_module.weight
|
110 |
+
|
111 |
+
for name, module in tqdm(list(target_module.named_modules())):
|
112 |
+
if name in temp:
|
113 |
+
weights = temp[name]
|
114 |
+
for child_name, child_module in module.named_modules():
|
115 |
+
lora_name = prefix + '.' + name + '.' + child_name
|
116 |
+
lora_name = lora_name.replace('.', '_')
|
117 |
+
if child_module.__class__.__name__ == 'Linear':
|
118 |
+
extract_a, extract_b = extract_linear(
|
119 |
+
(child_module.weight - weights[child_name]),
|
120 |
+
lora_dim
|
121 |
+
)
|
122 |
+
elif child_module.__class__.__name__ == 'Conv2d':
|
123 |
+
extract_a, extract_b = extract_conv(
|
124 |
+
(child_module.weight - weights[child_name]),
|
125 |
+
conv_lora_dim
|
126 |
+
)
|
127 |
+
else:
|
128 |
+
continue
|
129 |
+
|
130 |
+
loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().half()
|
131 |
+
loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().half()
|
132 |
+
loras[f'{lora_name}.alpha'] = torch.Tensor([int(extract_a.shape[0])]).detach().cpu().half()
|
133 |
+
del extract_a, extract_b
|
134 |
+
return loras
|
135 |
+
|
136 |
+
text_encoder_loras = make_state_dict(
|
137 |
+
LORA_PREFIX_TEXT_ENCODER,
|
138 |
+
base_model[0], db_model[0],
|
139 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE
|
140 |
+
)
|
141 |
+
|
142 |
+
unet_loras = make_state_dict(
|
143 |
+
LORA_PREFIX_UNET,
|
144 |
+
base_model[2], db_model[2],
|
145 |
+
UNET_TARGET_REPLACE_MODULE
|
146 |
+
)
|
147 |
+
print(len(text_encoder_loras), len(unet_loras))
|
148 |
+
return text_encoder_loras|unet_loras
|
lora_train_popup.py
ADDED
@@ -0,0 +1,862 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
from functools import partial
|
5 |
+
from typing import Union
|
6 |
+
import os
|
7 |
+
import tkinter as tk
|
8 |
+
from tkinter import filedialog as fd, ttk
|
9 |
+
from tkinter import simpledialog as sd
|
10 |
+
from tkinter import messagebox as mb
|
11 |
+
|
12 |
+
import torch.cuda
|
13 |
+
import train_network
|
14 |
+
import library.train_util as util
|
15 |
+
import argparse
|
16 |
+
|
17 |
+
|
18 |
+
class ArgStore:
|
19 |
+
# Represents the entirety of all possible inputs for sd-scripts. they are ordered from most important to least
|
20 |
+
def __init__(self):
|
21 |
+
# Important, these are the most likely things you will modify
|
22 |
+
self.base_model: str = r"" # example path, r"E:\sd\stable-diffusion-webui\models\Stable-diffusion\nai.ckpt"
|
23 |
+
self.img_folder: str = r"" # is the folder path to your img folder, make sure to follow the guide here for folder setup: https://rentry.org/2chAI_LoRA_Dreambooth_guide_english#for-kohyas-script
|
24 |
+
self.output_folder: str = r"" # just the folder all epochs/safetensors are output
|
25 |
+
self.change_output_name: Union[str, None] = None # changes the output name of the epochs
|
26 |
+
self.save_json_folder: Union[str, None] = None # OPTIONAL, saves a json folder of your config to whatever location you set here.
|
27 |
+
self.load_json_path: Union[str, None] = None # OPTIONAL, loads a json file partially changes the config to match. things like folder paths do not get modified.
|
28 |
+
self.json_load_skip_list: Union[list[str], None] = ["save_json_folder", "reg_img_folder",
|
29 |
+
"lora_model_for_resume", "change_output_name",
|
30 |
+
"training_comment",
|
31 |
+
"json_load_skip_list"] # OPTIONAL, allows the user to define what they skip when loading a json, by default it loads everything, including all paths, set it up like this ["base_model", "img_folder", "output_folder"]
|
32 |
+
self.caption_dropout_rate: Union[float, None] = None # The rate at which captions for files get dropped.
|
33 |
+
self.caption_dropout_every_n_epochs: Union[int, None] = None # Defines how often an epoch will completely ignore
|
34 |
+
# captions, EX. 3 means it will ignore captions at epochs 3, 6, and 9
|
35 |
+
self.caption_tag_dropout_rate: Union[float, None] = None # Defines the rate at which a tag would be dropped, rather than the entire caption file
|
36 |
+
self.noise_offset: Union[float, None] = None # OPTIONAL, seems to help allow SD to gen better blacks and whites
|
37 |
+
# Kohya recommends, if you have it set, to use 0.1, not sure how
|
38 |
+
# high the value can be, I'm going to assume maximum of 1
|
39 |
+
|
40 |
+
self.net_dim: int = 128 # network dimension, 128 is the most common, however you might be able to get lesser to work
|
41 |
+
self.alpha: float = 128 # represents the scalar for training. the lower the alpha, the less gets learned per step. if you want the older way of training, set this to dim
|
42 |
+
# list of schedulers: linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup
|
43 |
+
self.scheduler: str = "cosine_with_restarts" # the scheduler for learning rate. Each does something specific
|
44 |
+
self.cosine_restarts: Union[int, None] = 1 # OPTIONAL, represents the number of times it restarts. Only matters if you are using cosine_with_restarts
|
45 |
+
self.scheduler_power: Union[float, None] = 1 # OPTIONAL, represents the power of the polynomial. Only matters if you are using polynomial
|
46 |
+
self.warmup_lr_ratio: Union[float, None] = None # OPTIONAL, Calculates the number of warmup steps based on the ratio given. Make sure to set this if you are using constant_with_warmup, None to ignore
|
47 |
+
self.learning_rate: Union[float, None] = 1e-4 # OPTIONAL, when not set, lr gets set to 1e-3 as per adamW. Personally, I suggest actually setting this as lower lr seems to be a small bit better.
|
48 |
+
self.text_encoder_lr: Union[float, None] = None # OPTIONAL, Sets a specific lr for the text encoder, this overwrites the base lr I believe, None to ignore
|
49 |
+
self.unet_lr: Union[float, None] = None # OPTIONAL, Sets a specific lr for the unet, this overwrites the base lr I believe, None to ignore
|
50 |
+
self.num_workers: int = 1 # The number of threads that are being used to load images, lower speeds up the start of epochs, but slows down the loading of data. The assumption here is that it increases the training time as you reduce this value
|
51 |
+
self.persistent_workers: bool = True # makes workers persistent, further reduces/eliminates the lag in between epochs. however it may increase memory usage
|
52 |
+
|
53 |
+
self.batch_size: int = 1 # The number of images that get processed at one time, this is directly proportional to your vram and resolution. with 12gb of vram, at 512 reso, you can get a maximum of 6 batch size
|
54 |
+
self.num_epochs: int = 1 # The number of epochs, if you set max steps this value is ignored as it doesn't calculate steps.
|
55 |
+
self.save_every_n_epochs: Union[int, None] = None # OPTIONAL, how often to save epochs, None to ignore
|
56 |
+
self.shuffle_captions: bool = False # OPTIONAL, False to ignore
|
57 |
+
self.keep_tokens: Union[int, None] = None # OPTIONAL, None to ignore
|
58 |
+
self.max_steps: Union[int, None] = None # OPTIONAL, if you have specific steps you want to hit, this allows you to set it directly. None to ignore
|
59 |
+
self.tag_occurrence_txt_file: bool = False # OPTIONAL, creates a txt file that has the entire occurrence of all tags in your dataset
|
60 |
+
# the metadata will also have this so long as you have metadata on, so no reason to have this on by default
|
61 |
+
# will automatically output to the same folder as your output checkpoints
|
62 |
+
self.sort_tag_occurrence_alphabetically: bool = False # OPTIONAL, only applies if tag_occurrence_txt_file is also true
|
63 |
+
# Will change the output to be alphabetically vs being occurrence based
|
64 |
+
|
65 |
+
# These are the second most likely things you will modify
|
66 |
+
self.train_resolution: int = 512
|
67 |
+
self.min_bucket_resolution: int = 320
|
68 |
+
self.max_bucket_resolution: int = 960
|
69 |
+
self.lora_model_for_resume: Union[str, None] = None # OPTIONAL, takes an input lora to continue training from, not exactly the way it *should* be, but it works, None to ignore
|
70 |
+
self.save_state: bool = False # OPTIONAL, is the intended way to save a training state to use for continuing training, False to ignore
|
71 |
+
self.load_previous_save_state: Union[str, None] = None # OPTIONAL, is the intended way to load a training state to use for continuing training, None to ignore
|
72 |
+
self.training_comment: Union[str, None] = None # OPTIONAL, great way to put in things like activation tokens right into the metadata. seems to not work at this point and time
|
73 |
+
self.unet_only: bool = False # OPTIONAL, set it to only train the unet
|
74 |
+
self.text_only: bool = False # OPTIONAL, set it to only train the text encoder
|
75 |
+
|
76 |
+
# These are the least likely things you will modify
|
77 |
+
self.reg_img_folder: Union[str, None] = None # OPTIONAL, None to ignore
|
78 |
+
self.clip_skip: int = 2 # If you are training on a model that is anime based, keep this at 2 as most models are designed for that
|
79 |
+
self.test_seed: int = 23 # this is the "reproducable seed", basically if you set the seed to this, you should be able to input a prompt from one of your training images and get a close representation of it
|
80 |
+
self.prior_loss_weight: float = 1 # is the loss weight much like Dreambooth, is required for LoRA training
|
81 |
+
self.gradient_checkpointing: bool = False # OPTIONAL, enables gradient checkpointing
|
82 |
+
self.gradient_acc_steps: Union[int, None] = None # OPTIONAL, not sure exactly what this means
|
83 |
+
self.mixed_precision: str = "fp16" # If you have the ability to use bf16, do it, it's better
|
84 |
+
self.save_precision: str = "fp16" # You can also save in bf16, but because it's not universally supported, I suggest you keep saving at fp16
|
85 |
+
self.save_as: str = "safetensors" # list is pt, ckpt, safetensors
|
86 |
+
self.caption_extension: str = ".txt" # the other option is .captions, but since wd1.4 tagger outputs as txt files, this is the default
|
87 |
+
self.max_clip_token_length = 150 # can be 75, 150, or 225 I believe, there is no reason to go higher than 150 though
|
88 |
+
self.buckets: bool = True
|
89 |
+
self.xformers: bool = True
|
90 |
+
self.use_8bit_adam: bool = True
|
91 |
+
self.cache_latents: bool = True
|
92 |
+
self.color_aug: bool = False # IMPORTANT: Clashes with cache_latents, only have one of the two on!
|
93 |
+
self.flip_aug: bool = False
|
94 |
+
self.vae: Union[str, None] = None # Seems to only make results worse when not using that specific vae, should probably not use
|
95 |
+
self.no_meta: bool = False # This removes the metadata that now gets saved into safetensors, (you should keep this on)
|
96 |
+
self.log_dir: Union[str, None] = None # output of logs, not useful to most people.
|
97 |
+
self.v2: bool = False # Sets up training for SD2.1
|
98 |
+
self.v_parameterization: bool = False # Only is used when v2 is also set and you are using the 768x version of v2
|
99 |
+
|
100 |
+
# Creates the dict that is used for the rest of the code, to facilitate easier json saving and loading
|
101 |
+
@staticmethod
|
102 |
+
def convert_args_to_dict():
|
103 |
+
return ArgStore().__dict__
|
104 |
+
|
105 |
+
|
106 |
+
def main():
|
107 |
+
parser = argparse.ArgumentParser()
|
108 |
+
setup_args(parser)
|
109 |
+
pre_args = parser.parse_args()
|
110 |
+
queues = 0
|
111 |
+
args_queue = []
|
112 |
+
cont = True
|
113 |
+
while cont:
|
114 |
+
arg_dict = ArgStore.convert_args_to_dict()
|
115 |
+
ret = mb.askyesno(message="Do you want to load a json config file?")
|
116 |
+
if ret:
|
117 |
+
load_json(ask_file("select json to load from", {"json"}), arg_dict)
|
118 |
+
arg_dict = ask_elements_trunc(arg_dict)
|
119 |
+
else:
|
120 |
+
arg_dict = ask_elements(arg_dict)
|
121 |
+
if pre_args.save_json_path or arg_dict["save_json_folder"]:
|
122 |
+
save_json(pre_args.save_json_path if pre_args.save_json_path else arg_dict['save_json_folder'], arg_dict)
|
123 |
+
args = create_arg_space(arg_dict)
|
124 |
+
args = parser.parse_args(args)
|
125 |
+
queues += 1
|
126 |
+
args_queue.append(args)
|
127 |
+
if arg_dict['tag_occurrence_txt_file']:
|
128 |
+
get_occurrence_of_tags(arg_dict)
|
129 |
+
ret = mb.askyesno(message="Do you want to queue another training?")
|
130 |
+
if not ret:
|
131 |
+
cont = False
|
132 |
+
for args in args_queue:
|
133 |
+
try:
|
134 |
+
train_network.train(args)
|
135 |
+
except Exception as e:
|
136 |
+
print(f"Failed to train this set of args.\nSkipping this training session.\nError is: {e}")
|
137 |
+
gc.collect()
|
138 |
+
torch.cuda.empty_cache()
|
139 |
+
|
140 |
+
|
141 |
+
def create_arg_space(args: dict) -> [str]:
|
142 |
+
# This is the list of args that are to be used regardless of setup
|
143 |
+
output = ["--network_module=networks.lora", f"--pretrained_model_name_or_path={args['base_model']}",
|
144 |
+
f"--train_data_dir={args['img_folder']}", f"--output_dir={args['output_folder']}",
|
145 |
+
f"--prior_loss_weight={args['prior_loss_weight']}", f"--caption_extension=" + args['caption_extension'],
|
146 |
+
f"--resolution={args['train_resolution']}", f"--train_batch_size={args['batch_size']}",
|
147 |
+
f"--mixed_precision={args['mixed_precision']}", f"--save_precision={args['save_precision']}",
|
148 |
+
f"--network_dim={args['net_dim']}", f"--save_model_as={args['save_as']}",
|
149 |
+
f"--clip_skip={args['clip_skip']}", f"--seed={args['test_seed']}",
|
150 |
+
f"--max_token_length={args['max_clip_token_length']}", f"--lr_scheduler={args['scheduler']}",
|
151 |
+
f"--network_alpha={args['alpha']}", f"--max_data_loader_n_workers={args['num_workers']}"]
|
152 |
+
if not args['max_steps']:
|
153 |
+
output.append(f"--max_train_epochs={args['num_epochs']}")
|
154 |
+
output += create_optional_args(args, find_max_steps(args))
|
155 |
+
else:
|
156 |
+
output.append(f"--max_train_steps={args['max_steps']}")
|
157 |
+
output += create_optional_args(args, args['max_steps'])
|
158 |
+
return output
|
159 |
+
|
160 |
+
|
161 |
+
def create_optional_args(args: dict, steps):
|
162 |
+
output = []
|
163 |
+
if args["reg_img_folder"]:
|
164 |
+
output.append(f"--reg_data_dir={args['reg_img_folder']}")
|
165 |
+
|
166 |
+
if args['lora_model_for_resume']:
|
167 |
+
output.append(f"--network_weights={args['lora_model_for_resume']}")
|
168 |
+
|
169 |
+
if args['save_every_n_epochs']:
|
170 |
+
output.append(f"--save_every_n_epochs={args['save_every_n_epochs']}")
|
171 |
+
else:
|
172 |
+
output.append("--save_every_n_epochs=999999")
|
173 |
+
|
174 |
+
if args['shuffle_captions']:
|
175 |
+
output.append("--shuffle_caption")
|
176 |
+
|
177 |
+
if args['keep_tokens'] and args['keep_tokens'] > 0:
|
178 |
+
output.append(f"--keep_tokens={args['keep_tokens']}")
|
179 |
+
|
180 |
+
if args['buckets']:
|
181 |
+
output.append("--enable_bucket")
|
182 |
+
output.append(f"--min_bucket_reso={args['min_bucket_resolution']}")
|
183 |
+
output.append(f"--max_bucket_reso={args['max_bucket_resolution']}")
|
184 |
+
|
185 |
+
if args['use_8bit_adam']:
|
186 |
+
output.append("--use_8bit_adam")
|
187 |
+
|
188 |
+
if args['xformers']:
|
189 |
+
output.append("--xformers")
|
190 |
+
|
191 |
+
if args['color_aug']:
|
192 |
+
if args['cache_latents']:
|
193 |
+
print("color_aug and cache_latents conflict with one another. Please select only one")
|
194 |
+
quit(1)
|
195 |
+
output.append("--color_aug")
|
196 |
+
|
197 |
+
if args['flip_aug']:
|
198 |
+
output.append("--flip_aug")
|
199 |
+
|
200 |
+
if args['cache_latents']:
|
201 |
+
output.append("--cache_latents")
|
202 |
+
|
203 |
+
if args['warmup_lr_ratio'] and args['warmup_lr_ratio'] > 0:
|
204 |
+
warmup_steps = int(steps * args['warmup_lr_ratio'])
|
205 |
+
output.append(f"--lr_warmup_steps={warmup_steps}")
|
206 |
+
|
207 |
+
if args['gradient_checkpointing']:
|
208 |
+
output.append("--gradient_checkpointing")
|
209 |
+
|
210 |
+
if args['gradient_acc_steps'] and args['gradient_acc_steps'] > 0 and args['gradient_checkpointing']:
|
211 |
+
output.append(f"--gradient_accumulation_steps={args['gradient_acc_steps']}")
|
212 |
+
|
213 |
+
if args['learning_rate'] and args['learning_rate'] > 0:
|
214 |
+
output.append(f"--learning_rate={args['learning_rate']}")
|
215 |
+
|
216 |
+
if args['text_encoder_lr'] and args['text_encoder_lr'] > 0:
|
217 |
+
output.append(f"--text_encoder_lr={args['text_encoder_lr']}")
|
218 |
+
|
219 |
+
if args['unet_lr'] and args['unet_lr'] > 0:
|
220 |
+
output.append(f"--unet_lr={args['unet_lr']}")
|
221 |
+
|
222 |
+
if args['vae']:
|
223 |
+
output.append(f"--vae={args['vae']}")
|
224 |
+
|
225 |
+
if args['no_meta']:
|
226 |
+
output.append("--no_metadata")
|
227 |
+
|
228 |
+
if args['save_state']:
|
229 |
+
output.append("--save_state")
|
230 |
+
|
231 |
+
if args['load_previous_save_state']:
|
232 |
+
output.append(f"--resume={args['load_previous_save_state']}")
|
233 |
+
|
234 |
+
if args['change_output_name']:
|
235 |
+
output.append(f"--output_name={args['change_output_name']}")
|
236 |
+
|
237 |
+
if args['training_comment']:
|
238 |
+
output.append(f"--training_comment={args['training_comment']}")
|
239 |
+
|
240 |
+
if args['cosine_restarts'] and args['scheduler'] == "cosine_with_restarts":
|
241 |
+
output.append(f"--lr_scheduler_num_cycles={args['cosine_restarts']}")
|
242 |
+
|
243 |
+
if args['scheduler_power'] and args['scheduler'] == "polynomial":
|
244 |
+
output.append(f"--lr_scheduler_power={args['scheduler_power']}")
|
245 |
+
|
246 |
+
if args['persistent_workers']:
|
247 |
+
output.append(f"--persistent_data_loader_workers")
|
248 |
+
|
249 |
+
if args['unet_only']:
|
250 |
+
output.append("--network_train_unet_only")
|
251 |
+
|
252 |
+
if args['text_only'] and not args['unet_only']:
|
253 |
+
output.append("--network_train_text_encoder_only")
|
254 |
+
|
255 |
+
if args["log_dir"]:
|
256 |
+
output.append(f"--logging_dir={args['log_dir']}")
|
257 |
+
|
258 |
+
if args['caption_dropout_rate']:
|
259 |
+
output.append(f"--caption_dropout_rate={args['caption_dropout_rate']}")
|
260 |
+
|
261 |
+
if args['caption_dropout_every_n_epochs']:
|
262 |
+
output.append(f"--caption_dropout_every_n_epochs={args['caption_dropout_every_n_epochs']}")
|
263 |
+
|
264 |
+
if args['caption_tag_dropout_rate']:
|
265 |
+
output.append(f"--caption_tag_dropout_rate={args['caption_tag_dropout_rate']}")
|
266 |
+
|
267 |
+
if args['v2']:
|
268 |
+
output.append("--v2")
|
269 |
+
|
270 |
+
if args['v2'] and args['v_parameterization']:
|
271 |
+
output.append("--v_parameterization")
|
272 |
+
|
273 |
+
if args['noise_offset']:
|
274 |
+
output.append(f"--noise_offset={args['noise_offset']}")
|
275 |
+
return output
|
276 |
+
|
277 |
+
|
278 |
+
def find_max_steps(args: dict) -> int:
|
279 |
+
total_steps = 0
|
280 |
+
folders = os.listdir(args["img_folder"])
|
281 |
+
for folder in folders:
|
282 |
+
if not os.path.isdir(os.path.join(args["img_folder"], folder)):
|
283 |
+
continue
|
284 |
+
num_repeats = folder.split("_")
|
285 |
+
if len(num_repeats) < 2:
|
286 |
+
print(f"folder {folder} is not in the correct format. Format is x_name. skipping")
|
287 |
+
continue
|
288 |
+
try:
|
289 |
+
num_repeats = int(num_repeats[0])
|
290 |
+
except ValueError:
|
291 |
+
print(f"folder {folder} is not in the correct format. Format is x_name. skipping")
|
292 |
+
continue
|
293 |
+
imgs = 0
|
294 |
+
for file in os.listdir(os.path.join(args["img_folder"], folder)):
|
295 |
+
if os.path.isdir(file):
|
296 |
+
continue
|
297 |
+
ext = file.split(".")
|
298 |
+
if ext[-1].lower() in {"png", "bmp", "gif", "jpeg", "jpg", "webp"}:
|
299 |
+
imgs += 1
|
300 |
+
total_steps += (num_repeats * imgs)
|
301 |
+
total_steps = int((total_steps / args["batch_size"]) * args["num_epochs"])
|
302 |
+
return total_steps
|
303 |
+
|
304 |
+
|
305 |
+
def add_misc_args(parser):
|
306 |
+
parser.add_argument("--save_json_path", type=str, default=None,
|
307 |
+
help="Path to save a configuration json file to")
|
308 |
+
parser.add_argument("--load_json_path", type=str, default=None,
|
309 |
+
help="Path to a json file to configure things from")
|
310 |
+
parser.add_argument("--no_metadata", action='store_true',
|
311 |
+
help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
312 |
+
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
313 |
+
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)")
|
314 |
+
|
315 |
+
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
316 |
+
parser.add_argument("--text_encoder_lr", type=float, default=None,
|
317 |
+
help="learning rate for Text Encoder / Text Encoderの学習率")
|
318 |
+
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
319 |
+
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
320 |
+
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
321 |
+
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
322 |
+
|
323 |
+
parser.add_argument("--network_weights", type=str, default=None,
|
324 |
+
help="pretrained weights for network / 学習するネットワークの初期重み")
|
325 |
+
parser.add_argument("--network_module", type=str, default=None,
|
326 |
+
help='network module to train / 学習対象のネットワークのモジュール')
|
327 |
+
parser.add_argument("--network_dim", type=int, default=None,
|
328 |
+
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
|
329 |
+
parser.add_argument("--network_alpha", type=float, default=1,
|
330 |
+
help='alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)')
|
331 |
+
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
332 |
+
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
333 |
+
parser.add_argument("--network_train_unet_only", action="store_true",
|
334 |
+
help="only training U-Net part / U-Net関連部分のみ学習する")
|
335 |
+
parser.add_argument("--network_train_text_encoder_only", action="store_true",
|
336 |
+
help="only training Text Encoder part / Text Encoder関連部分のみ学習する")
|
337 |
+
parser.add_argument("--training_comment", type=str, default=None,
|
338 |
+
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
|
339 |
+
|
340 |
+
|
341 |
+
def setup_args(parser):
|
342 |
+
util.add_sd_models_arguments(parser)
|
343 |
+
util.add_dataset_arguments(parser, True, True, True)
|
344 |
+
util.add_training_arguments(parser, True)
|
345 |
+
add_misc_args(parser)
|
346 |
+
|
347 |
+
|
348 |
+
def get_occurrence_of_tags(args):
|
349 |
+
extension = args['caption_extension']
|
350 |
+
img_folder = args['img_folder']
|
351 |
+
output_folder = args['output_folder']
|
352 |
+
occurrence_dict = {}
|
353 |
+
print(img_folder)
|
354 |
+
for folder in os.listdir(img_folder):
|
355 |
+
print(folder)
|
356 |
+
if not os.path.isdir(os.path.join(img_folder, folder)):
|
357 |
+
continue
|
358 |
+
for file in os.listdir(os.path.join(img_folder, folder)):
|
359 |
+
if not os.path.isfile(os.path.join(img_folder, folder, file)):
|
360 |
+
continue
|
361 |
+
ext = os.path.splitext(file)[1]
|
362 |
+
if ext != extension:
|
363 |
+
continue
|
364 |
+
get_tags_from_file(os.path.join(img_folder, folder, file), occurrence_dict)
|
365 |
+
if not args['sort_tag_occurrence_alphabetically']:
|
366 |
+
output_list = {k: v for k, v in sorted(occurrence_dict.items(), key=lambda item: item[1], reverse=True)}
|
367 |
+
else:
|
368 |
+
output_list = {k: v for k, v in sorted(occurrence_dict.items(), key=lambda item: item[0])}
|
369 |
+
name = args['change_output_name'] if args['change_output_name'] else "last"
|
370 |
+
with open(os.path.join(output_folder, f"{name}.txt"), "w") as f:
|
371 |
+
f.write(f"Below is a list of keywords used during the training of {args['change_output_name']}:\n")
|
372 |
+
for k, v in output_list.items():
|
373 |
+
f.write(f"[{v}] {k}\n")
|
374 |
+
print(f"Created a txt file named {name}.txt in the output folder")
|
375 |
+
|
376 |
+
|
377 |
+
def get_tags_from_file(file, occurrence_dict):
|
378 |
+
f = open(file)
|
379 |
+
temp = f.read().replace(", ", ",").split(",")
|
380 |
+
f.close()
|
381 |
+
for tag in temp:
|
382 |
+
if tag in occurrence_dict:
|
383 |
+
occurrence_dict[tag] += 1
|
384 |
+
else:
|
385 |
+
occurrence_dict[tag] = 1
|
386 |
+
|
387 |
+
|
388 |
+
def ask_file(message, accepted_ext_list, file_path=None):
|
389 |
+
mb.showinfo(message=message)
|
390 |
+
res = ""
|
391 |
+
_initialdir = ""
|
392 |
+
_initialfile = ""
|
393 |
+
if file_path != None:
|
394 |
+
_initialdir = os.path.dirname(file_path) if os.path.exists(file_path) else ""
|
395 |
+
_initialfile = os.path.basename(file_path) if os.path.exists(file_path) else ""
|
396 |
+
|
397 |
+
while res == "":
|
398 |
+
res = fd.askopenfilename(title=message, initialdir=_initialdir, initialfile=_initialfile)
|
399 |
+
if res == "" or type(res) == tuple:
|
400 |
+
ret = mb.askretrycancel(message="Do you want to to cancel training?")
|
401 |
+
if not ret:
|
402 |
+
exit()
|
403 |
+
continue
|
404 |
+
elif not os.path.exists(res):
|
405 |
+
res = ""
|
406 |
+
continue
|
407 |
+
_, name = os.path.split(res)
|
408 |
+
split_name = name.split(".")
|
409 |
+
if split_name[-1] not in accepted_ext_list:
|
410 |
+
res = ""
|
411 |
+
return res
|
412 |
+
|
413 |
+
|
414 |
+
def ask_dir(message, dir_path=None):
|
415 |
+
mb.showinfo(message=message)
|
416 |
+
res = ""
|
417 |
+
_initialdir = ""
|
418 |
+
if dir_path != None:
|
419 |
+
_initialdir = dir_path if os.path.exists(dir_path) else ""
|
420 |
+
while res == "":
|
421 |
+
res = fd.askdirectory(title=message, initialdir=_initialdir)
|
422 |
+
if res == "" or type(res) == tuple:
|
423 |
+
ret = mb.askretrycancel(message="Do you want to to cancel training?")
|
424 |
+
if not ret:
|
425 |
+
exit()
|
426 |
+
continue
|
427 |
+
if not os.path.exists(res):
|
428 |
+
res = ""
|
429 |
+
return res
|
430 |
+
|
431 |
+
|
432 |
+
def ask_elements_trunc(args: dict):
|
433 |
+
args['base_model'] = ask_file("Select your base model", {"ckpt", "safetensors"}, args['base_model'])
|
434 |
+
args['img_folder'] = ask_dir("Select your image folder", args['img_folder'])
|
435 |
+
args['output_folder'] = ask_dir("Select your output folder", args['output_folder'])
|
436 |
+
|
437 |
+
ret = mb.askyesno(message="Do you want to save a json of your configuration?")
|
438 |
+
if ret:
|
439 |
+
args['save_json_folder'] = ask_dir("Select the folder to save json files to", args['save_json_folder'])
|
440 |
+
else:
|
441 |
+
args['save_json_folder'] = None
|
442 |
+
|
443 |
+
ret = mb.askyesno(message="Are you training on a SD2 based model?")
|
444 |
+
if ret:
|
445 |
+
args['v2'] = True
|
446 |
+
|
447 |
+
ret = mb.askyesno(message="Are you training on an realistic model?")
|
448 |
+
if ret:
|
449 |
+
args['clip_skip'] = 1
|
450 |
+
|
451 |
+
if args['v2']:
|
452 |
+
ret = mb.askyesno(message="Are you training on a model based on the 768x version of SD2?")
|
453 |
+
if ret:
|
454 |
+
args['v_parameterization'] = True
|
455 |
+
|
456 |
+
ret = mb.askyesno(message="Do you want to use regularization images?")
|
457 |
+
if ret:
|
458 |
+
args['reg_img_folder'] = ask_dir("Select your regularization folder", args['reg_img_folder'])
|
459 |
+
else:
|
460 |
+
args['reg_img_folder'] = None
|
461 |
+
|
462 |
+
ret = mb.askyesno(message="Do you want to continue from an earlier version?")
|
463 |
+
if ret:
|
464 |
+
args['lora_model_for_resume'] = ask_file("Select your lora model", {"ckpt", "pt", "safetensors"},
|
465 |
+
args['lora_model_for_resume'])
|
466 |
+
else:
|
467 |
+
args['lora_model_for_resume'] = None
|
468 |
+
|
469 |
+
ret = mb.askyesno(message="Do you want to flip all of your images? It is supposed to reduce biases\n"
|
470 |
+
"within your dataset but it can also ruin learning an asymmetrical element\n")
|
471 |
+
if ret:
|
472 |
+
args['flip_aug'] = True
|
473 |
+
|
474 |
+
ret = mb.askyesno(message="Do you want to change the name of output checkpoints?")
|
475 |
+
if ret:
|
476 |
+
ret = sd.askstring(title="output_name", prompt="What do you want your output name to be?\n"
|
477 |
+
"Cancel keeps outputs the original")
|
478 |
+
if ret:
|
479 |
+
args['change_output_name'] = ret
|
480 |
+
else:
|
481 |
+
args['change_output_name'] = None
|
482 |
+
|
483 |
+
ret = sd.askstring(title="comment",
|
484 |
+
prompt="Do you want to set a comment that gets put into the metadata?\nA good use of this would "
|
485 |
+
"be to include how to use, such as activation keywords.\nCancel will leave empty")
|
486 |
+
if ret is None:
|
487 |
+
args['training_comment'] = ret
|
488 |
+
else:
|
489 |
+
args['training_comment'] = None
|
490 |
+
|
491 |
+
ret = mb.askyesno(message="Do you want to train only one of unet and text encoder?")
|
492 |
+
if ret:
|
493 |
+
button = ButtonBox("Which do you want to train with?", ["unet_only", "text_only"])
|
494 |
+
button.window.mainloop()
|
495 |
+
if button.current_value != "":
|
496 |
+
args[button.current_value] = True
|
497 |
+
|
498 |
+
ret = mb.askyesno(message="Do you want to save a txt file that contains a list\n"
|
499 |
+
"of all tags that you have used in your training data?\n")
|
500 |
+
if ret:
|
501 |
+
args['tag_occurrence_txt_file'] = True
|
502 |
+
button = ButtonBox("How do you want tags to be ordered?", ["alphabetically", "occurrence-ly"])
|
503 |
+
button.window.mainloop()
|
504 |
+
if button.current_value == "alphabetically":
|
505 |
+
args['sort_tag_occurrence_alphabetically'] = True
|
506 |
+
|
507 |
+
ret = mb.askyesno(message="Do you want to use caption dropout?")
|
508 |
+
if ret:
|
509 |
+
ret = mb.askyesno(message="Do you want full caption files to dropout randomly?")
|
510 |
+
if ret:
|
511 |
+
ret = sd.askinteger(title="Caption_File_Dropout",
|
512 |
+
prompt="How often do you want caption files to drop out?\n"
|
513 |
+
"enter a number from 0 to 100 that is the percentage chance of dropout\n"
|
514 |
+
"Cancel sets to 0")
|
515 |
+
if ret and 0 <= ret <= 100:
|
516 |
+
args['caption_dropout_rate'] = ret / 100.0
|
517 |
+
|
518 |
+
ret = mb.askyesno(message="Do you want to have full epochs have no captions?")
|
519 |
+
if ret:
|
520 |
+
ret = sd.askinteger(title="Caption_epoch_dropout", prompt="The number set here is how often you will have an"
|
521 |
+
"epoch with no captions\nSo if you set 3, then every"
|
522 |
+
"three epochs will not have captions (3, 6, 9)\n"
|
523 |
+
"Cancel will set to None")
|
524 |
+
if ret:
|
525 |
+
args['caption_dropout_every_n_epochs'] = ret
|
526 |
+
|
527 |
+
ret = mb.askyesno(message="Do you want to have tags to randomly drop?")
|
528 |
+
if ret:
|
529 |
+
ret = sd.askinteger(title="Caption_tag_dropout", prompt="How often do you want tags to randomly drop out?\n"
|
530 |
+
"Enter a number between 0 and 100, that is the percentage"
|
531 |
+
"chance of dropout.\nCancel sets to 0")
|
532 |
+
if ret and 0 <= ret <= 100:
|
533 |
+
args['caption_tag_dropout_rate'] = ret / 100.0
|
534 |
+
|
535 |
+
ret = mb.askyesno(message="Do you want to use noise offset? Noise offset seems to allow for SD to better generate\n"
|
536 |
+
"darker or lighter images using this than normal.")
|
537 |
+
if ret:
|
538 |
+
ret = sd.askfloat(title="noise_offset", prompt="What value do you want to set? recommended value is 0.1,\n"
|
539 |
+
"but it can go higher. Cancel defaults to 0.1")
|
540 |
+
if ret:
|
541 |
+
args['noise_offset'] = ret
|
542 |
+
else:
|
543 |
+
args['noise_offset'] = 0.1
|
544 |
+
return args
|
545 |
+
|
546 |
+
|
547 |
+
def ask_elements(args: dict):
|
548 |
+
# start with file dialog
|
549 |
+
args['base_model'] = ask_file("Select your base model", {"ckpt", "safetensors"}, args['base_model'])
|
550 |
+
args['img_folder'] = ask_dir("Select your image folder", args['img_folder'])
|
551 |
+
args['output_folder'] = ask_dir("Select your output folder", args['output_folder'])
|
552 |
+
|
553 |
+
# optional file dialog
|
554 |
+
ret = mb.askyesno(message="Do you want to save a json of your configuration?")
|
555 |
+
if ret:
|
556 |
+
args['save_json_folder'] = ask_dir("Select the folder to save json files to", args['save_json_folder'])
|
557 |
+
else:
|
558 |
+
args['save_json_folder'] = None
|
559 |
+
|
560 |
+
ret = mb.askyesno(message="Are you training on a SD2 based model?")
|
561 |
+
if ret:
|
562 |
+
args['v2'] = True
|
563 |
+
|
564 |
+
ret = mb.askyesno(message="Are you training on an realistic model?")
|
565 |
+
if ret:
|
566 |
+
args['clip_skip'] = 1
|
567 |
+
|
568 |
+
if args['v2']:
|
569 |
+
ret = mb.askyesno(message="Are you training on a model based on the 768x version of SD2?")
|
570 |
+
if ret:
|
571 |
+
args['v_parameterization'] = True
|
572 |
+
|
573 |
+
ret = mb.askyesno(message="Do you want to use regularization images?")
|
574 |
+
if ret:
|
575 |
+
args['reg_img_folder'] = ask_dir("Select your regularization folder", args['reg_img_folder'])
|
576 |
+
else:
|
577 |
+
args['reg_img_folder'] = None
|
578 |
+
|
579 |
+
ret = mb.askyesno(message="Do you want to continue from an earlier version?")
|
580 |
+
if ret:
|
581 |
+
args['lora_model_for_resume'] = ask_file("Select your lora model", {"ckpt", "pt", "safetensors"},
|
582 |
+
args['lora_model_for_resume'])
|
583 |
+
else:
|
584 |
+
args['lora_model_for_resume'] = None
|
585 |
+
|
586 |
+
ret = mb.askyesno(message="Do you want to flip all of your images? It is supposed to reduce biases\n"
|
587 |
+
"within your dataset but it can also ruin learning an asymmetrical element\n")
|
588 |
+
if ret:
|
589 |
+
args['flip_aug'] = True
|
590 |
+
|
591 |
+
# text based required elements
|
592 |
+
ret = sd.askinteger(title="batch_size",
|
593 |
+
prompt="The number of images that get processed at one time, this is directly proportional to "
|
594 |
+
"your vram and resolution. with 12gb of vram, at 512 reso, you can get a maximum of 6 "
|
595 |
+
"batch size\nHow large is your batch size going to be?\nCancel will default to 1")
|
596 |
+
if ret is None:
|
597 |
+
args['batch_size'] = 1
|
598 |
+
else:
|
599 |
+
args['batch_size'] = ret
|
600 |
+
|
601 |
+
ret = sd.askinteger(title="num_epochs", prompt="How many epochs do you want?\nCancel will default to 1")
|
602 |
+
if ret is None:
|
603 |
+
args['num_epochs'] = 1
|
604 |
+
else:
|
605 |
+
args['num_epochs'] = ret
|
606 |
+
|
607 |
+
ret = sd.askinteger(title="network_dim", prompt="What is the dim size you want to use?\nCancel will default to 128")
|
608 |
+
if ret is None:
|
609 |
+
args['net_dim'] = 128
|
610 |
+
else:
|
611 |
+
args['net_dim'] = ret
|
612 |
+
|
613 |
+
ret = sd.askfloat(title="alpha", prompt="Alpha is the scalar of the training, generally a good starting point is "
|
614 |
+
"0.5x dim size\nWhat Alpha do you want?\nCancel will default to equal to "
|
615 |
+
"0.5 x network_dim")
|
616 |
+
if ret is None:
|
617 |
+
args['alpha'] = args['net_dim'] / 2
|
618 |
+
else:
|
619 |
+
args['alpha'] = ret
|
620 |
+
|
621 |
+
ret = sd.askinteger(title="resolution", prompt="How large of a resolution do you want to train at?\n"
|
622 |
+
"Cancel will default to 512")
|
623 |
+
if ret is None:
|
624 |
+
args['train_resolution'] = 512
|
625 |
+
else:
|
626 |
+
args['train_resolution'] = ret
|
627 |
+
|
628 |
+
ret = sd.askfloat(title="learning_rate", prompt="What learning rate do you want to use?\n"
|
629 |
+
"Cancel will default to 1e-4")
|
630 |
+
if ret is None:
|
631 |
+
args['learning_rate'] = 1e-4
|
632 |
+
else:
|
633 |
+
args['learning_rate'] = ret
|
634 |
+
|
635 |
+
ret = sd.askfloat(title="text_encoder_lr", prompt="Do you want to set the text_encoder_lr?\n"
|
636 |
+
"Cancel will default to None")
|
637 |
+
if ret is None:
|
638 |
+
args['text_encoder_lr'] = None
|
639 |
+
else:
|
640 |
+
args['text_encoder_lr'] = ret
|
641 |
+
|
642 |
+
ret = sd.askfloat(title="unet_lr", prompt="Do you want to set the unet_lr?\nCancel will default to None")
|
643 |
+
if ret is None:
|
644 |
+
args['unet_lr'] = None
|
645 |
+
else:
|
646 |
+
args['unet_lr'] = ret
|
647 |
+
|
648 |
+
button = ButtonBox("Which scheduler do you want?", ["cosine_with_restarts", "cosine", "polynomial",
|
649 |
+
"constant", "constant_with_warmup", "linear"])
|
650 |
+
button.window.mainloop()
|
651 |
+
args['scheduler'] = button.current_value if button.current_value != "" else "cosine_with_restarts"
|
652 |
+
|
653 |
+
if args['scheduler'] == "cosine_with_restarts":
|
654 |
+
ret = sd.askinteger(title="Cycle Count",
|
655 |
+
prompt="How many times do you want cosine to restart?\nThis is the entire amount of times "
|
656 |
+
"it will restart for the entire training\nCancel will default to 1")
|
657 |
+
if ret is None:
|
658 |
+
args['cosine_restarts'] = 1
|
659 |
+
else:
|
660 |
+
args['cosine_restarts'] = ret
|
661 |
+
|
662 |
+
if args['scheduler'] == "polynomial":
|
663 |
+
ret = sd.askfloat(title="Poly Strength",
|
664 |
+
prompt="What power do you want to set your polynomial to?\nhigher power means that the "
|
665 |
+
"model reduces the learning more more aggressively from initial training.\n1 = "
|
666 |
+
"linear\nCancel sets to 1")
|
667 |
+
if ret is None:
|
668 |
+
args['scheduler_power'] = 1
|
669 |
+
else:
|
670 |
+
args['scheduler_power'] = ret
|
671 |
+
|
672 |
+
ret = mb.askyesno(message="Do you want to save epochs as it trains?")
|
673 |
+
if ret:
|
674 |
+
ret = sd.askinteger(title="save_epoch",
|
675 |
+
prompt="How often do you want to save epochs?\nCancel will default to 1")
|
676 |
+
if ret is None:
|
677 |
+
args['save_every_n_epochs'] = 1
|
678 |
+
else:
|
679 |
+
args['save_every_n_epochs'] = ret
|
680 |
+
|
681 |
+
ret = mb.askyesno(message="Do you want to shuffle captions?")
|
682 |
+
if ret:
|
683 |
+
args['shuffle_captions'] = True
|
684 |
+
else:
|
685 |
+
args['shuffle_captions'] = False
|
686 |
+
|
687 |
+
ret = mb.askyesno(message="Do you want to keep some tokens at the front of your captions?")
|
688 |
+
if ret:
|
689 |
+
ret = sd.askinteger(title="keep_tokens", prompt="How many do you want to keep at the front?"
|
690 |
+
"\nCancel will default to 1")
|
691 |
+
if ret is None:
|
692 |
+
args['keep_tokens'] = 1
|
693 |
+
else:
|
694 |
+
args['keep_tokens'] = ret
|
695 |
+
|
696 |
+
ret = mb.askyesno(message="Do you want to have a warmup ratio?")
|
697 |
+
if ret:
|
698 |
+
ret = sd.askfloat(title="warmup_ratio", prompt="What is the ratio of steps to use as warmup "
|
699 |
+
"steps?\nCancel will default to None")
|
700 |
+
if ret is None:
|
701 |
+
args['warmup_lr_ratio'] = None
|
702 |
+
else:
|
703 |
+
args['warmup_lr_ratio'] = ret
|
704 |
+
|
705 |
+
ret = mb.askyesno(message="Do you want to change the name of output checkpoints?")
|
706 |
+
if ret:
|
707 |
+
ret = sd.askstring(title="output_name", prompt="What do you want your output name to be?\n"
|
708 |
+
"Cancel keeps outputs the original")
|
709 |
+
if ret:
|
710 |
+
args['change_output_name'] = ret
|
711 |
+
else:
|
712 |
+
args['change_output_name'] = None
|
713 |
+
|
714 |
+
ret = sd.askstring(title="comment",
|
715 |
+
prompt="Do you want to set a comment that gets put into the metadata?\nA good use of this would "
|
716 |
+
"be to include how to use, such as activation keywords.\nCancel will leave empty")
|
717 |
+
if ret is None:
|
718 |
+
args['training_comment'] = ret
|
719 |
+
else:
|
720 |
+
args['training_comment'] = None
|
721 |
+
|
722 |
+
ret = mb.askyesno(message="Do you want to train only one of unet and text encoder?")
|
723 |
+
if ret:
|
724 |
+
if ret:
|
725 |
+
button = ButtonBox("Which do you want to train with?", ["unet_only", "text_only"])
|
726 |
+
button.window.mainloop()
|
727 |
+
if button.current_value != "":
|
728 |
+
args[button.current_value] = True
|
729 |
+
|
730 |
+
ret = mb.askyesno(message="Do you want to save a txt file that contains a list\n"
|
731 |
+
"of all tags that you have used in your training data?\n")
|
732 |
+
if ret:
|
733 |
+
args['tag_occurrence_txt_file'] = True
|
734 |
+
button = ButtonBox("How do you want tags to be ordered?", ["alphabetically", "occurrence-ly"])
|
735 |
+
button.window.mainloop()
|
736 |
+
if button.current_value == "alphabetically":
|
737 |
+
args['sort_tag_occurrence_alphabetically'] = True
|
738 |
+
|
739 |
+
ret = mb.askyesno(message="Do you want to use caption dropout?")
|
740 |
+
if ret:
|
741 |
+
ret = mb.askyesno(message="Do you want full caption files to dropout randomly?")
|
742 |
+
if ret:
|
743 |
+
ret = sd.askinteger(title="Caption_File_Dropout",
|
744 |
+
prompt="How often do you want caption files to drop out?\n"
|
745 |
+
"enter a number from 0 to 100 that is the percentage chance of dropout\n"
|
746 |
+
"Cancel sets to 0")
|
747 |
+
if ret and 0 <= ret <= 100:
|
748 |
+
args['caption_dropout_rate'] = ret / 100.0
|
749 |
+
|
750 |
+
ret = mb.askyesno(message="Do you want to have full epochs have no captions?")
|
751 |
+
if ret:
|
752 |
+
ret = sd.askinteger(title="Caption_epoch_dropout", prompt="The number set here is how often you will have an"
|
753 |
+
"epoch with no captions\nSo if you set 3, then every"
|
754 |
+
"three epochs will not have captions (3, 6, 9)\n"
|
755 |
+
"Cancel will set to None")
|
756 |
+
if ret:
|
757 |
+
args['caption_dropout_every_n_epochs'] = ret
|
758 |
+
|
759 |
+
ret = mb.askyesno(message="Do you want to have tags to randomly drop?")
|
760 |
+
if ret:
|
761 |
+
ret = sd.askinteger(title="Caption_tag_dropout", prompt="How often do you want tags to randomly drop out?\n"
|
762 |
+
"Enter a number between 0 and 100, that is the percentage"
|
763 |
+
"chance of dropout.\nCancel sets to 0")
|
764 |
+
if ret and 0 <= ret <= 100:
|
765 |
+
args['caption_tag_dropout_rate'] = ret / 100.0
|
766 |
+
|
767 |
+
ret = mb.askyesno(message="Do you want to use noise offset? Noise offset seems to allow for SD to better generate\n"
|
768 |
+
"darker or lighter images using this than normal.")
|
769 |
+
if ret:
|
770 |
+
ret = sd.askfloat(title="noise_offset", prompt="What value do you want to set? recommended value is 0.1,\n"
|
771 |
+
"but it can go higher. Cancel defaults to 0.1")
|
772 |
+
if ret:
|
773 |
+
args['noise_offset'] = ret
|
774 |
+
else:
|
775 |
+
args['noise_offset'] = 0.1
|
776 |
+
return args
|
777 |
+
|
778 |
+
|
779 |
+
def save_json(path, obj: dict) -> None:
|
780 |
+
fp = open(os.path.join(path, f"config-{time.time()}.json"), "w")
|
781 |
+
json.dump(obj, fp=fp, indent=4)
|
782 |
+
fp.close()
|
783 |
+
|
784 |
+
|
785 |
+
def load_json(path, obj: dict) -> dict:
|
786 |
+
with open(path) as f:
|
787 |
+
json_obj = json.loads(f.read())
|
788 |
+
print("loaded json, setting variables...")
|
789 |
+
ui_name_scheme = {"pretrained_model_name_or_path": "base_model", "logging_dir": "log_dir",
|
790 |
+
"train_data_dir": "img_folder", "reg_data_dir": "reg_img_folder",
|
791 |
+
"output_dir": "output_folder", "max_resolution": "train_resolution",
|
792 |
+
"lr_scheduler": "scheduler", "lr_warmup": "warmup_lr_ratio",
|
793 |
+
"train_batch_size": "batch_size", "epoch": "num_epochs",
|
794 |
+
"save_at_n_epochs": "save_every_n_epochs", "num_cpu_threads_per_process": "num_workers",
|
795 |
+
"enable_bucket": "buckets", "save_model_as": "save_as", "shuffle_caption": "shuffle_captions",
|
796 |
+
"resume": "load_previous_save_state", "network_dim": "net_dim",
|
797 |
+
"gradient_accumulation_steps": "gradient_acc_steps", "output_name": "change_output_name",
|
798 |
+
"network_alpha": "alpha", "lr_scheduler_num_cycles": "cosine_restarts",
|
799 |
+
"lr_scheduler_power": "scheduler_power"}
|
800 |
+
|
801 |
+
for key in list(json_obj):
|
802 |
+
if key in ui_name_scheme:
|
803 |
+
json_obj[ui_name_scheme[key]] = json_obj[key]
|
804 |
+
if ui_name_scheme[key] in {"batch_size", "num_epochs"}:
|
805 |
+
try:
|
806 |
+
json_obj[ui_name_scheme[key]] = int(json_obj[ui_name_scheme[key]])
|
807 |
+
except ValueError:
|
808 |
+
print(f"attempting to load {key} from json failed as input isn't an integer")
|
809 |
+
quit(1)
|
810 |
+
|
811 |
+
for key in list(json_obj):
|
812 |
+
if obj["json_load_skip_list"] and key in obj["json_load_skip_list"]:
|
813 |
+
continue
|
814 |
+
if key in obj:
|
815 |
+
if key in {"keep_tokens", "warmup_lr_ratio"}:
|
816 |
+
json_obj[key] = int(json_obj[key]) if json_obj[key] is not None else None
|
817 |
+
if key in {"learning_rate", "unet_lr", "text_encoder_lr"}:
|
818 |
+
json_obj[key] = float(json_obj[key]) if json_obj[key] is not None else None
|
819 |
+
if obj[key] != json_obj[key]:
|
820 |
+
print_change(key, obj[key], json_obj[key])
|
821 |
+
obj[key] = json_obj[key]
|
822 |
+
print("completed changing variables.")
|
823 |
+
return obj
|
824 |
+
|
825 |
+
|
826 |
+
def print_change(value, old, new):
|
827 |
+
print(f"{value} changed from {old} to {new}")
|
828 |
+
|
829 |
+
|
830 |
+
class ButtonBox:
|
831 |
+
def __init__(self, label: str, button_name_list: list[str]) -> None:
|
832 |
+
self.window = tk.Tk()
|
833 |
+
self.button_list = []
|
834 |
+
self.current_value = ""
|
835 |
+
|
836 |
+
self.window.attributes("-topmost", True)
|
837 |
+
self.window.resizable(False, False)
|
838 |
+
self.window.eval('tk::PlaceWindow . center')
|
839 |
+
|
840 |
+
def del_window():
|
841 |
+
self.window.quit()
|
842 |
+
self.window.destroy()
|
843 |
+
|
844 |
+
self.window.protocol("WM_DELETE_WINDOW", del_window)
|
845 |
+
tk.Label(text=label, master=self.window).pack()
|
846 |
+
for button in button_name_list:
|
847 |
+
self.button_list.append(ttk.Button(text=button, master=self.window,
|
848 |
+
command=partial(self.set_current_value, button)))
|
849 |
+
self.button_list[-1].pack()
|
850 |
+
|
851 |
+
def set_current_value(self, value):
|
852 |
+
self.current_value = value
|
853 |
+
self.window.quit()
|
854 |
+
self.window.destroy()
|
855 |
+
|
856 |
+
|
857 |
+
root = tk.Tk()
|
858 |
+
root.attributes('-topmost', True)
|
859 |
+
root.withdraw()
|
860 |
+
|
861 |
+
if __name__ == "__main__":
|
862 |
+
main()
|
lycoris/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lycoris import (
|
2 |
+
kohya,
|
3 |
+
kohya_model_utils,
|
4 |
+
kohya_utils,
|
5 |
+
locon,
|
6 |
+
loha,
|
7 |
+
utils,
|
8 |
+
)
|
lycoris/kohya.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# network module for kohya
|
2 |
+
# reference:
|
3 |
+
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
4 |
+
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
5 |
+
# https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
|
6 |
+
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
from typing import List
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from .kohya_utils import *
|
13 |
+
from .locon import LoConModule
|
14 |
+
from .loha import LohaModule
|
15 |
+
|
16 |
+
|
17 |
+
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
18 |
+
if network_dim is None:
|
19 |
+
network_dim = 4 # default
|
20 |
+
conv_dim = int(kwargs.get('conv_dim', network_dim))
|
21 |
+
conv_alpha = float(kwargs.get('conv_alpha', network_alpha))
|
22 |
+
dropout = float(kwargs.get('dropout', 0.))
|
23 |
+
algo = kwargs.get('algo', 'lora')
|
24 |
+
network_module = {
|
25 |
+
'lora': LoConModule,
|
26 |
+
'loha': LohaModule,
|
27 |
+
}[algo]
|
28 |
+
|
29 |
+
print(f'Using rank adaptation algo: {algo}')
|
30 |
+
network = LoRANetwork(
|
31 |
+
text_encoder, unet,
|
32 |
+
multiplier=multiplier,
|
33 |
+
lora_dim=network_dim, conv_lora_dim=conv_dim,
|
34 |
+
alpha=network_alpha, conv_alpha=conv_alpha,
|
35 |
+
dropout=dropout,
|
36 |
+
network_module=network_module
|
37 |
+
)
|
38 |
+
|
39 |
+
return network
|
40 |
+
|
41 |
+
|
42 |
+
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
|
43 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
44 |
+
from safetensors.torch import load_file, safe_open
|
45 |
+
weights_sd = load_file(file)
|
46 |
+
else:
|
47 |
+
weights_sd = torch.load(file, map_location='cpu')
|
48 |
+
|
49 |
+
# get dim (rank)
|
50 |
+
network_alpha = None
|
51 |
+
network_dim = None
|
52 |
+
for key, value in weights_sd.items():
|
53 |
+
if network_alpha is None and 'alpha' in key:
|
54 |
+
network_alpha = value
|
55 |
+
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
|
56 |
+
network_dim = value.size()[0]
|
57 |
+
|
58 |
+
if network_alpha is None:
|
59 |
+
network_alpha = network_dim
|
60 |
+
|
61 |
+
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
|
62 |
+
network.weights_sd = weights_sd
|
63 |
+
return network
|
64 |
+
|
65 |
+
|
66 |
+
class LoRANetwork(torch.nn.Module):
|
67 |
+
'''
|
68 |
+
LoRA + LoCon
|
69 |
+
'''
|
70 |
+
# Ignore proj_in or proj_out, their channels is only a few.
|
71 |
+
UNET_TARGET_REPLACE_MODULE = [
|
72 |
+
"Transformer2DModel",
|
73 |
+
"Attention",
|
74 |
+
"ResnetBlock2D",
|
75 |
+
"Downsample2D",
|
76 |
+
"Upsample2D"
|
77 |
+
]
|
78 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
79 |
+
LORA_PREFIX_UNET = 'lora_unet'
|
80 |
+
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
81 |
+
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
text_encoder, unet,
|
85 |
+
multiplier=1.0,
|
86 |
+
lora_dim=4, conv_lora_dim=4,
|
87 |
+
alpha=1, conv_alpha=1,
|
88 |
+
dropout = 0, network_module = LoConModule,
|
89 |
+
) -> None:
|
90 |
+
super().__init__()
|
91 |
+
self.multiplier = multiplier
|
92 |
+
self.lora_dim = lora_dim
|
93 |
+
self.conv_lora_dim = int(conv_lora_dim)
|
94 |
+
if self.conv_lora_dim != self.lora_dim:
|
95 |
+
print('Apply different lora dim for conv layer')
|
96 |
+
print(f'LoCon Dim: {conv_lora_dim}, LoRA Dim: {lora_dim}')
|
97 |
+
|
98 |
+
self.alpha = alpha
|
99 |
+
self.conv_alpha = float(conv_alpha)
|
100 |
+
if self.alpha != self.conv_alpha:
|
101 |
+
print('Apply different alpha value for conv layer')
|
102 |
+
print(f'LoCon alpha: {conv_alpha}, LoRA alpha: {alpha}')
|
103 |
+
|
104 |
+
if 1 >= dropout >= 0:
|
105 |
+
print(f'Use Dropout value: {dropout}')
|
106 |
+
self.dropout = dropout
|
107 |
+
|
108 |
+
# create module instances
|
109 |
+
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[network_module]:
|
110 |
+
print('Create LoCon Module')
|
111 |
+
loras = []
|
112 |
+
for name, module in root_module.named_modules():
|
113 |
+
if module.__class__.__name__ in target_replace_modules:
|
114 |
+
for child_name, child_module in module.named_modules():
|
115 |
+
lora_name = prefix + '.' + name + '.' + child_name
|
116 |
+
lora_name = lora_name.replace('.', '_')
|
117 |
+
if child_module.__class__.__name__ == 'Linear' and lora_dim>0:
|
118 |
+
lora = network_module(
|
119 |
+
lora_name, child_module, self.multiplier,
|
120 |
+
self.lora_dim, self.alpha, self.dropout
|
121 |
+
)
|
122 |
+
elif child_module.__class__.__name__ == 'Conv2d':
|
123 |
+
k_size, *_ = child_module.kernel_size
|
124 |
+
if k_size==1 and lora_dim>0:
|
125 |
+
lora = network_module(
|
126 |
+
lora_name, child_module, self.multiplier,
|
127 |
+
self.lora_dim, self.alpha, self.dropout
|
128 |
+
)
|
129 |
+
elif conv_lora_dim>0:
|
130 |
+
lora = network_module(
|
131 |
+
lora_name, child_module, self.multiplier,
|
132 |
+
self.conv_lora_dim, self.conv_alpha, self.dropout
|
133 |
+
)
|
134 |
+
else:
|
135 |
+
continue
|
136 |
+
else:
|
137 |
+
continue
|
138 |
+
loras.append(lora)
|
139 |
+
return loras
|
140 |
+
|
141 |
+
self.text_encoder_loras = create_modules(
|
142 |
+
LoRANetwork.LORA_PREFIX_TEXT_ENCODER,
|
143 |
+
text_encoder,
|
144 |
+
LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
145 |
+
)
|
146 |
+
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
147 |
+
|
148 |
+
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
|
149 |
+
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
150 |
+
|
151 |
+
self.weights_sd = None
|
152 |
+
|
153 |
+
# assertion
|
154 |
+
names = set()
|
155 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
156 |
+
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
157 |
+
names.add(lora.lora_name)
|
158 |
+
|
159 |
+
def set_multiplier(self, multiplier):
|
160 |
+
self.multiplier = multiplier
|
161 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
162 |
+
lora.multiplier = self.multiplier
|
163 |
+
|
164 |
+
def load_weights(self, file):
|
165 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
166 |
+
from safetensors.torch import load_file, safe_open
|
167 |
+
self.weights_sd = load_file(file)
|
168 |
+
else:
|
169 |
+
self.weights_sd = torch.load(file, map_location='cpu')
|
170 |
+
|
171 |
+
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
|
172 |
+
if self.weights_sd:
|
173 |
+
weights_has_text_encoder = weights_has_unet = False
|
174 |
+
for key in self.weights_sd.keys():
|
175 |
+
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
176 |
+
weights_has_text_encoder = True
|
177 |
+
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
178 |
+
weights_has_unet = True
|
179 |
+
|
180 |
+
if apply_text_encoder is None:
|
181 |
+
apply_text_encoder = weights_has_text_encoder
|
182 |
+
else:
|
183 |
+
assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
|
184 |
+
|
185 |
+
if apply_unet is None:
|
186 |
+
apply_unet = weights_has_unet
|
187 |
+
else:
|
188 |
+
assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
|
189 |
+
else:
|
190 |
+
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
|
191 |
+
|
192 |
+
if apply_text_encoder:
|
193 |
+
print("enable LoRA for text encoder")
|
194 |
+
else:
|
195 |
+
self.text_encoder_loras = []
|
196 |
+
|
197 |
+
if apply_unet:
|
198 |
+
print("enable LoRA for U-Net")
|
199 |
+
else:
|
200 |
+
self.unet_loras = []
|
201 |
+
|
202 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
203 |
+
lora.apply_to()
|
204 |
+
self.add_module(lora.lora_name, lora)
|
205 |
+
|
206 |
+
if self.weights_sd:
|
207 |
+
# if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
|
208 |
+
info = self.load_state_dict(self.weights_sd, False)
|
209 |
+
print(f"weights are loaded: {info}")
|
210 |
+
|
211 |
+
def enable_gradient_checkpointing(self):
|
212 |
+
# not supported
|
213 |
+
def make_ckpt(module):
|
214 |
+
if isinstance(module, torch.nn.Module):
|
215 |
+
module.grad_ckpt = True
|
216 |
+
self.apply(make_ckpt)
|
217 |
+
pass
|
218 |
+
|
219 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
|
220 |
+
def enumerate_params(loras):
|
221 |
+
params = []
|
222 |
+
for lora in loras:
|
223 |
+
params.extend(lora.parameters())
|
224 |
+
return params
|
225 |
+
|
226 |
+
self.requires_grad_(True)
|
227 |
+
all_params = []
|
228 |
+
|
229 |
+
if self.text_encoder_loras:
|
230 |
+
param_data = {'params': enumerate_params(self.text_encoder_loras)}
|
231 |
+
if text_encoder_lr is not None:
|
232 |
+
param_data['lr'] = text_encoder_lr
|
233 |
+
all_params.append(param_data)
|
234 |
+
|
235 |
+
if self.unet_loras:
|
236 |
+
param_data = {'params': enumerate_params(self.unet_loras)}
|
237 |
+
if unet_lr is not None:
|
238 |
+
param_data['lr'] = unet_lr
|
239 |
+
all_params.append(param_data)
|
240 |
+
|
241 |
+
return all_params
|
242 |
+
|
243 |
+
def prepare_grad_etc(self, text_encoder, unet):
|
244 |
+
self.requires_grad_(True)
|
245 |
+
|
246 |
+
def on_epoch_start(self, text_encoder, unet):
|
247 |
+
self.train()
|
248 |
+
|
249 |
+
def get_trainable_params(self):
|
250 |
+
return self.parameters()
|
251 |
+
|
252 |
+
def save_weights(self, file, dtype, metadata):
|
253 |
+
if metadata is not None and len(metadata) == 0:
|
254 |
+
metadata = None
|
255 |
+
|
256 |
+
state_dict = self.state_dict()
|
257 |
+
|
258 |
+
if dtype is not None:
|
259 |
+
for key in list(state_dict.keys()):
|
260 |
+
v = state_dict[key]
|
261 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
262 |
+
state_dict[key] = v
|
263 |
+
|
264 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
265 |
+
from safetensors.torch import save_file
|
266 |
+
|
267 |
+
# Precalculate model hashes to save time on indexing
|
268 |
+
if metadata is None:
|
269 |
+
metadata = {}
|
270 |
+
model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
|
271 |
+
metadata["sshs_model_hash"] = model_hash
|
272 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
273 |
+
|
274 |
+
save_file(state_dict, file, metadata)
|
275 |
+
else:
|
276 |
+
torch.save(state_dict, file)
|
lycoris/kohya_model_utils.py
ADDED
@@ -0,0 +1,1184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
https://github.com/kohya-ss/sd-scripts/blob/main/library/model_util.py
|
3 |
+
'''
|
4 |
+
# v1: split from train_db_fixed.py.
|
5 |
+
# v2: support safetensors
|
6 |
+
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import torch
|
10 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
|
11 |
+
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
12 |
+
from safetensors.torch import load_file, save_file
|
13 |
+
|
14 |
+
# DiffUsers版StableDiffusionのモデルパラメータ
|
15 |
+
NUM_TRAIN_TIMESTEPS = 1000
|
16 |
+
BETA_START = 0.00085
|
17 |
+
BETA_END = 0.0120
|
18 |
+
|
19 |
+
UNET_PARAMS_MODEL_CHANNELS = 320
|
20 |
+
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
|
21 |
+
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
|
22 |
+
UNET_PARAMS_IMAGE_SIZE = 32 # unused
|
23 |
+
UNET_PARAMS_IN_CHANNELS = 4
|
24 |
+
UNET_PARAMS_OUT_CHANNELS = 4
|
25 |
+
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
26 |
+
UNET_PARAMS_CONTEXT_DIM = 768
|
27 |
+
UNET_PARAMS_NUM_HEADS = 8
|
28 |
+
|
29 |
+
VAE_PARAMS_Z_CHANNELS = 4
|
30 |
+
VAE_PARAMS_RESOLUTION = 256
|
31 |
+
VAE_PARAMS_IN_CHANNELS = 3
|
32 |
+
VAE_PARAMS_OUT_CH = 3
|
33 |
+
VAE_PARAMS_CH = 128
|
34 |
+
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
|
35 |
+
VAE_PARAMS_NUM_RES_BLOCKS = 2
|
36 |
+
|
37 |
+
# V2
|
38 |
+
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
39 |
+
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
40 |
+
|
41 |
+
# Diffusersの設定を読み込むための参照モデル
|
42 |
+
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
|
43 |
+
DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
|
44 |
+
|
45 |
+
|
46 |
+
# region StableDiffusion->Diffusersの変換コード
|
47 |
+
# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
|
48 |
+
|
49 |
+
|
50 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
51 |
+
"""
|
52 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
53 |
+
"""
|
54 |
+
if n_shave_prefix_segments >= 0:
|
55 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
56 |
+
else:
|
57 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
58 |
+
|
59 |
+
|
60 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
61 |
+
"""
|
62 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
63 |
+
"""
|
64 |
+
mapping = []
|
65 |
+
for old_item in old_list:
|
66 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
67 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
68 |
+
|
69 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
70 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
71 |
+
|
72 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
73 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
74 |
+
|
75 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
76 |
+
|
77 |
+
mapping.append({"old": old_item, "new": new_item})
|
78 |
+
|
79 |
+
return mapping
|
80 |
+
|
81 |
+
|
82 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
83 |
+
"""
|
84 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
85 |
+
"""
|
86 |
+
mapping = []
|
87 |
+
for old_item in old_list:
|
88 |
+
new_item = old_item
|
89 |
+
|
90 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
91 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
92 |
+
|
93 |
+
mapping.append({"old": old_item, "new": new_item})
|
94 |
+
|
95 |
+
return mapping
|
96 |
+
|
97 |
+
|
98 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
99 |
+
"""
|
100 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
101 |
+
"""
|
102 |
+
mapping = []
|
103 |
+
for old_item in old_list:
|
104 |
+
new_item = old_item
|
105 |
+
|
106 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
107 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
108 |
+
|
109 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
110 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
111 |
+
|
112 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
113 |
+
|
114 |
+
mapping.append({"old": old_item, "new": new_item})
|
115 |
+
|
116 |
+
return mapping
|
117 |
+
|
118 |
+
|
119 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
120 |
+
"""
|
121 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
122 |
+
"""
|
123 |
+
mapping = []
|
124 |
+
for old_item in old_list:
|
125 |
+
new_item = old_item
|
126 |
+
|
127 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
128 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
129 |
+
|
130 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
131 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
132 |
+
|
133 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
134 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
135 |
+
|
136 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
137 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
138 |
+
|
139 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
140 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
141 |
+
|
142 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
143 |
+
|
144 |
+
mapping.append({"old": old_item, "new": new_item})
|
145 |
+
|
146 |
+
return mapping
|
147 |
+
|
148 |
+
|
149 |
+
def assign_to_checkpoint(
|
150 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
151 |
+
):
|
152 |
+
"""
|
153 |
+
This does the final conversion step: take locally converted weights and apply a global renaming
|
154 |
+
to them. It splits attention layers, and takes into account additional replacements
|
155 |
+
that may arise.
|
156 |
+
|
157 |
+
Assigns the weights to the new checkpoint.
|
158 |
+
"""
|
159 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
160 |
+
|
161 |
+
# Splits the attention layers into three variables.
|
162 |
+
if attention_paths_to_split is not None:
|
163 |
+
for path, path_map in attention_paths_to_split.items():
|
164 |
+
old_tensor = old_checkpoint[path]
|
165 |
+
channels = old_tensor.shape[0] // 3
|
166 |
+
|
167 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
168 |
+
|
169 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
170 |
+
|
171 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
172 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
173 |
+
|
174 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
175 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
176 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
177 |
+
|
178 |
+
for path in paths:
|
179 |
+
new_path = path["new"]
|
180 |
+
|
181 |
+
# These have already been assigned
|
182 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
183 |
+
continue
|
184 |
+
|
185 |
+
# Global renaming happens here
|
186 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
187 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
188 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
189 |
+
|
190 |
+
if additional_replacements is not None:
|
191 |
+
for replacement in additional_replacements:
|
192 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
193 |
+
|
194 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
195 |
+
if "proj_attn.weight" in new_path:
|
196 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
197 |
+
else:
|
198 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
199 |
+
|
200 |
+
|
201 |
+
def conv_attn_to_linear(checkpoint):
|
202 |
+
keys = list(checkpoint.keys())
|
203 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
204 |
+
for key in keys:
|
205 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
206 |
+
if checkpoint[key].ndim > 2:
|
207 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
208 |
+
elif "proj_attn.weight" in key:
|
209 |
+
if checkpoint[key].ndim > 2:
|
210 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
211 |
+
|
212 |
+
|
213 |
+
def linear_transformer_to_conv(checkpoint):
|
214 |
+
keys = list(checkpoint.keys())
|
215 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
216 |
+
for key in keys:
|
217 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
218 |
+
if checkpoint[key].ndim == 2:
|
219 |
+
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
220 |
+
|
221 |
+
|
222 |
+
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
223 |
+
"""
|
224 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
225 |
+
"""
|
226 |
+
|
227 |
+
# extract state_dict for UNet
|
228 |
+
unet_state_dict = {}
|
229 |
+
unet_key = "model.diffusion_model."
|
230 |
+
keys = list(checkpoint.keys())
|
231 |
+
for key in keys:
|
232 |
+
if key.startswith(unet_key):
|
233 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
234 |
+
|
235 |
+
new_checkpoint = {}
|
236 |
+
|
237 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
238 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
239 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
240 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
241 |
+
|
242 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
243 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
244 |
+
|
245 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
246 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
247 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
248 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
249 |
+
|
250 |
+
# Retrieves the keys for the input blocks only
|
251 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
252 |
+
input_blocks = {
|
253 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
|
254 |
+
for layer_id in range(num_input_blocks)
|
255 |
+
}
|
256 |
+
|
257 |
+
# Retrieves the keys for the middle blocks only
|
258 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
259 |
+
middle_blocks = {
|
260 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
|
261 |
+
for layer_id in range(num_middle_blocks)
|
262 |
+
}
|
263 |
+
|
264 |
+
# Retrieves the keys for the output blocks only
|
265 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
266 |
+
output_blocks = {
|
267 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
|
268 |
+
for layer_id in range(num_output_blocks)
|
269 |
+
}
|
270 |
+
|
271 |
+
for i in range(1, num_input_blocks):
|
272 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
273 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
274 |
+
|
275 |
+
resnets = [
|
276 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
277 |
+
]
|
278 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
279 |
+
|
280 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
281 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
282 |
+
f"input_blocks.{i}.0.op.weight"
|
283 |
+
)
|
284 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
285 |
+
f"input_blocks.{i}.0.op.bias"
|
286 |
+
)
|
287 |
+
|
288 |
+
paths = renew_resnet_paths(resnets)
|
289 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
290 |
+
assign_to_checkpoint(
|
291 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
292 |
+
)
|
293 |
+
|
294 |
+
if len(attentions):
|
295 |
+
paths = renew_attention_paths(attentions)
|
296 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
297 |
+
assign_to_checkpoint(
|
298 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
299 |
+
)
|
300 |
+
|
301 |
+
resnet_0 = middle_blocks[0]
|
302 |
+
attentions = middle_blocks[1]
|
303 |
+
resnet_1 = middle_blocks[2]
|
304 |
+
|
305 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
306 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
307 |
+
|
308 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
309 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
310 |
+
|
311 |
+
attentions_paths = renew_attention_paths(attentions)
|
312 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
313 |
+
assign_to_checkpoint(
|
314 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
315 |
+
)
|
316 |
+
|
317 |
+
for i in range(num_output_blocks):
|
318 |
+
block_id = i // (config["layers_per_block"] + 1)
|
319 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
320 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
321 |
+
output_block_list = {}
|
322 |
+
|
323 |
+
for layer in output_block_layers:
|
324 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
325 |
+
if layer_id in output_block_list:
|
326 |
+
output_block_list[layer_id].append(layer_name)
|
327 |
+
else:
|
328 |
+
output_block_list[layer_id] = [layer_name]
|
329 |
+
|
330 |
+
if len(output_block_list) > 1:
|
331 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
332 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
333 |
+
|
334 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
335 |
+
paths = renew_resnet_paths(resnets)
|
336 |
+
|
337 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
338 |
+
assign_to_checkpoint(
|
339 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
340 |
+
)
|
341 |
+
|
342 |
+
# オリジナル:
|
343 |
+
# if ["conv.weight", "conv.bias"] in output_block_list.values():
|
344 |
+
# index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
345 |
+
|
346 |
+
# biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
|
347 |
+
for l in output_block_list.values():
|
348 |
+
l.sort()
|
349 |
+
|
350 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
351 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
352 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
353 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
354 |
+
]
|
355 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
356 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
357 |
+
]
|
358 |
+
|
359 |
+
# Clear attentions as they have been attributed above.
|
360 |
+
if len(attentions) == 2:
|
361 |
+
attentions = []
|
362 |
+
|
363 |
+
if len(attentions):
|
364 |
+
paths = renew_attention_paths(attentions)
|
365 |
+
meta_path = {
|
366 |
+
"old": f"output_blocks.{i}.1",
|
367 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
368 |
+
}
|
369 |
+
assign_to_checkpoint(
|
370 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
371 |
+
)
|
372 |
+
else:
|
373 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
374 |
+
for path in resnet_0_paths:
|
375 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
376 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
377 |
+
|
378 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
379 |
+
|
380 |
+
# SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
|
381 |
+
if v2:
|
382 |
+
linear_transformer_to_conv(new_checkpoint)
|
383 |
+
|
384 |
+
return new_checkpoint
|
385 |
+
|
386 |
+
|
387 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
388 |
+
# extract state dict for VAE
|
389 |
+
vae_state_dict = {}
|
390 |
+
vae_key = "first_stage_model."
|
391 |
+
keys = list(checkpoint.keys())
|
392 |
+
for key in keys:
|
393 |
+
if key.startswith(vae_key):
|
394 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
395 |
+
# if len(vae_state_dict) == 0:
|
396 |
+
# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
|
397 |
+
# vae_state_dict = checkpoint
|
398 |
+
|
399 |
+
new_checkpoint = {}
|
400 |
+
|
401 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
402 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
403 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
404 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
405 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
406 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
407 |
+
|
408 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
409 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
410 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
411 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
412 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
413 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
414 |
+
|
415 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
416 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
417 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
418 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
419 |
+
|
420 |
+
# Retrieves the keys for the encoder down blocks only
|
421 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
422 |
+
down_blocks = {
|
423 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
424 |
+
}
|
425 |
+
|
426 |
+
# Retrieves the keys for the decoder up blocks only
|
427 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
428 |
+
up_blocks = {
|
429 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
430 |
+
}
|
431 |
+
|
432 |
+
for i in range(num_down_blocks):
|
433 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
434 |
+
|
435 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
436 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
437 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
438 |
+
)
|
439 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
440 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
441 |
+
)
|
442 |
+
|
443 |
+
paths = renew_vae_resnet_paths(resnets)
|
444 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
445 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
446 |
+
|
447 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
448 |
+
num_mid_res_blocks = 2
|
449 |
+
for i in range(1, num_mid_res_blocks + 1):
|
450 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
451 |
+
|
452 |
+
paths = renew_vae_resnet_paths(resnets)
|
453 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
454 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
455 |
+
|
456 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
457 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
458 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
459 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
460 |
+
conv_attn_to_linear(new_checkpoint)
|
461 |
+
|
462 |
+
for i in range(num_up_blocks):
|
463 |
+
block_id = num_up_blocks - 1 - i
|
464 |
+
resnets = [
|
465 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
466 |
+
]
|
467 |
+
|
468 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
469 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
470 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
471 |
+
]
|
472 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
473 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
474 |
+
]
|
475 |
+
|
476 |
+
paths = renew_vae_resnet_paths(resnets)
|
477 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
478 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
479 |
+
|
480 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
481 |
+
num_mid_res_blocks = 2
|
482 |
+
for i in range(1, num_mid_res_blocks + 1):
|
483 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
484 |
+
|
485 |
+
paths = renew_vae_resnet_paths(resnets)
|
486 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
487 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
488 |
+
|
489 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
490 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
491 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
492 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
493 |
+
conv_attn_to_linear(new_checkpoint)
|
494 |
+
return new_checkpoint
|
495 |
+
|
496 |
+
|
497 |
+
def create_unet_diffusers_config(v2):
|
498 |
+
"""
|
499 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
500 |
+
"""
|
501 |
+
# unet_params = original_config.model.params.unet_config.params
|
502 |
+
|
503 |
+
block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
|
504 |
+
|
505 |
+
down_block_types = []
|
506 |
+
resolution = 1
|
507 |
+
for i in range(len(block_out_channels)):
|
508 |
+
block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
|
509 |
+
down_block_types.append(block_type)
|
510 |
+
if i != len(block_out_channels) - 1:
|
511 |
+
resolution *= 2
|
512 |
+
|
513 |
+
up_block_types = []
|
514 |
+
for i in range(len(block_out_channels)):
|
515 |
+
block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
|
516 |
+
up_block_types.append(block_type)
|
517 |
+
resolution //= 2
|
518 |
+
|
519 |
+
config = dict(
|
520 |
+
sample_size=UNET_PARAMS_IMAGE_SIZE,
|
521 |
+
in_channels=UNET_PARAMS_IN_CHANNELS,
|
522 |
+
out_channels=UNET_PARAMS_OUT_CHANNELS,
|
523 |
+
down_block_types=tuple(down_block_types),
|
524 |
+
up_block_types=tuple(up_block_types),
|
525 |
+
block_out_channels=tuple(block_out_channels),
|
526 |
+
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
527 |
+
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
|
528 |
+
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
529 |
+
)
|
530 |
+
|
531 |
+
return config
|
532 |
+
|
533 |
+
|
534 |
+
def create_vae_diffusers_config():
|
535 |
+
"""
|
536 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
537 |
+
"""
|
538 |
+
# vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
539 |
+
# _ = original_config.model.params.first_stage_config.params.embed_dim
|
540 |
+
block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
|
541 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
542 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
543 |
+
|
544 |
+
config = dict(
|
545 |
+
sample_size=VAE_PARAMS_RESOLUTION,
|
546 |
+
in_channels=VAE_PARAMS_IN_CHANNELS,
|
547 |
+
out_channels=VAE_PARAMS_OUT_CH,
|
548 |
+
down_block_types=tuple(down_block_types),
|
549 |
+
up_block_types=tuple(up_block_types),
|
550 |
+
block_out_channels=tuple(block_out_channels),
|
551 |
+
latent_channels=VAE_PARAMS_Z_CHANNELS,
|
552 |
+
layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
|
553 |
+
)
|
554 |
+
return config
|
555 |
+
|
556 |
+
|
557 |
+
def convert_ldm_clip_checkpoint_v1(checkpoint):
|
558 |
+
keys = list(checkpoint.keys())
|
559 |
+
text_model_dict = {}
|
560 |
+
for key in keys:
|
561 |
+
if key.startswith("cond_stage_model.transformer"):
|
562 |
+
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
|
563 |
+
return text_model_dict
|
564 |
+
|
565 |
+
|
566 |
+
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
567 |
+
# 嫌になるくらい違うぞ!
|
568 |
+
def convert_key(key):
|
569 |
+
if not key.startswith("cond_stage_model"):
|
570 |
+
return None
|
571 |
+
|
572 |
+
# common conversion
|
573 |
+
key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
|
574 |
+
key = key.replace("cond_stage_model.model.", "text_model.")
|
575 |
+
|
576 |
+
if "resblocks" in key:
|
577 |
+
# resblocks conversion
|
578 |
+
key = key.replace(".resblocks.", ".layers.")
|
579 |
+
if ".ln_" in key:
|
580 |
+
key = key.replace(".ln_", ".layer_norm")
|
581 |
+
elif ".mlp." in key:
|
582 |
+
key = key.replace(".c_fc.", ".fc1.")
|
583 |
+
key = key.replace(".c_proj.", ".fc2.")
|
584 |
+
elif '.attn.out_proj' in key:
|
585 |
+
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
586 |
+
elif '.attn.in_proj' in key:
|
587 |
+
key = None # 特殊なので後で処理する
|
588 |
+
else:
|
589 |
+
raise ValueError(f"unexpected key in SD: {key}")
|
590 |
+
elif '.positional_embedding' in key:
|
591 |
+
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
592 |
+
elif '.text_projection' in key:
|
593 |
+
key = None # 使われない???
|
594 |
+
elif '.logit_scale' in key:
|
595 |
+
key = None # 使われない???
|
596 |
+
elif '.token_embedding' in key:
|
597 |
+
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
598 |
+
elif '.ln_final' in key:
|
599 |
+
key = key.replace(".ln_final", ".final_layer_norm")
|
600 |
+
return key
|
601 |
+
|
602 |
+
keys = list(checkpoint.keys())
|
603 |
+
new_sd = {}
|
604 |
+
for key in keys:
|
605 |
+
# remove resblocks 23
|
606 |
+
if '.resblocks.23.' in key:
|
607 |
+
continue
|
608 |
+
new_key = convert_key(key)
|
609 |
+
if new_key is None:
|
610 |
+
continue
|
611 |
+
new_sd[new_key] = checkpoint[key]
|
612 |
+
|
613 |
+
# attnの変換
|
614 |
+
for key in keys:
|
615 |
+
if '.resblocks.23.' in key:
|
616 |
+
continue
|
617 |
+
if '.resblocks' in key and '.attn.in_proj_' in key:
|
618 |
+
# 三つに分割
|
619 |
+
values = torch.chunk(checkpoint[key], 3)
|
620 |
+
|
621 |
+
key_suffix = ".weight" if "weight" in key else ".bias"
|
622 |
+
key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
|
623 |
+
key_pfx = key_pfx.replace("_weight", "")
|
624 |
+
key_pfx = key_pfx.replace("_bias", "")
|
625 |
+
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
626 |
+
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
627 |
+
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
628 |
+
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
629 |
+
|
630 |
+
# rename or add position_ids
|
631 |
+
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
|
632 |
+
if ANOTHER_POSITION_IDS_KEY in new_sd:
|
633 |
+
# waifu diffusion v1.4
|
634 |
+
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
|
635 |
+
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
636 |
+
else:
|
637 |
+
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
638 |
+
|
639 |
+
new_sd["text_model.embeddings.position_ids"] = position_ids
|
640 |
+
return new_sd
|
641 |
+
|
642 |
+
# endregion
|
643 |
+
|
644 |
+
|
645 |
+
# region Diffusers->StableDiffusion の変換コード
|
646 |
+
# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
|
647 |
+
|
648 |
+
def conv_transformer_to_linear(checkpoint):
|
649 |
+
keys = list(checkpoint.keys())
|
650 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
651 |
+
for key in keys:
|
652 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
653 |
+
if checkpoint[key].ndim > 2:
|
654 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
655 |
+
|
656 |
+
|
657 |
+
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
658 |
+
unet_conversion_map = [
|
659 |
+
# (stable-diffusion, HF Diffusers)
|
660 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
661 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
662 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
663 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
664 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
665 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
666 |
+
("out.0.weight", "conv_norm_out.weight"),
|
667 |
+
("out.0.bias", "conv_norm_out.bias"),
|
668 |
+
("out.2.weight", "conv_out.weight"),
|
669 |
+
("out.2.bias", "conv_out.bias"),
|
670 |
+
]
|
671 |
+
|
672 |
+
unet_conversion_map_resnet = [
|
673 |
+
# (stable-diffusion, HF Diffusers)
|
674 |
+
("in_layers.0", "norm1"),
|
675 |
+
("in_layers.2", "conv1"),
|
676 |
+
("out_layers.0", "norm2"),
|
677 |
+
("out_layers.3", "conv2"),
|
678 |
+
("emb_layers.1", "time_emb_proj"),
|
679 |
+
("skip_connection", "conv_shortcut"),
|
680 |
+
]
|
681 |
+
|
682 |
+
unet_conversion_map_layer = []
|
683 |
+
for i in range(4):
|
684 |
+
# loop over downblocks/upblocks
|
685 |
+
|
686 |
+
for j in range(2):
|
687 |
+
# loop over resnets/attentions for downblocks
|
688 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
689 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
690 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
691 |
+
|
692 |
+
if i < 3:
|
693 |
+
# no attention layers in down_blocks.3
|
694 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
695 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
696 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
697 |
+
|
698 |
+
for j in range(3):
|
699 |
+
# loop over resnets/attentions for upblocks
|
700 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
701 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
702 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
703 |
+
|
704 |
+
if i > 0:
|
705 |
+
# no attention layers in up_blocks.0
|
706 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
707 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
708 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
709 |
+
|
710 |
+
if i < 3:
|
711 |
+
# no downsample in down_blocks.3
|
712 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
713 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
714 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
715 |
+
|
716 |
+
# no upsample in up_blocks.3
|
717 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
718 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
719 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
720 |
+
|
721 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
722 |
+
sd_mid_atn_prefix = "middle_block.1."
|
723 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
724 |
+
|
725 |
+
for j in range(2):
|
726 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
727 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
728 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
729 |
+
|
730 |
+
# buyer beware: this is a *brittle* function,
|
731 |
+
# and correct output requires that all of these pieces interact in
|
732 |
+
# the exact order in which I have arranged them.
|
733 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
734 |
+
for sd_name, hf_name in unet_conversion_map:
|
735 |
+
mapping[hf_name] = sd_name
|
736 |
+
for k, v in mapping.items():
|
737 |
+
if "resnets" in k:
|
738 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
739 |
+
v = v.replace(hf_part, sd_part)
|
740 |
+
mapping[k] = v
|
741 |
+
for k, v in mapping.items():
|
742 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
743 |
+
v = v.replace(hf_part, sd_part)
|
744 |
+
mapping[k] = v
|
745 |
+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
746 |
+
|
747 |
+
if v2:
|
748 |
+
conv_transformer_to_linear(new_state_dict)
|
749 |
+
|
750 |
+
return new_state_dict
|
751 |
+
|
752 |
+
|
753 |
+
# ================#
|
754 |
+
# VAE Conversion #
|
755 |
+
# ================#
|
756 |
+
|
757 |
+
def reshape_weight_for_sd(w):
|
758 |
+
# convert HF linear weights to SD conv2d weights
|
759 |
+
return w.reshape(*w.shape, 1, 1)
|
760 |
+
|
761 |
+
|
762 |
+
def convert_vae_state_dict(vae_state_dict):
|
763 |
+
vae_conversion_map = [
|
764 |
+
# (stable-diffusion, HF Diffusers)
|
765 |
+
("nin_shortcut", "conv_shortcut"),
|
766 |
+
("norm_out", "conv_norm_out"),
|
767 |
+
("mid.attn_1.", "mid_block.attentions.0."),
|
768 |
+
]
|
769 |
+
|
770 |
+
for i in range(4):
|
771 |
+
# down_blocks have two resnets
|
772 |
+
for j in range(2):
|
773 |
+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
774 |
+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
775 |
+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
776 |
+
|
777 |
+
if i < 3:
|
778 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
779 |
+
sd_downsample_prefix = f"down.{i}.downsample."
|
780 |
+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
781 |
+
|
782 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
783 |
+
sd_upsample_prefix = f"up.{3-i}.upsample."
|
784 |
+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
785 |
+
|
786 |
+
# up_blocks have three resnets
|
787 |
+
# also, up blocks in hf are numbered in reverse from sd
|
788 |
+
for j in range(3):
|
789 |
+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
790 |
+
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
791 |
+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
792 |
+
|
793 |
+
# this part accounts for mid blocks in both the encoder and the decoder
|
794 |
+
for i in range(2):
|
795 |
+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
796 |
+
sd_mid_res_prefix = f"mid.block_{i+1}."
|
797 |
+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
798 |
+
|
799 |
+
vae_conversion_map_attn = [
|
800 |
+
# (stable-diffusion, HF Diffusers)
|
801 |
+
("norm.", "group_norm."),
|
802 |
+
("q.", "query."),
|
803 |
+
("k.", "key."),
|
804 |
+
("v.", "value."),
|
805 |
+
("proj_out.", "proj_attn."),
|
806 |
+
]
|
807 |
+
|
808 |
+
mapping = {k: k for k in vae_state_dict.keys()}
|
809 |
+
for k, v in mapping.items():
|
810 |
+
for sd_part, hf_part in vae_conversion_map:
|
811 |
+
v = v.replace(hf_part, sd_part)
|
812 |
+
mapping[k] = v
|
813 |
+
for k, v in mapping.items():
|
814 |
+
if "attentions" in k:
|
815 |
+
for sd_part, hf_part in vae_conversion_map_attn:
|
816 |
+
v = v.replace(hf_part, sd_part)
|
817 |
+
mapping[k] = v
|
818 |
+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
819 |
+
weights_to_convert = ["q", "k", "v", "proj_out"]
|
820 |
+
for k, v in new_state_dict.items():
|
821 |
+
for weight_name in weights_to_convert:
|
822 |
+
if f"mid.attn_1.{weight_name}.weight" in k:
|
823 |
+
# print(f"Reshaping {k} for SD format")
|
824 |
+
new_state_dict[k] = reshape_weight_for_sd(v)
|
825 |
+
|
826 |
+
return new_state_dict
|
827 |
+
|
828 |
+
|
829 |
+
# endregion
|
830 |
+
|
831 |
+
# region 自作のモデル読み書きなど
|
832 |
+
|
833 |
+
def is_safetensors(path):
|
834 |
+
return os.path.splitext(path)[1].lower() == '.safetensors'
|
835 |
+
|
836 |
+
|
837 |
+
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
838 |
+
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
839 |
+
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
840 |
+
('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
|
841 |
+
('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
|
842 |
+
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
|
843 |
+
]
|
844 |
+
|
845 |
+
if is_safetensors(ckpt_path):
|
846 |
+
checkpoint = None
|
847 |
+
state_dict = load_file(ckpt_path, "cpu")
|
848 |
+
else:
|
849 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
850 |
+
if "state_dict" in checkpoint:
|
851 |
+
state_dict = checkpoint["state_dict"]
|
852 |
+
else:
|
853 |
+
state_dict = checkpoint
|
854 |
+
checkpoint = None
|
855 |
+
|
856 |
+
key_reps = []
|
857 |
+
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
858 |
+
for key in state_dict.keys():
|
859 |
+
if key.startswith(rep_from):
|
860 |
+
new_key = rep_to + key[len(rep_from):]
|
861 |
+
key_reps.append((key, new_key))
|
862 |
+
|
863 |
+
for key, new_key in key_reps:
|
864 |
+
state_dict[new_key] = state_dict[key]
|
865 |
+
del state_dict[key]
|
866 |
+
|
867 |
+
return checkpoint, state_dict
|
868 |
+
|
869 |
+
|
870 |
+
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
871 |
+
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
872 |
+
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
873 |
+
if dtype is not None:
|
874 |
+
for k, v in state_dict.items():
|
875 |
+
if type(v) is torch.Tensor:
|
876 |
+
state_dict[k] = v.to(dtype)
|
877 |
+
|
878 |
+
# Convert the UNet2DConditionModel model.
|
879 |
+
unet_config = create_unet_diffusers_config(v2)
|
880 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
|
881 |
+
|
882 |
+
unet = UNet2DConditionModel(**unet_config)
|
883 |
+
info = unet.load_state_dict(converted_unet_checkpoint)
|
884 |
+
print("loading u-net:", info)
|
885 |
+
|
886 |
+
# Convert the VAE model.
|
887 |
+
vae_config = create_vae_diffusers_config()
|
888 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
|
889 |
+
|
890 |
+
vae = AutoencoderKL(**vae_config)
|
891 |
+
info = vae.load_state_dict(converted_vae_checkpoint)
|
892 |
+
print("loading vae:", info)
|
893 |
+
|
894 |
+
# convert text_model
|
895 |
+
if v2:
|
896 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
|
897 |
+
cfg = CLIPTextConfig(
|
898 |
+
vocab_size=49408,
|
899 |
+
hidden_size=1024,
|
900 |
+
intermediate_size=4096,
|
901 |
+
num_hidden_layers=23,
|
902 |
+
num_attention_heads=16,
|
903 |
+
max_position_embeddings=77,
|
904 |
+
hidden_act="gelu",
|
905 |
+
layer_norm_eps=1e-05,
|
906 |
+
dropout=0.0,
|
907 |
+
attention_dropout=0.0,
|
908 |
+
initializer_range=0.02,
|
909 |
+
initializer_factor=1.0,
|
910 |
+
pad_token_id=1,
|
911 |
+
bos_token_id=0,
|
912 |
+
eos_token_id=2,
|
913 |
+
model_type="clip_text_model",
|
914 |
+
projection_dim=512,
|
915 |
+
torch_dtype="float32",
|
916 |
+
transformers_version="4.25.0.dev0",
|
917 |
+
)
|
918 |
+
text_model = CLIPTextModel._from_config(cfg)
|
919 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
920 |
+
else:
|
921 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
922 |
+
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
923 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
924 |
+
print("loading text encoder:", info)
|
925 |
+
|
926 |
+
return text_model, vae, unet
|
927 |
+
|
928 |
+
|
929 |
+
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
|
930 |
+
def convert_key(key):
|
931 |
+
# position_idsの除去
|
932 |
+
if ".position_ids" in key:
|
933 |
+
return None
|
934 |
+
|
935 |
+
# common
|
936 |
+
key = key.replace("text_model.encoder.", "transformer.")
|
937 |
+
key = key.replace("text_model.", "")
|
938 |
+
if "layers" in key:
|
939 |
+
# resblocks conversion
|
940 |
+
key = key.replace(".layers.", ".resblocks.")
|
941 |
+
if ".layer_norm" in key:
|
942 |
+
key = key.replace(".layer_norm", ".ln_")
|
943 |
+
elif ".mlp." in key:
|
944 |
+
key = key.replace(".fc1.", ".c_fc.")
|
945 |
+
key = key.replace(".fc2.", ".c_proj.")
|
946 |
+
elif '.self_attn.out_proj' in key:
|
947 |
+
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
948 |
+
elif '.self_attn.' in key:
|
949 |
+
key = None # 特殊なので後で処理する
|
950 |
+
else:
|
951 |
+
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
952 |
+
elif '.position_embedding' in key:
|
953 |
+
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
954 |
+
elif '.token_embedding' in key:
|
955 |
+
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
956 |
+
elif 'final_layer_norm' in key:
|
957 |
+
key = key.replace("final_layer_norm", "ln_final")
|
958 |
+
return key
|
959 |
+
|
960 |
+
keys = list(checkpoint.keys())
|
961 |
+
new_sd = {}
|
962 |
+
for key in keys:
|
963 |
+
new_key = convert_key(key)
|
964 |
+
if new_key is None:
|
965 |
+
continue
|
966 |
+
new_sd[new_key] = checkpoint[key]
|
967 |
+
|
968 |
+
# attnの変換
|
969 |
+
for key in keys:
|
970 |
+
if 'layers' in key and 'q_proj' in key:
|
971 |
+
# 三つを結合
|
972 |
+
key_q = key
|
973 |
+
key_k = key.replace("q_proj", "k_proj")
|
974 |
+
key_v = key.replace("q_proj", "v_proj")
|
975 |
+
|
976 |
+
value_q = checkpoint[key_q]
|
977 |
+
value_k = checkpoint[key_k]
|
978 |
+
value_v = checkpoint[key_v]
|
979 |
+
value = torch.cat([value_q, value_k, value_v])
|
980 |
+
|
981 |
+
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
982 |
+
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
983 |
+
new_sd[new_key] = value
|
984 |
+
|
985 |
+
# 最後の層などを捏造するか
|
986 |
+
if make_dummy_weights:
|
987 |
+
print("make dummy weights for resblock.23, text_projection and logit scale.")
|
988 |
+
keys = list(new_sd.keys())
|
989 |
+
for key in keys:
|
990 |
+
if key.startswith("transformer.resblocks.22."):
|
991 |
+
new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
|
992 |
+
|
993 |
+
# Diffusersに含まれない重みを作っておく
|
994 |
+
new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
|
995 |
+
new_sd['logit_scale'] = torch.tensor(1)
|
996 |
+
|
997 |
+
return new_sd
|
998 |
+
|
999 |
+
|
1000 |
+
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
|
1001 |
+
if ckpt_path is not None:
|
1002 |
+
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
1003 |
+
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
1004 |
+
if checkpoint is None: # safetensors または state_dictのckpt
|
1005 |
+
checkpoint = {}
|
1006 |
+
strict = False
|
1007 |
+
else:
|
1008 |
+
strict = True
|
1009 |
+
if "state_dict" in state_dict:
|
1010 |
+
del state_dict["state_dict"]
|
1011 |
+
else:
|
1012 |
+
# 新しく作る
|
1013 |
+
assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
|
1014 |
+
checkpoint = {}
|
1015 |
+
state_dict = {}
|
1016 |
+
strict = False
|
1017 |
+
|
1018 |
+
def update_sd(prefix, sd):
|
1019 |
+
for k, v in sd.items():
|
1020 |
+
key = prefix + k
|
1021 |
+
assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
|
1022 |
+
if save_dtype is not None:
|
1023 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
1024 |
+
state_dict[key] = v
|
1025 |
+
|
1026 |
+
# Convert the UNet model
|
1027 |
+
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
1028 |
+
update_sd("model.diffusion_model.", unet_state_dict)
|
1029 |
+
|
1030 |
+
# Convert the text encoder model
|
1031 |
+
if v2:
|
1032 |
+
make_dummy = ckpt_path is None # 参照元のcheckpoint���ない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
|
1033 |
+
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
|
1034 |
+
update_sd("cond_stage_model.model.", text_enc_dict)
|
1035 |
+
else:
|
1036 |
+
text_enc_dict = text_encoder.state_dict()
|
1037 |
+
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
1038 |
+
|
1039 |
+
# Convert the VAE
|
1040 |
+
if vae is not None:
|
1041 |
+
vae_dict = convert_vae_state_dict(vae.state_dict())
|
1042 |
+
update_sd("first_stage_model.", vae_dict)
|
1043 |
+
|
1044 |
+
# Put together new checkpoint
|
1045 |
+
key_count = len(state_dict.keys())
|
1046 |
+
new_ckpt = {'state_dict': state_dict}
|
1047 |
+
|
1048 |
+
if 'epoch' in checkpoint:
|
1049 |
+
epochs += checkpoint['epoch']
|
1050 |
+
if 'global_step' in checkpoint:
|
1051 |
+
steps += checkpoint['global_step']
|
1052 |
+
|
1053 |
+
new_ckpt['epoch'] = epochs
|
1054 |
+
new_ckpt['global_step'] = steps
|
1055 |
+
|
1056 |
+
if is_safetensors(output_file):
|
1057 |
+
# TODO Tensor以外のdictの値を削除したほうがいいか
|
1058 |
+
save_file(state_dict, output_file)
|
1059 |
+
else:
|
1060 |
+
torch.save(new_ckpt, output_file)
|
1061 |
+
|
1062 |
+
return key_count
|
1063 |
+
|
1064 |
+
|
1065 |
+
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
|
1066 |
+
if pretrained_model_name_or_path is None:
|
1067 |
+
# load default settings for v1/v2
|
1068 |
+
if v2:
|
1069 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
|
1070 |
+
else:
|
1071 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
|
1072 |
+
|
1073 |
+
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
1074 |
+
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
1075 |
+
if vae is None:
|
1076 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
1077 |
+
|
1078 |
+
pipeline = StableDiffusionPipeline(
|
1079 |
+
unet=unet,
|
1080 |
+
text_encoder=text_encoder,
|
1081 |
+
vae=vae,
|
1082 |
+
scheduler=scheduler,
|
1083 |
+
tokenizer=tokenizer,
|
1084 |
+
safety_checker=None,
|
1085 |
+
feature_extractor=None,
|
1086 |
+
requires_safety_checker=None,
|
1087 |
+
)
|
1088 |
+
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
1089 |
+
|
1090 |
+
|
1091 |
+
VAE_PREFIX = "first_stage_model."
|
1092 |
+
|
1093 |
+
|
1094 |
+
def load_vae(vae_id, dtype):
|
1095 |
+
print(f"load VAE: {vae_id}")
|
1096 |
+
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
1097 |
+
# Diffusers local/remote
|
1098 |
+
try:
|
1099 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
|
1100 |
+
except EnvironmentError as e:
|
1101 |
+
print(f"exception occurs in loading vae: {e}")
|
1102 |
+
print("retry with subfolder='vae'")
|
1103 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
|
1104 |
+
return vae
|
1105 |
+
|
1106 |
+
# local
|
1107 |
+
vae_config = create_vae_diffusers_config()
|
1108 |
+
|
1109 |
+
if vae_id.endswith(".bin"):
|
1110 |
+
# SD 1.5 VAE on Huggingface
|
1111 |
+
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
|
1112 |
+
else:
|
1113 |
+
# StableDiffusion
|
1114 |
+
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
|
1115 |
+
else torch.load(vae_id, map_location="cpu"))
|
1116 |
+
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
|
1117 |
+
|
1118 |
+
# vae only or full model
|
1119 |
+
full_model = False
|
1120 |
+
for vae_key in vae_sd:
|
1121 |
+
if vae_key.startswith(VAE_PREFIX):
|
1122 |
+
full_model = True
|
1123 |
+
break
|
1124 |
+
if not full_model:
|
1125 |
+
sd = {}
|
1126 |
+
for key, value in vae_sd.items():
|
1127 |
+
sd[VAE_PREFIX + key] = value
|
1128 |
+
vae_sd = sd
|
1129 |
+
del sd
|
1130 |
+
|
1131 |
+
# Convert the VAE model.
|
1132 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
1133 |
+
|
1134 |
+
vae = AutoencoderKL(**vae_config)
|
1135 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
1136 |
+
return vae
|
1137 |
+
|
1138 |
+
# endregion
|
1139 |
+
|
1140 |
+
|
1141 |
+
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
|
1142 |
+
max_width, max_height = max_reso
|
1143 |
+
max_area = (max_width // divisible) * (max_height // divisible)
|
1144 |
+
|
1145 |
+
resos = set()
|
1146 |
+
|
1147 |
+
size = int(math.sqrt(max_area)) * divisible
|
1148 |
+
resos.add((size, size))
|
1149 |
+
|
1150 |
+
size = min_size
|
1151 |
+
while size <= max_size:
|
1152 |
+
width = size
|
1153 |
+
height = min(max_size, (max_area // (width // divisible)) * divisible)
|
1154 |
+
resos.add((width, height))
|
1155 |
+
resos.add((height, width))
|
1156 |
+
|
1157 |
+
# # make additional resos
|
1158 |
+
# if width >= height and width - divisible >= min_size:
|
1159 |
+
# resos.add((width - divisible, height))
|
1160 |
+
# resos.add((height, width - divisible))
|
1161 |
+
# if height >= width and height - divisible >= min_size:
|
1162 |
+
# resos.add((width, height - divisible))
|
1163 |
+
# resos.add((height - divisible, width))
|
1164 |
+
|
1165 |
+
size += divisible
|
1166 |
+
|
1167 |
+
resos = list(resos)
|
1168 |
+
resos.sort()
|
1169 |
+
|
1170 |
+
aspect_ratios = [w / h for w, h in resos]
|
1171 |
+
return resos, aspect_ratios
|
1172 |
+
|
1173 |
+
|
1174 |
+
if __name__ == '__main__':
|
1175 |
+
resos, aspect_ratios = make_bucket_resolutions((512, 768))
|
1176 |
+
print(len(resos))
|
1177 |
+
print(resos)
|
1178 |
+
print(aspect_ratios)
|
1179 |
+
|
1180 |
+
ars = set()
|
1181 |
+
for ar in aspect_ratios:
|
1182 |
+
if ar in ars:
|
1183 |
+
print("error! duplicate ar:", ar)
|
1184 |
+
ars.add(ar)
|
lycoris/kohya_utils.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# part of https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py
|
2 |
+
|
3 |
+
import hashlib
|
4 |
+
import safetensors
|
5 |
+
from io import BytesIO
|
6 |
+
|
7 |
+
|
8 |
+
def addnet_hash_legacy(b):
|
9 |
+
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
|
10 |
+
m = hashlib.sha256()
|
11 |
+
|
12 |
+
b.seek(0x100000)
|
13 |
+
m.update(b.read(0x10000))
|
14 |
+
return m.hexdigest()[0:8]
|
15 |
+
|
16 |
+
|
17 |
+
def addnet_hash_safetensors(b):
|
18 |
+
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
|
19 |
+
hash_sha256 = hashlib.sha256()
|
20 |
+
blksize = 1024 * 1024
|
21 |
+
|
22 |
+
b.seek(0)
|
23 |
+
header = b.read(8)
|
24 |
+
n = int.from_bytes(header, "little")
|
25 |
+
|
26 |
+
offset = n + 8
|
27 |
+
b.seek(offset)
|
28 |
+
for chunk in iter(lambda: b.read(blksize), b""):
|
29 |
+
hash_sha256.update(chunk)
|
30 |
+
|
31 |
+
return hash_sha256.hexdigest()
|
32 |
+
|
33 |
+
|
34 |
+
def precalculate_safetensors_hashes(tensors, metadata):
|
35 |
+
"""Precalculate the model hashes needed by sd-webui-additional-networks to
|
36 |
+
save time on indexing the model later."""
|
37 |
+
|
38 |
+
# Because writing user metadata to the file can change the result of
|
39 |
+
# sd_models.model_hash(), only retain the training metadata for purposes of
|
40 |
+
# calculating the hash, as they are meant to be immutable
|
41 |
+
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
|
42 |
+
|
43 |
+
bytes = safetensors.torch.save(tensors, metadata)
|
44 |
+
b = BytesIO(bytes)
|
45 |
+
|
46 |
+
model_hash = addnet_hash_safetensors(b)
|
47 |
+
legacy_hash = addnet_hash_legacy(b)
|
48 |
+
return model_hash, legacy_hash
|
lycoris/locon.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class LoConModule(nn.Module):
|
9 |
+
"""
|
10 |
+
modifed from kohya-ss/sd-scripts/networks/lora:LoRAModule
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, lora_name, org_module: nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=0.):
|
14 |
+
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
15 |
+
super().__init__()
|
16 |
+
self.lora_name = lora_name
|
17 |
+
self.lora_dim = lora_dim
|
18 |
+
|
19 |
+
if org_module.__class__.__name__ == 'Conv2d':
|
20 |
+
# For general LoCon
|
21 |
+
in_dim = org_module.in_channels
|
22 |
+
k_size = org_module.kernel_size
|
23 |
+
stride = org_module.stride
|
24 |
+
padding = org_module.padding
|
25 |
+
out_dim = org_module.out_channels
|
26 |
+
self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False)
|
27 |
+
self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
|
28 |
+
self.op = F.conv2d
|
29 |
+
self.extra_args = {
|
30 |
+
'stride': stride,
|
31 |
+
'padding': padding
|
32 |
+
}
|
33 |
+
else:
|
34 |
+
in_dim = org_module.in_features
|
35 |
+
out_dim = org_module.out_features
|
36 |
+
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
|
37 |
+
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
|
38 |
+
self.op = F.linear
|
39 |
+
self.extra_args = {}
|
40 |
+
self.shape = org_module.weight.shape
|
41 |
+
|
42 |
+
if dropout:
|
43 |
+
self.dropout = nn.Dropout(dropout)
|
44 |
+
else:
|
45 |
+
self.dropout = nn.Identity()
|
46 |
+
|
47 |
+
if type(alpha) == torch.Tensor:
|
48 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
49 |
+
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
50 |
+
self.scale = alpha / self.lora_dim
|
51 |
+
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
52 |
+
|
53 |
+
# same as microsoft's
|
54 |
+
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
55 |
+
torch.nn.init.zeros_(self.lora_up.weight)
|
56 |
+
|
57 |
+
self.multiplier = multiplier
|
58 |
+
self.org_module = [org_module]
|
59 |
+
|
60 |
+
def apply_to(self):
|
61 |
+
self.org_module[0].forward = self.forward
|
62 |
+
|
63 |
+
def make_weight(self):
|
64 |
+
wa = self.lora_up.weight
|
65 |
+
wb = self.lora_down.weight
|
66 |
+
return (wa.view(wa.size(0), -1) @ wb.view(wb.size(0), -1)).view(self.shape)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data
|
70 |
+
return self.op(
|
71 |
+
x,
|
72 |
+
(self.org_module[0].weight.data
|
73 |
+
+ self.dropout(self.make_weight()) * self.multiplier * self.scale),
|
74 |
+
bias,
|
75 |
+
**self.extra_args,
|
76 |
+
)
|
lycoris/loha.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class HadaWeight(torch.autograd.Function):
|
9 |
+
@staticmethod
|
10 |
+
def forward(ctx, orig_weight, w1a, w1b, w2a, w2b, scale=torch.tensor(1), dropout=nn.Identity()):
|
11 |
+
ctx.save_for_backward(w1a, w1b, w2a, w2b, scale)
|
12 |
+
diff_weight = ((w1a@w1b)*(w2a@w2b)) * scale
|
13 |
+
return orig_weight.reshape(diff_weight.shape) + dropout(diff_weight)
|
14 |
+
|
15 |
+
@staticmethod
|
16 |
+
def backward(ctx, grad_out):
|
17 |
+
(w1a, w1b, w2a, w2b, scale) = ctx.saved_tensors
|
18 |
+
temp = grad_out*(w2a@w2b)*scale
|
19 |
+
grad_w1a = temp @ w1b.T
|
20 |
+
grad_w1b = w1a.T @ temp
|
21 |
+
|
22 |
+
temp = grad_out * (w1a@w1b)*scale
|
23 |
+
grad_w2a = temp @ w2b.T
|
24 |
+
grad_w2b = w2a.T @ temp
|
25 |
+
|
26 |
+
del temp
|
27 |
+
return grad_out, grad_w1a, grad_w1b, grad_w2a, grad_w2b, None
|
28 |
+
|
29 |
+
|
30 |
+
def make_weight(orig_weight, w1a, w1b, w2a, w2b, scale):
|
31 |
+
return HadaWeight.apply(orig_weight, w1a, w1b, w2a, w2b, scale)
|
32 |
+
|
33 |
+
|
34 |
+
class LohaModule(nn.Module):
|
35 |
+
"""
|
36 |
+
Hadamard product Implementaion for Low Rank Adaptation
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, lora_name, org_module: nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=0.):
|
40 |
+
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
41 |
+
super().__init__()
|
42 |
+
self.lora_name = lora_name
|
43 |
+
self.lora_dim = lora_dim
|
44 |
+
|
45 |
+
self.shape = org_module.weight.shape
|
46 |
+
if org_module.__class__.__name__ == 'Conv2d':
|
47 |
+
in_dim = org_module.in_channels
|
48 |
+
k_size = org_module.kernel_size
|
49 |
+
out_dim = org_module.out_channels
|
50 |
+
shape = (out_dim, in_dim*k_size[0]*k_size[1])
|
51 |
+
self.op = F.conv2d
|
52 |
+
self.extra_args = {
|
53 |
+
"stride": org_module.stride,
|
54 |
+
"padding": org_module.padding,
|
55 |
+
"dilation": org_module.dilation,
|
56 |
+
"groups": org_module.groups
|
57 |
+
}
|
58 |
+
else:
|
59 |
+
in_dim = org_module.in_features
|
60 |
+
out_dim = org_module.out_features
|
61 |
+
shape = (out_dim, in_dim)
|
62 |
+
self.op = F.linear
|
63 |
+
self.extra_args = {}
|
64 |
+
|
65 |
+
self.hada_w1_a = nn.Parameter(torch.empty(shape[0], lora_dim))
|
66 |
+
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1]))
|
67 |
+
|
68 |
+
self.hada_w2_a = nn.Parameter(torch.empty(shape[0], lora_dim))
|
69 |
+
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1]))
|
70 |
+
|
71 |
+
if dropout:
|
72 |
+
self.dropout = nn.Dropout(dropout)
|
73 |
+
else:
|
74 |
+
self.dropout = nn.Identity()
|
75 |
+
|
76 |
+
if type(alpha) == torch.Tensor:
|
77 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
78 |
+
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
79 |
+
self.scale = alpha / self.lora_dim
|
80 |
+
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
81 |
+
|
82 |
+
# Need more experiences on init method
|
83 |
+
torch.nn.init.normal_(self.hada_w1_b, std=1)
|
84 |
+
torch.nn.init.normal_(self.hada_w2_b, std=0.05)
|
85 |
+
torch.nn.init.normal_(self.hada_w1_a, std=1)
|
86 |
+
torch.nn.init.constant_(self.hada_w2_a, 0)
|
87 |
+
|
88 |
+
self.multiplier = multiplier
|
89 |
+
self.org_module = [org_module] # remove in applying
|
90 |
+
self.grad_ckpt = False
|
91 |
+
|
92 |
+
def apply_to(self):
|
93 |
+
self.org_module[0].forward = self.forward
|
94 |
+
|
95 |
+
def get_weight(self):
|
96 |
+
d_weight = self.hada_w1_a @ self.hada_w1_b
|
97 |
+
d_weight *= self.hada_w2_a @ self.hada_w2_b
|
98 |
+
return (d_weight).reshape(self.shape)
|
99 |
+
|
100 |
+
@torch.enable_grad()
|
101 |
+
def forward(self, x):
|
102 |
+
# print(torch.mean(torch.abs(self.orig_w1a.to(x.device) - self.hada_w1_a)), end='\r')
|
103 |
+
weight = make_weight(
|
104 |
+
self.org_module[0].weight.data,
|
105 |
+
self.hada_w1_a, self.hada_w1_b,
|
106 |
+
self.hada_w2_a, self.hada_w2_b,
|
107 |
+
scale = torch.tensor(self.scale*self.multiplier),
|
108 |
+
)
|
109 |
+
|
110 |
+
bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data
|
111 |
+
return self.op(
|
112 |
+
x,
|
113 |
+
weight.view(self.shape),
|
114 |
+
bias,
|
115 |
+
**self.extra_args
|
116 |
+
)
|
lycoris/utils.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
import torch.linalg as linalg
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
|
12 |
+
def extract_conv(
|
13 |
+
weight: Union[torch.Tensor, nn.Parameter],
|
14 |
+
mode = 'fixed',
|
15 |
+
mode_param = 0,
|
16 |
+
device = 'cpu',
|
17 |
+
) -> Tuple[nn.Parameter, nn.Parameter]:
|
18 |
+
out_ch, in_ch, kernel_size, _ = weight.shape
|
19 |
+
|
20 |
+
U, S, Vh = linalg.svd(weight.reshape(out_ch, -1).to(device))
|
21 |
+
|
22 |
+
if mode=='fixed':
|
23 |
+
lora_rank = mode_param
|
24 |
+
elif mode=='threshold':
|
25 |
+
assert mode_param>=0
|
26 |
+
lora_rank = torch.sum(S>mode_param)
|
27 |
+
elif mode=='ratio':
|
28 |
+
assert 1>=mode_param>=0
|
29 |
+
min_s = torch.max(S)*mode_param
|
30 |
+
lora_rank = torch.sum(S>min_s)
|
31 |
+
elif mode=='percentile':
|
32 |
+
assert 1>=mode_param>=0
|
33 |
+
s_cum = torch.cumsum(S, dim=0)
|
34 |
+
min_cum_sum = mode_param * torch.sum(S)
|
35 |
+
lora_rank = torch.sum(s_cum<min_cum_sum)
|
36 |
+
lora_rank = max(1, lora_rank)
|
37 |
+
lora_rank = min(out_ch, in_ch, lora_rank)
|
38 |
+
|
39 |
+
U = U[:, :lora_rank]
|
40 |
+
S = S[:lora_rank]
|
41 |
+
U = U @ torch.diag(S)
|
42 |
+
Vh = Vh[:lora_rank, :]
|
43 |
+
|
44 |
+
extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).cpu()
|
45 |
+
extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).cpu()
|
46 |
+
del U, S, Vh, weight
|
47 |
+
return extract_weight_A, extract_weight_B
|
48 |
+
|
49 |
+
|
50 |
+
def merge_conv(
|
51 |
+
weight_a: Union[torch.Tensor, nn.Parameter],
|
52 |
+
weight_b: Union[torch.Tensor, nn.Parameter],
|
53 |
+
device = 'cpu'
|
54 |
+
):
|
55 |
+
rank, in_ch, kernel_size, k_ = weight_a.shape
|
56 |
+
out_ch, rank_, _, _ = weight_b.shape
|
57 |
+
assert rank == rank_ and kernel_size == k_
|
58 |
+
|
59 |
+
wa = weight_a.to(device)
|
60 |
+
wb = weight_b.to(device)
|
61 |
+
|
62 |
+
if device == 'cpu':
|
63 |
+
wa = wa.float()
|
64 |
+
wb = wb.float()
|
65 |
+
|
66 |
+
merged = wb.reshape(out_ch, -1) @ wa.reshape(rank, -1)
|
67 |
+
weight = merged.reshape(out_ch, in_ch, kernel_size, kernel_size)
|
68 |
+
del wb, wa
|
69 |
+
return weight
|
70 |
+
|
71 |
+
|
72 |
+
def extract_linear(
|
73 |
+
weight: Union[torch.Tensor, nn.Parameter],
|
74 |
+
mode = 'fixed',
|
75 |
+
mode_param = 0,
|
76 |
+
device = 'cpu',
|
77 |
+
) -> Tuple[nn.Parameter, nn.Parameter]:
|
78 |
+
out_ch, in_ch = weight.shape
|
79 |
+
|
80 |
+
U, S, Vh = linalg.svd(weight.to(device))
|
81 |
+
|
82 |
+
if mode=='fixed':
|
83 |
+
lora_rank = mode_param
|
84 |
+
elif mode=='threshold':
|
85 |
+
assert mode_param>=0
|
86 |
+
lora_rank = torch.sum(S>mode_param)
|
87 |
+
elif mode=='ratio':
|
88 |
+
assert 1>=mode_param>=0
|
89 |
+
min_s = torch.max(S)*mode_param
|
90 |
+
lora_rank = torch.sum(S>min_s)
|
91 |
+
elif mode=='percentile':
|
92 |
+
assert 1>=mode_param>=0
|
93 |
+
s_cum = torch.cumsum(S, dim=0)
|
94 |
+
min_cum_sum = mode_param * torch.sum(S)
|
95 |
+
lora_rank = torch.sum(s_cum<min_cum_sum)
|
96 |
+
lora_rank = max(1, lora_rank)
|
97 |
+
lora_rank = min(out_ch, in_ch, lora_rank)
|
98 |
+
|
99 |
+
U = U[:, :lora_rank]
|
100 |
+
S = S[:lora_rank]
|
101 |
+
U = U @ torch.diag(S)
|
102 |
+
Vh = Vh[:lora_rank, :]
|
103 |
+
|
104 |
+
extract_weight_A = Vh.reshape(lora_rank, in_ch).cpu()
|
105 |
+
extract_weight_B = U.reshape(out_ch, lora_rank).cpu()
|
106 |
+
del U, S, Vh, weight
|
107 |
+
return extract_weight_A, extract_weight_B
|
108 |
+
|
109 |
+
|
110 |
+
def merge_linear(
|
111 |
+
weight_a: Union[torch.Tensor, nn.Parameter],
|
112 |
+
weight_b: Union[torch.Tensor, nn.Parameter],
|
113 |
+
device = 'cpu'
|
114 |
+
):
|
115 |
+
rank, in_ch = weight_a.shape
|
116 |
+
out_ch, rank_ = weight_b.shape
|
117 |
+
assert rank == rank_
|
118 |
+
|
119 |
+
wa = weight_a.to(device)
|
120 |
+
wb = weight_b.to(device)
|
121 |
+
|
122 |
+
if device == 'cpu':
|
123 |
+
wa = wa.float()
|
124 |
+
wb = wb.float()
|
125 |
+
|
126 |
+
weight = wb @ wa
|
127 |
+
del wb, wa
|
128 |
+
return weight
|
129 |
+
|
130 |
+
|
131 |
+
def extract_diff(
|
132 |
+
base_model,
|
133 |
+
db_model,
|
134 |
+
mode = 'fixed',
|
135 |
+
linear_mode_param = 0,
|
136 |
+
conv_mode_param = 0,
|
137 |
+
extract_device = 'cpu'
|
138 |
+
):
|
139 |
+
UNET_TARGET_REPLACE_MODULE = [
|
140 |
+
"Transformer2DModel",
|
141 |
+
"Attention",
|
142 |
+
"ResnetBlock2D",
|
143 |
+
"Downsample2D",
|
144 |
+
"Upsample2D"
|
145 |
+
]
|
146 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
147 |
+
LORA_PREFIX_UNET = 'lora_unet'
|
148 |
+
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
149 |
+
def make_state_dict(
|
150 |
+
prefix,
|
151 |
+
root_module: torch.nn.Module,
|
152 |
+
target_module: torch.nn.Module,
|
153 |
+
target_replace_modules
|
154 |
+
):
|
155 |
+
loras = {}
|
156 |
+
temp = {}
|
157 |
+
|
158 |
+
for name, module in root_module.named_modules():
|
159 |
+
if module.__class__.__name__ in target_replace_modules:
|
160 |
+
temp[name] = {}
|
161 |
+
for child_name, child_module in module.named_modules():
|
162 |
+
if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}:
|
163 |
+
continue
|
164 |
+
temp[name][child_name] = child_module.weight
|
165 |
+
|
166 |
+
for name, module in tqdm(list(target_module.named_modules())):
|
167 |
+
if name in temp:
|
168 |
+
weights = temp[name]
|
169 |
+
for child_name, child_module in module.named_modules():
|
170 |
+
lora_name = prefix + '.' + name + '.' + child_name
|
171 |
+
lora_name = lora_name.replace('.', '_')
|
172 |
+
|
173 |
+
layer = child_module.__class__.__name__
|
174 |
+
if layer == 'Linear':
|
175 |
+
extract_a, extract_b = extract_linear(
|
176 |
+
(child_module.weight - weights[child_name]),
|
177 |
+
mode,
|
178 |
+
linear_mode_param,
|
179 |
+
device = extract_device,
|
180 |
+
)
|
181 |
+
elif layer == 'Conv2d':
|
182 |
+
is_linear = (child_module.weight.shape[2] == 1
|
183 |
+
and child_module.weight.shape[3] == 1)
|
184 |
+
extract_a, extract_b = extract_conv(
|
185 |
+
(child_module.weight - weights[child_name]),
|
186 |
+
mode,
|
187 |
+
linear_mode_param if is_linear else conv_mode_param,
|
188 |
+
device = extract_device,
|
189 |
+
)
|
190 |
+
else:
|
191 |
+
continue
|
192 |
+
loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half()
|
193 |
+
loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half()
|
194 |
+
loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half()
|
195 |
+
del extract_a, extract_b
|
196 |
+
return loras
|
197 |
+
|
198 |
+
text_encoder_loras = make_state_dict(
|
199 |
+
LORA_PREFIX_TEXT_ENCODER,
|
200 |
+
base_model[0], db_model[0],
|
201 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE
|
202 |
+
)
|
203 |
+
|
204 |
+
unet_loras = make_state_dict(
|
205 |
+
LORA_PREFIX_UNET,
|
206 |
+
base_model[2], db_model[2],
|
207 |
+
UNET_TARGET_REPLACE_MODULE
|
208 |
+
)
|
209 |
+
print(len(text_encoder_loras), len(unet_loras))
|
210 |
+
return text_encoder_loras|unet_loras
|
211 |
+
|
212 |
+
|
213 |
+
def merge_locon(
|
214 |
+
base_model,
|
215 |
+
locon_state_dict: Dict[str, torch.TensorType],
|
216 |
+
scale: float = 1.0,
|
217 |
+
device = 'cpu'
|
218 |
+
):
|
219 |
+
UNET_TARGET_REPLACE_MODULE = [
|
220 |
+
"Transformer2DModel",
|
221 |
+
"Attention",
|
222 |
+
"ResnetBlock2D",
|
223 |
+
"Downsample2D",
|
224 |
+
"Upsample2D"
|
225 |
+
]
|
226 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
227 |
+
LORA_PREFIX_UNET = 'lora_unet'
|
228 |
+
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
229 |
+
def merge(
|
230 |
+
prefix,
|
231 |
+
root_module: torch.nn.Module,
|
232 |
+
target_replace_modules
|
233 |
+
):
|
234 |
+
temp = {}
|
235 |
+
|
236 |
+
for name, module in tqdm(list(root_module.named_modules())):
|
237 |
+
if module.__class__.__name__ in target_replace_modules:
|
238 |
+
temp[name] = {}
|
239 |
+
for child_name, child_module in module.named_modules():
|
240 |
+
layer = child_module.__class__.__name__
|
241 |
+
if layer not in {'Linear', 'Conv2d'}:
|
242 |
+
continue
|
243 |
+
lora_name = prefix + '.' + name + '.' + child_name
|
244 |
+
lora_name = lora_name.replace('.', '_')
|
245 |
+
|
246 |
+
down = locon_state_dict[f'{lora_name}.lora_down.weight'].float()
|
247 |
+
up = locon_state_dict[f'{lora_name}.lora_up.weight'].float()
|
248 |
+
alpha = locon_state_dict[f'{lora_name}.alpha'].float()
|
249 |
+
rank = down.shape[0]
|
250 |
+
|
251 |
+
if layer == 'Conv2d':
|
252 |
+
delta = merge_conv(down, up, device)
|
253 |
+
child_module.weight.requires_grad_(False)
|
254 |
+
child_module.weight += (alpha.to(device)/rank * scale * delta).cpu()
|
255 |
+
del delta
|
256 |
+
elif layer == 'Linear':
|
257 |
+
delta = merge_linear(down, up, device)
|
258 |
+
child_module.weight.requires_grad_(False)
|
259 |
+
child_module.weight += (alpha.to(device)/rank * scale * delta).cpu()
|
260 |
+
del delta
|
261 |
+
|
262 |
+
merge(
|
263 |
+
LORA_PREFIX_TEXT_ENCODER,
|
264 |
+
base_model[0],
|
265 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE
|
266 |
+
)
|
267 |
+
merge(
|
268 |
+
LORA_PREFIX_UNET,
|
269 |
+
base_model[2],
|
270 |
+
UNET_TARGET_REPLACE_MODULE
|
271 |
+
)
|
networks/__pycache__/lora.cpython-310.pyc
ADDED
Binary file (7.36 kB). View file
|
|
networks/check_lora_weights.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from safetensors.torch import load_file
|
5 |
+
|
6 |
+
|
7 |
+
def main(file):
|
8 |
+
print(f"loading: {file}")
|
9 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
10 |
+
sd = load_file(file)
|
11 |
+
else:
|
12 |
+
sd = torch.load(file, map_location='cpu')
|
13 |
+
|
14 |
+
values = []
|
15 |
+
|
16 |
+
keys = list(sd.keys())
|
17 |
+
for key in keys:
|
18 |
+
if 'lora_up' in key or 'lora_down' in key:
|
19 |
+
values.append((key, sd[key]))
|
20 |
+
print(f"number of LoRA modules: {len(values)}")
|
21 |
+
|
22 |
+
for key, value in values:
|
23 |
+
value = value.to(torch.float32)
|
24 |
+
print(f"{key},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
|
25 |
+
|
26 |
+
|
27 |
+
if __name__ == '__main__':
|
28 |
+
parser = argparse.ArgumentParser()
|
29 |
+
parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル")
|
30 |
+
args = parser.parse_args()
|
31 |
+
|
32 |
+
main(args.file)
|
networks/extract_lora_from_models.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# extract approximating LoRA by svd from two SD models
|
2 |
+
# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
3 |
+
# Thanks to cloneofsimo!
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
from safetensors.torch import load_file, save_file
|
9 |
+
from tqdm import tqdm
|
10 |
+
import library.model_util as model_util
|
11 |
+
import lora
|
12 |
+
|
13 |
+
|
14 |
+
CLAMP_QUANTILE = 0.99
|
15 |
+
MIN_DIFF = 1e-6
|
16 |
+
|
17 |
+
|
18 |
+
def save_to_file(file_name, model, state_dict, dtype):
|
19 |
+
if dtype is not None:
|
20 |
+
for key in list(state_dict.keys()):
|
21 |
+
if type(state_dict[key]) == torch.Tensor:
|
22 |
+
state_dict[key] = state_dict[key].to(dtype)
|
23 |
+
|
24 |
+
if os.path.splitext(file_name)[1] == '.safetensors':
|
25 |
+
save_file(model, file_name)
|
26 |
+
else:
|
27 |
+
torch.save(model, file_name)
|
28 |
+
|
29 |
+
|
30 |
+
def svd(args):
|
31 |
+
def str_to_dtype(p):
|
32 |
+
if p == 'float':
|
33 |
+
return torch.float
|
34 |
+
if p == 'fp16':
|
35 |
+
return torch.float16
|
36 |
+
if p == 'bf16':
|
37 |
+
return torch.bfloat16
|
38 |
+
return None
|
39 |
+
|
40 |
+
save_dtype = str_to_dtype(args.save_precision)
|
41 |
+
|
42 |
+
print(f"loading SD model : {args.model_org}")
|
43 |
+
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
|
44 |
+
print(f"loading SD model : {args.model_tuned}")
|
45 |
+
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
|
46 |
+
|
47 |
+
# create LoRA network to extract weights: Use dim (rank) as alpha
|
48 |
+
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o)
|
49 |
+
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t)
|
50 |
+
assert len(lora_network_o.text_encoder_loras) == len(
|
51 |
+
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
52 |
+
|
53 |
+
# get diffs
|
54 |
+
diffs = {}
|
55 |
+
text_encoder_different = False
|
56 |
+
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
|
57 |
+
lora_name = lora_o.lora_name
|
58 |
+
module_o = lora_o.org_module
|
59 |
+
module_t = lora_t.org_module
|
60 |
+
diff = module_t.weight - module_o.weight
|
61 |
+
|
62 |
+
# Text Encoder might be same
|
63 |
+
if torch.max(torch.abs(diff)) > MIN_DIFF:
|
64 |
+
text_encoder_different = True
|
65 |
+
|
66 |
+
diff = diff.float()
|
67 |
+
diffs[lora_name] = diff
|
68 |
+
|
69 |
+
if not text_encoder_different:
|
70 |
+
print("Text encoder is same. Extract U-Net only.")
|
71 |
+
lora_network_o.text_encoder_loras = []
|
72 |
+
diffs = {}
|
73 |
+
|
74 |
+
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
|
75 |
+
lora_name = lora_o.lora_name
|
76 |
+
module_o = lora_o.org_module
|
77 |
+
module_t = lora_t.org_module
|
78 |
+
diff = module_t.weight - module_o.weight
|
79 |
+
diff = diff.float()
|
80 |
+
|
81 |
+
if args.device:
|
82 |
+
diff = diff.to(args.device)
|
83 |
+
|
84 |
+
diffs[lora_name] = diff
|
85 |
+
|
86 |
+
# make LoRA with svd
|
87 |
+
print("calculating by svd")
|
88 |
+
rank = args.dim
|
89 |
+
lora_weights = {}
|
90 |
+
with torch.no_grad():
|
91 |
+
for lora_name, mat in tqdm(list(diffs.items())):
|
92 |
+
conv2d = (len(mat.size()) == 4)
|
93 |
+
if conv2d:
|
94 |
+
mat = mat.squeeze()
|
95 |
+
|
96 |
+
U, S, Vh = torch.linalg.svd(mat)
|
97 |
+
|
98 |
+
U = U[:, :rank]
|
99 |
+
S = S[:rank]
|
100 |
+
U = U @ torch.diag(S)
|
101 |
+
|
102 |
+
Vh = Vh[:rank, :]
|
103 |
+
|
104 |
+
dist = torch.cat([U.flatten(), Vh.flatten()])
|
105 |
+
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
106 |
+
low_val = -hi_val
|
107 |
+
|
108 |
+
U = U.clamp(low_val, hi_val)
|
109 |
+
Vh = Vh.clamp(low_val, hi_val)
|
110 |
+
|
111 |
+
lora_weights[lora_name] = (U, Vh)
|
112 |
+
|
113 |
+
# make state dict for LoRA
|
114 |
+
lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict
|
115 |
+
lora_sd = lora_network_o.state_dict()
|
116 |
+
print(f"LoRA has {len(lora_sd)} weights.")
|
117 |
+
|
118 |
+
for key in list(lora_sd.keys()):
|
119 |
+
if "alpha" in key:
|
120 |
+
continue
|
121 |
+
|
122 |
+
lora_name = key.split('.')[0]
|
123 |
+
i = 0 if "lora_up" in key else 1
|
124 |
+
|
125 |
+
weights = lora_weights[lora_name][i]
|
126 |
+
# print(key, i, weights.size(), lora_sd[key].size())
|
127 |
+
if len(lora_sd[key].size()) == 4:
|
128 |
+
weights = weights.unsqueeze(2).unsqueeze(3)
|
129 |
+
|
130 |
+
assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
|
131 |
+
lora_sd[key] = weights
|
132 |
+
|
133 |
+
# load state dict to LoRA and save it
|
134 |
+
info = lora_network_o.load_state_dict(lora_sd)
|
135 |
+
print(f"Loading extracted LoRA weights: {info}")
|
136 |
+
|
137 |
+
dir_name = os.path.dirname(args.save_to)
|
138 |
+
if dir_name and not os.path.exists(dir_name):
|
139 |
+
os.makedirs(dir_name, exist_ok=True)
|
140 |
+
|
141 |
+
# minimum metadata
|
142 |
+
metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
|
143 |
+
|
144 |
+
lora_network_o.save_weights(args.save_to, save_dtype, metadata)
|
145 |
+
print(f"LoRA weights are saved to: {args.save_to}")
|
146 |
+
|
147 |
+
|
148 |
+
if __name__ == '__main__':
|
149 |
+
parser = argparse.ArgumentParser()
|
150 |
+
parser.add_argument("--v2", action='store_true',
|
151 |
+
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
152 |
+
parser.add_argument("--save_precision", type=str, default=None,
|
153 |
+
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat")
|
154 |
+
parser.add_argument("--model_org", type=str, default=None,
|
155 |
+
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors")
|
156 |
+
parser.add_argument("--model_tuned", type=str, default=None,
|
157 |
+
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors")
|
158 |
+
parser.add_argument("--save_to", type=str, default=None,
|
159 |
+
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
160 |
+
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
161 |
+
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
162 |
+
|
163 |
+
args = parser.parse_args()
|
164 |
+
svd(args)
|
networks/lora.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LoRA network module
|
2 |
+
# reference:
|
3 |
+
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
4 |
+
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
5 |
+
|
6 |
+
import math
|
7 |
+
import os
|
8 |
+
from typing import List
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from library import train_util
|
12 |
+
|
13 |
+
|
14 |
+
class LoRAModule(torch.nn.Module):
|
15 |
+
"""
|
16 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
|
20 |
+
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
21 |
+
super().__init__()
|
22 |
+
self.lora_name = lora_name
|
23 |
+
self.lora_dim = lora_dim
|
24 |
+
|
25 |
+
if org_module.__class__.__name__ == 'Conv2d':
|
26 |
+
in_dim = org_module.in_channels
|
27 |
+
out_dim = org_module.out_channels
|
28 |
+
self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
|
29 |
+
self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
|
30 |
+
else:
|
31 |
+
in_dim = org_module.in_features
|
32 |
+
out_dim = org_module.out_features
|
33 |
+
self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
|
34 |
+
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
|
35 |
+
|
36 |
+
if type(alpha) == torch.Tensor:
|
37 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
38 |
+
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
39 |
+
self.scale = alpha / self.lora_dim
|
40 |
+
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
41 |
+
|
42 |
+
# same as microsoft's
|
43 |
+
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
44 |
+
torch.nn.init.zeros_(self.lora_up.weight)
|
45 |
+
|
46 |
+
self.multiplier = multiplier
|
47 |
+
self.org_module = org_module # remove in applying
|
48 |
+
|
49 |
+
def apply_to(self):
|
50 |
+
self.org_forward = self.org_module.forward
|
51 |
+
self.org_module.forward = self.forward
|
52 |
+
del self.org_module
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
56 |
+
|
57 |
+
|
58 |
+
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
59 |
+
if network_dim is None:
|
60 |
+
network_dim = 4 # default
|
61 |
+
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
|
62 |
+
return network
|
63 |
+
|
64 |
+
|
65 |
+
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
|
66 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
67 |
+
from safetensors.torch import load_file, safe_open
|
68 |
+
weights_sd = load_file(file)
|
69 |
+
else:
|
70 |
+
weights_sd = torch.load(file, map_location='cpu')
|
71 |
+
|
72 |
+
# get dim (rank)
|
73 |
+
network_alpha = None
|
74 |
+
network_dim = None
|
75 |
+
for key, value in weights_sd.items():
|
76 |
+
if network_alpha is None and 'alpha' in key:
|
77 |
+
network_alpha = value
|
78 |
+
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
|
79 |
+
network_dim = value.size()[0]
|
80 |
+
|
81 |
+
if network_alpha is None:
|
82 |
+
network_alpha = network_dim
|
83 |
+
|
84 |
+
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
|
85 |
+
network.weights_sd = weights_sd
|
86 |
+
return network
|
87 |
+
|
88 |
+
|
89 |
+
class LoRANetwork(torch.nn.Module):
|
90 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
91 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
92 |
+
LORA_PREFIX_UNET = 'lora_unet'
|
93 |
+
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
94 |
+
|
95 |
+
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
|
96 |
+
super().__init__()
|
97 |
+
self.multiplier = multiplier
|
98 |
+
self.lora_dim = lora_dim
|
99 |
+
self.alpha = alpha
|
100 |
+
|
101 |
+
# create module instances
|
102 |
+
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
103 |
+
loras = []
|
104 |
+
for name, module in root_module.named_modules():
|
105 |
+
if module.__class__.__name__ in target_replace_modules:
|
106 |
+
for child_name, child_module in module.named_modules():
|
107 |
+
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
|
108 |
+
lora_name = prefix + '.' + name + '.' + child_name
|
109 |
+
lora_name = lora_name.replace('.', '_')
|
110 |
+
lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
|
111 |
+
loras.append(lora)
|
112 |
+
return loras
|
113 |
+
|
114 |
+
self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER,
|
115 |
+
text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
116 |
+
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
117 |
+
|
118 |
+
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
|
119 |
+
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
120 |
+
|
121 |
+
self.weights_sd = None
|
122 |
+
|
123 |
+
# assertion
|
124 |
+
names = set()
|
125 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
126 |
+
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
127 |
+
names.add(lora.lora_name)
|
128 |
+
|
129 |
+
def load_weights(self, file):
|
130 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
131 |
+
from safetensors.torch import load_file, safe_open
|
132 |
+
self.weights_sd = load_file(file)
|
133 |
+
else:
|
134 |
+
self.weights_sd = torch.load(file, map_location='cpu')
|
135 |
+
|
136 |
+
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
|
137 |
+
if self.weights_sd:
|
138 |
+
weights_has_text_encoder = weights_has_unet = False
|
139 |
+
for key in self.weights_sd.keys():
|
140 |
+
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
141 |
+
weights_has_text_encoder = True
|
142 |
+
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
143 |
+
weights_has_unet = True
|
144 |
+
|
145 |
+
if apply_text_encoder is None:
|
146 |
+
apply_text_encoder = weights_has_text_encoder
|
147 |
+
else:
|
148 |
+
assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
|
149 |
+
|
150 |
+
if apply_unet is None:
|
151 |
+
apply_unet = weights_has_unet
|
152 |
+
else:
|
153 |
+
assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
|
154 |
+
else:
|
155 |
+
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
|
156 |
+
|
157 |
+
if apply_text_encoder:
|
158 |
+
print("enable LoRA for text encoder")
|
159 |
+
else:
|
160 |
+
self.text_encoder_loras = []
|
161 |
+
|
162 |
+
if apply_unet:
|
163 |
+
print("enable LoRA for U-Net")
|
164 |
+
else:
|
165 |
+
self.unet_loras = []
|
166 |
+
|
167 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
168 |
+
lora.apply_to()
|
169 |
+
self.add_module(lora.lora_name, lora)
|
170 |
+
|
171 |
+
if self.weights_sd:
|
172 |
+
# if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
|
173 |
+
info = self.load_state_dict(self.weights_sd, False)
|
174 |
+
print(f"weights are loaded: {info}")
|
175 |
+
|
176 |
+
def enable_gradient_checkpointing(self):
|
177 |
+
# not supported
|
178 |
+
pass
|
179 |
+
|
180 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
|
181 |
+
def enumerate_params(loras):
|
182 |
+
params = []
|
183 |
+
for lora in loras:
|
184 |
+
params.extend(lora.parameters())
|
185 |
+
return params
|
186 |
+
|
187 |
+
self.requires_grad_(True)
|
188 |
+
all_params = []
|
189 |
+
|
190 |
+
if self.text_encoder_loras:
|
191 |
+
param_data = {'params': enumerate_params(self.text_encoder_loras)}
|
192 |
+
if text_encoder_lr is not None:
|
193 |
+
param_data['lr'] = text_encoder_lr
|
194 |
+
all_params.append(param_data)
|
195 |
+
|
196 |
+
if self.unet_loras:
|
197 |
+
param_data = {'params': enumerate_params(self.unet_loras)}
|
198 |
+
if unet_lr is not None:
|
199 |
+
param_data['lr'] = unet_lr
|
200 |
+
all_params.append(param_data)
|
201 |
+
|
202 |
+
return all_params
|
203 |
+
|
204 |
+
def prepare_grad_etc(self, text_encoder, unet):
|
205 |
+
self.requires_grad_(True)
|
206 |
+
|
207 |
+
def on_epoch_start(self, text_encoder, unet):
|
208 |
+
self.train()
|
209 |
+
|
210 |
+
def get_trainable_params(self):
|
211 |
+
return self.parameters()
|
212 |
+
|
213 |
+
def save_weights(self, file, dtype, metadata):
|
214 |
+
if metadata is not None and len(metadata) == 0:
|
215 |
+
metadata = None
|
216 |
+
|
217 |
+
state_dict = self.state_dict()
|
218 |
+
|
219 |
+
if dtype is not None:
|
220 |
+
for key in list(state_dict.keys()):
|
221 |
+
v = state_dict[key]
|
222 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
223 |
+
state_dict[key] = v
|
224 |
+
|
225 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
226 |
+
from safetensors.torch import save_file
|
227 |
+
|
228 |
+
# Precalculate model hashes to save time on indexing
|
229 |
+
if metadata is None:
|
230 |
+
metadata = {}
|
231 |
+
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
232 |
+
metadata["sshs_model_hash"] = model_hash
|
233 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
234 |
+
|
235 |
+
save_file(state_dict, file, metadata)
|
236 |
+
else:
|
237 |
+
torch.save(state_dict, file)
|
networks/lora_interrogator.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
from tqdm import tqdm
|
4 |
+
from library import model_util
|
5 |
+
import argparse
|
6 |
+
from transformers import CLIPTokenizer
|
7 |
+
import torch
|
8 |
+
|
9 |
+
import library.model_util as model_util
|
10 |
+
import lora
|
11 |
+
|
12 |
+
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
13 |
+
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
|
14 |
+
|
15 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
16 |
+
|
17 |
+
|
18 |
+
def interrogate(args):
|
19 |
+
# いろいろ準備する
|
20 |
+
print(f"loading SD model: {args.sd_model}")
|
21 |
+
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
22 |
+
|
23 |
+
print(f"loading LoRA: {args.model}")
|
24 |
+
network = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
|
25 |
+
|
26 |
+
# text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい
|
27 |
+
has_te_weight = False
|
28 |
+
for key in network.weights_sd.keys():
|
29 |
+
if 'lora_te' in key:
|
30 |
+
has_te_weight = True
|
31 |
+
break
|
32 |
+
if not has_te_weight:
|
33 |
+
print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
|
34 |
+
return
|
35 |
+
del vae
|
36 |
+
|
37 |
+
print("loading tokenizer")
|
38 |
+
if args.v2:
|
39 |
+
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
40 |
+
else:
|
41 |
+
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
|
42 |
+
|
43 |
+
text_encoder.to(DEVICE)
|
44 |
+
text_encoder.eval()
|
45 |
+
unet.to(DEVICE)
|
46 |
+
unet.eval() # U-Netは呼び出さないので不要だけど
|
47 |
+
|
48 |
+
# トークンをひとつひとつ当たっていく
|
49 |
+
token_id_start = 0
|
50 |
+
token_id_end = max(tokenizer.all_special_ids)
|
51 |
+
print(f"interrogate tokens are: {token_id_start} to {token_id_end}")
|
52 |
+
|
53 |
+
def get_all_embeddings(text_encoder):
|
54 |
+
embs = []
|
55 |
+
with torch.no_grad():
|
56 |
+
for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)):
|
57 |
+
batch = []
|
58 |
+
for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)):
|
59 |
+
tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id]
|
60 |
+
# tokens = [tid] # こちらは結果がいまひとつ
|
61 |
+
batch.append(tokens)
|
62 |
+
|
63 |
+
# batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1]
|
64 |
+
# clip skip対応
|
65 |
+
batch = torch.tensor(batch).to(DEVICE)
|
66 |
+
if args.clip_skip is None:
|
67 |
+
encoder_hidden_states = text_encoder(batch)[0]
|
68 |
+
else:
|
69 |
+
enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True)
|
70 |
+
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
71 |
+
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
72 |
+
encoder_hidden_states = encoder_hidden_states.to("cpu")
|
73 |
+
|
74 |
+
embs.extend(encoder_hidden_states)
|
75 |
+
return torch.stack(embs)
|
76 |
+
|
77 |
+
print("get original text encoder embeddings.")
|
78 |
+
orig_embs = get_all_embeddings(text_encoder)
|
79 |
+
|
80 |
+
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
|
81 |
+
network.to(DEVICE)
|
82 |
+
network.eval()
|
83 |
+
|
84 |
+
print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
|
85 |
+
print("get text encoder embeddings with lora.")
|
86 |
+
lora_embs = get_all_embeddings(text_encoder)
|
87 |
+
|
88 |
+
# 比べる:とりあえず単純に差分の絶対値で
|
89 |
+
print("comparing...")
|
90 |
+
diffs = {}
|
91 |
+
for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
|
92 |
+
diff = torch.mean(torch.abs(orig_emb - lora_emb))
|
93 |
+
# diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うまく検出できない
|
94 |
+
diff = float(diff.detach().to('cpu').numpy())
|
95 |
+
diffs[token_id_start + i] = diff
|
96 |
+
|
97 |
+
diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1])
|
98 |
+
|
99 |
+
# 結果を表示する
|
100 |
+
print("top 100:")
|
101 |
+
for i, (token, diff) in enumerate(diffs_sorted[:100]):
|
102 |
+
# if diff < 1e-6:
|
103 |
+
# break
|
104 |
+
string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token]))
|
105 |
+
print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}")
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == '__main__':
|
109 |
+
parser = argparse.ArgumentParser()
|
110 |
+
parser.add_argument("--v2", action='store_true',
|
111 |
+
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
112 |
+
parser.add_argument("--sd_model", type=str, default=None,
|
113 |
+
help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors")
|
114 |
+
parser.add_argument("--model", type=str, default=None,
|
115 |
+
help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors")
|
116 |
+
parser.add_argument("--batch_size", type=int, default=16,
|
117 |
+
help="batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ")
|
118 |
+
parser.add_argument("--clip_skip", type=int, default=None,
|
119 |
+
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
120 |
+
|
121 |
+
args = parser.parse_args()
|
122 |
+
interrogate(args)
|
networks/merge_lora.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import math
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from safetensors.torch import load_file, save_file
|
7 |
+
import library.model_util as model_util
|
8 |
+
import lora
|
9 |
+
|
10 |
+
|
11 |
+
def load_state_dict(file_name, dtype):
|
12 |
+
if os.path.splitext(file_name)[1] == '.safetensors':
|
13 |
+
sd = load_file(file_name)
|
14 |
+
else:
|
15 |
+
sd = torch.load(file_name, map_location='cpu')
|
16 |
+
for key in list(sd.keys()):
|
17 |
+
if type(sd[key]) == torch.Tensor:
|
18 |
+
sd[key] = sd[key].to(dtype)
|
19 |
+
return sd
|
20 |
+
|
21 |
+
|
22 |
+
def save_to_file(file_name, model, state_dict, dtype):
|
23 |
+
if dtype is not None:
|
24 |
+
for key in list(state_dict.keys()):
|
25 |
+
if type(state_dict[key]) == torch.Tensor:
|
26 |
+
state_dict[key] = state_dict[key].to(dtype)
|
27 |
+
|
28 |
+
if os.path.splitext(file_name)[1] == '.safetensors':
|
29 |
+
save_file(model, file_name)
|
30 |
+
else:
|
31 |
+
torch.save(model, file_name)
|
32 |
+
|
33 |
+
|
34 |
+
def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
35 |
+
text_encoder.to(merge_dtype)
|
36 |
+
unet.to(merge_dtype)
|
37 |
+
|
38 |
+
# create module map
|
39 |
+
name_to_module = {}
|
40 |
+
for i, root_module in enumerate([text_encoder, unet]):
|
41 |
+
if i == 0:
|
42 |
+
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
43 |
+
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
44 |
+
else:
|
45 |
+
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
|
46 |
+
target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
47 |
+
|
48 |
+
for name, module in root_module.named_modules():
|
49 |
+
if module.__class__.__name__ in target_replace_modules:
|
50 |
+
for child_name, child_module in module.named_modules():
|
51 |
+
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
|
52 |
+
lora_name = prefix + '.' + name + '.' + child_name
|
53 |
+
lora_name = lora_name.replace('.', '_')
|
54 |
+
name_to_module[lora_name] = child_module
|
55 |
+
|
56 |
+
for model, ratio in zip(models, ratios):
|
57 |
+
print(f"loading: {model}")
|
58 |
+
lora_sd = load_state_dict(model, merge_dtype)
|
59 |
+
|
60 |
+
print(f"merging...")
|
61 |
+
for key in lora_sd.keys():
|
62 |
+
if "lora_down" in key:
|
63 |
+
up_key = key.replace("lora_down", "lora_up")
|
64 |
+
alpha_key = key[:key.index("lora_down")] + 'alpha'
|
65 |
+
|
66 |
+
# find original module for this lora
|
67 |
+
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
|
68 |
+
if module_name not in name_to_module:
|
69 |
+
print(f"no module found for LoRA weight: {key}")
|
70 |
+
continue
|
71 |
+
module = name_to_module[module_name]
|
72 |
+
# print(f"apply {key} to {module}")
|
73 |
+
|
74 |
+
down_weight = lora_sd[key]
|
75 |
+
up_weight = lora_sd[up_key]
|
76 |
+
|
77 |
+
dim = down_weight.size()[0]
|
78 |
+
alpha = lora_sd.get(alpha_key, dim)
|
79 |
+
scale = alpha / dim
|
80 |
+
|
81 |
+
# W <- W + U * D
|
82 |
+
weight = module.weight
|
83 |
+
if len(weight.size()) == 2:
|
84 |
+
# linear
|
85 |
+
weight = weight + ratio * (up_weight @ down_weight) * scale
|
86 |
+
else:
|
87 |
+
# conv2d
|
88 |
+
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
89 |
+
).unsqueeze(2).unsqueeze(3) * scale
|
90 |
+
|
91 |
+
module.weight = torch.nn.Parameter(weight)
|
92 |
+
|
93 |
+
|
94 |
+
def merge_lora_models(models, ratios, merge_dtype):
|
95 |
+
base_alphas = {} # alpha for merged model
|
96 |
+
base_dims = {}
|
97 |
+
|
98 |
+
merged_sd = {}
|
99 |
+
for model, ratio in zip(models, ratios):
|
100 |
+
print(f"loading: {model}")
|
101 |
+
lora_sd = load_state_dict(model, merge_dtype)
|
102 |
+
|
103 |
+
# get alpha and dim
|
104 |
+
alphas = {} # alpha for current model
|
105 |
+
dims = {} # dims for current model
|
106 |
+
for key in lora_sd.keys():
|
107 |
+
if 'alpha' in key:
|
108 |
+
lora_module_name = key[:key.rfind(".alpha")]
|
109 |
+
alpha = float(lora_sd[key].detach().numpy())
|
110 |
+
alphas[lora_module_name] = alpha
|
111 |
+
if lora_module_name not in base_alphas:
|
112 |
+
base_alphas[lora_module_name] = alpha
|
113 |
+
elif "lora_down" in key:
|
114 |
+
lora_module_name = key[:key.rfind(".lora_down")]
|
115 |
+
dim = lora_sd[key].size()[0]
|
116 |
+
dims[lora_module_name] = dim
|
117 |
+
if lora_module_name not in base_dims:
|
118 |
+
base_dims[lora_module_name] = dim
|
119 |
+
|
120 |
+
for lora_module_name in dims.keys():
|
121 |
+
if lora_module_name not in alphas:
|
122 |
+
alpha = dims[lora_module_name]
|
123 |
+
alphas[lora_module_name] = alpha
|
124 |
+
if lora_module_name not in base_alphas:
|
125 |
+
base_alphas[lora_module_name] = alpha
|
126 |
+
|
127 |
+
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
128 |
+
|
129 |
+
# merge
|
130 |
+
print(f"merging...")
|
131 |
+
for key in lora_sd.keys():
|
132 |
+
if 'alpha' in key:
|
133 |
+
continue
|
134 |
+
|
135 |
+
lora_module_name = key[:key.rfind(".lora_")]
|
136 |
+
|
137 |
+
base_alpha = base_alphas[lora_module_name]
|
138 |
+
alpha = alphas[lora_module_name]
|
139 |
+
|
140 |
+
scale = math.sqrt(alpha / base_alpha) * ratio
|
141 |
+
|
142 |
+
if key in merged_sd:
|
143 |
+
assert merged_sd[key].size() == lora_sd[key].size(
|
144 |
+
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズ��合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
145 |
+
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
146 |
+
else:
|
147 |
+
merged_sd[key] = lora_sd[key] * scale
|
148 |
+
|
149 |
+
# set alpha to sd
|
150 |
+
for lora_module_name, alpha in base_alphas.items():
|
151 |
+
key = lora_module_name + ".alpha"
|
152 |
+
merged_sd[key] = torch.tensor(alpha)
|
153 |
+
|
154 |
+
print("merged model")
|
155 |
+
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
156 |
+
|
157 |
+
return merged_sd
|
158 |
+
|
159 |
+
|
160 |
+
def merge(args):
|
161 |
+
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
162 |
+
|
163 |
+
def str_to_dtype(p):
|
164 |
+
if p == 'float':
|
165 |
+
return torch.float
|
166 |
+
if p == 'fp16':
|
167 |
+
return torch.float16
|
168 |
+
if p == 'bf16':
|
169 |
+
return torch.bfloat16
|
170 |
+
return None
|
171 |
+
|
172 |
+
merge_dtype = str_to_dtype(args.precision)
|
173 |
+
save_dtype = str_to_dtype(args.save_precision)
|
174 |
+
if save_dtype is None:
|
175 |
+
save_dtype = merge_dtype
|
176 |
+
|
177 |
+
if args.sd_model is not None:
|
178 |
+
print(f"loading SD model: {args.sd_model}")
|
179 |
+
|
180 |
+
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
181 |
+
|
182 |
+
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
|
183 |
+
|
184 |
+
print(f"saving SD model to: {args.save_to}")
|
185 |
+
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
|
186 |
+
args.sd_model, 0, 0, save_dtype, vae)
|
187 |
+
else:
|
188 |
+
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
|
189 |
+
|
190 |
+
print(f"saving model to: {args.save_to}")
|
191 |
+
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
192 |
+
|
193 |
+
|
194 |
+
if __name__ == '__main__':
|
195 |
+
parser = argparse.ArgumentParser()
|
196 |
+
parser.add_argument("--v2", action='store_true',
|
197 |
+
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
198 |
+
parser.add_argument("--save_precision", type=str, default=None,
|
199 |
+
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
|
200 |
+
parser.add_argument("--precision", type=str, default="float",
|
201 |
+
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)")
|
202 |
+
parser.add_argument("--sd_model", type=str, default=None,
|
203 |
+
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
|
204 |
+
parser.add_argument("--save_to", type=str, default=None,
|
205 |
+
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
206 |
+
parser.add_argument("--models", type=str, nargs='*',
|
207 |
+
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
|
208 |
+
parser.add_argument("--ratios", type=float, nargs='*',
|
209 |
+
help="ratios for each model / それぞれのLoRAモデルの比率")
|
210 |
+
|
211 |
+
args = parser.parse_args()
|
212 |
+
merge(args)
|
networks/merge_lora_old.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from safetensors.torch import load_file, save_file
|
7 |
+
import library.model_util as model_util
|
8 |
+
import lora
|
9 |
+
|
10 |
+
|
11 |
+
def load_state_dict(file_name, dtype):
|
12 |
+
if os.path.splitext(file_name)[1] == '.safetensors':
|
13 |
+
sd = load_file(file_name)
|
14 |
+
else:
|
15 |
+
sd = torch.load(file_name, map_location='cpu')
|
16 |
+
for key in list(sd.keys()):
|
17 |
+
if type(sd[key]) == torch.Tensor:
|
18 |
+
sd[key] = sd[key].to(dtype)
|
19 |
+
return sd
|
20 |
+
|
21 |
+
|
22 |
+
def save_to_file(file_name, model, state_dict, dtype):
|
23 |
+
if dtype is not None:
|
24 |
+
for key in list(state_dict.keys()):
|
25 |
+
if type(state_dict[key]) == torch.Tensor:
|
26 |
+
state_dict[key] = state_dict[key].to(dtype)
|
27 |
+
|
28 |
+
if os.path.splitext(file_name)[1] == '.safetensors':
|
29 |
+
save_file(model, file_name)
|
30 |
+
else:
|
31 |
+
torch.save(model, file_name)
|
32 |
+
|
33 |
+
|
34 |
+
def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
35 |
+
text_encoder.to(merge_dtype)
|
36 |
+
unet.to(merge_dtype)
|
37 |
+
|
38 |
+
# create module map
|
39 |
+
name_to_module = {}
|
40 |
+
for i, root_module in enumerate([text_encoder, unet]):
|
41 |
+
if i == 0:
|
42 |
+
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
43 |
+
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
44 |
+
else:
|
45 |
+
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
|
46 |
+
target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
47 |
+
|
48 |
+
for name, module in root_module.named_modules():
|
49 |
+
if module.__class__.__name__ in target_replace_modules:
|
50 |
+
for child_name, child_module in module.named_modules():
|
51 |
+
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
|
52 |
+
lora_name = prefix + '.' + name + '.' + child_name
|
53 |
+
lora_name = lora_name.replace('.', '_')
|
54 |
+
name_to_module[lora_name] = child_module
|
55 |
+
|
56 |
+
for model, ratio in zip(models, ratios):
|
57 |
+
print(f"loading: {model}")
|
58 |
+
lora_sd = load_state_dict(model, merge_dtype)
|
59 |
+
|
60 |
+
print(f"merging...")
|
61 |
+
for key in lora_sd.keys():
|
62 |
+
if "lora_down" in key:
|
63 |
+
up_key = key.replace("lora_down", "lora_up")
|
64 |
+
alpha_key = key[:key.index("lora_down")] + 'alpha'
|
65 |
+
|
66 |
+
# find original module for this lora
|
67 |
+
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
|
68 |
+
if module_name not in name_to_module:
|
69 |
+
print(f"no module found for LoRA weight: {key}")
|
70 |
+
continue
|
71 |
+
module = name_to_module[module_name]
|
72 |
+
# print(f"apply {key} to {module}")
|
73 |
+
|
74 |
+
down_weight = lora_sd[key]
|
75 |
+
up_weight = lora_sd[up_key]
|
76 |
+
|
77 |
+
dim = down_weight.size()[0]
|
78 |
+
alpha = lora_sd.get(alpha_key, dim)
|
79 |
+
scale = alpha / dim
|
80 |
+
|
81 |
+
# W <- W + U * D
|
82 |
+
weight = module.weight
|
83 |
+
if len(weight.size()) == 2:
|
84 |
+
# linear
|
85 |
+
weight = weight + ratio * (up_weight @ down_weight) * scale
|
86 |
+
else:
|
87 |
+
# conv2d
|
88 |
+
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale
|
89 |
+
|
90 |
+
module.weight = torch.nn.Parameter(weight)
|
91 |
+
|
92 |
+
|
93 |
+
def merge_lora_models(models, ratios, merge_dtype):
|
94 |
+
merged_sd = {}
|
95 |
+
|
96 |
+
alpha = None
|
97 |
+
dim = None
|
98 |
+
for model, ratio in zip(models, ratios):
|
99 |
+
print(f"loading: {model}")
|
100 |
+
lora_sd = load_state_dict(model, merge_dtype)
|
101 |
+
|
102 |
+
print(f"merging...")
|
103 |
+
for key in lora_sd.keys():
|
104 |
+
if 'alpha' in key:
|
105 |
+
if key in merged_sd:
|
106 |
+
assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません"
|
107 |
+
else:
|
108 |
+
alpha = lora_sd[key].detach().numpy()
|
109 |
+
merged_sd[key] = lora_sd[key]
|
110 |
+
else:
|
111 |
+
if key in merged_sd:
|
112 |
+
assert merged_sd[key].size() == lora_sd[key].size(
|
113 |
+
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
114 |
+
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
|
115 |
+
else:
|
116 |
+
if "lora_down" in key:
|
117 |
+
dim = lora_sd[key].size()[0]
|
118 |
+
merged_sd[key] = lora_sd[key] * ratio
|
119 |
+
|
120 |
+
print(f"dim (rank): {dim}, alpha: {alpha}")
|
121 |
+
if alpha is None:
|
122 |
+
alpha = dim
|
123 |
+
|
124 |
+
return merged_sd, dim, alpha
|
125 |
+
|
126 |
+
|
127 |
+
def merge(args):
|
128 |
+
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
129 |
+
|
130 |
+
def str_to_dtype(p):
|
131 |
+
if p == 'float':
|
132 |
+
return torch.float
|
133 |
+
if p == 'fp16':
|
134 |
+
return torch.float16
|
135 |
+
if p == 'bf16':
|
136 |
+
return torch.bfloat16
|
137 |
+
return None
|
138 |
+
|
139 |
+
merge_dtype = str_to_dtype(args.precision)
|
140 |
+
save_dtype = str_to_dtype(args.save_precision)
|
141 |
+
if save_dtype is None:
|
142 |
+
save_dtype = merge_dtype
|
143 |
+
|
144 |
+
if args.sd_model is not None:
|
145 |
+
print(f"loading SD model: {args.sd_model}")
|
146 |
+
|
147 |
+
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
148 |
+
|
149 |
+
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
|
150 |
+
|
151 |
+
print(f"saving SD model to: {args.save_to}")
|
152 |
+
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
|
153 |
+
args.sd_model, 0, 0, save_dtype, vae)
|
154 |
+
else:
|
155 |
+
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
|
156 |
+
|
157 |
+
print(f"saving model to: {args.save_to}")
|
158 |
+
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
159 |
+
|
160 |
+
|
161 |
+
if __name__ == '__main__':
|
162 |
+
parser = argparse.ArgumentParser()
|
163 |
+
parser.add_argument("--v2", action='store_true',
|
164 |
+
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
165 |
+
parser.add_argument("--save_precision", type=str, default=None,
|
166 |
+
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
|
167 |
+
parser.add_argument("--precision", type=str, default="float",
|
168 |
+
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)")
|
169 |
+
parser.add_argument("--sd_model", type=str, default=None,
|
170 |
+
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
|
171 |
+
parser.add_argument("--save_to", type=str, default=None,
|
172 |
+
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
173 |
+
parser.add_argument("--models", type=str, nargs='*',
|
174 |
+
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
|
175 |
+
parser.add_argument("--ratios", type=float, nargs='*',
|
176 |
+
help="ratios for each model / それぞれのLoRAモデルの比率")
|
177 |
+
|
178 |
+
args = parser.parse_args()
|
179 |
+
merge(args)
|
networks/resize_lora.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
2 |
+
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
3 |
+
# Thanks to cloneofsimo and kohya
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
from safetensors.torch import load_file, save_file, safe_open
|
9 |
+
from tqdm import tqdm
|
10 |
+
from library import train_util, model_util
|
11 |
+
|
12 |
+
|
13 |
+
def load_state_dict(file_name, dtype):
|
14 |
+
if model_util.is_safetensors(file_name):
|
15 |
+
sd = load_file(file_name)
|
16 |
+
with safe_open(file_name, framework="pt") as f:
|
17 |
+
metadata = f.metadata()
|
18 |
+
else:
|
19 |
+
sd = torch.load(file_name, map_location='cpu')
|
20 |
+
metadata = None
|
21 |
+
|
22 |
+
for key in list(sd.keys()):
|
23 |
+
if type(sd[key]) == torch.Tensor:
|
24 |
+
sd[key] = sd[key].to(dtype)
|
25 |
+
|
26 |
+
return sd, metadata
|
27 |
+
|
28 |
+
|
29 |
+
def save_to_file(file_name, model, state_dict, dtype, metadata):
|
30 |
+
if dtype is not None:
|
31 |
+
for key in list(state_dict.keys()):
|
32 |
+
if type(state_dict[key]) == torch.Tensor:
|
33 |
+
state_dict[key] = state_dict[key].to(dtype)
|
34 |
+
|
35 |
+
if model_util.is_safetensors(file_name):
|
36 |
+
save_file(model, file_name, metadata)
|
37 |
+
else:
|
38 |
+
torch.save(model, file_name)
|
39 |
+
|
40 |
+
|
41 |
+
def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
42 |
+
network_alpha = None
|
43 |
+
network_dim = None
|
44 |
+
verbose_str = "\n"
|
45 |
+
|
46 |
+
CLAMP_QUANTILE = 0.99
|
47 |
+
|
48 |
+
# Extract loaded lora dim and alpha
|
49 |
+
for key, value in lora_sd.items():
|
50 |
+
if network_alpha is None and 'alpha' in key:
|
51 |
+
network_alpha = value
|
52 |
+
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
|
53 |
+
network_dim = value.size()[0]
|
54 |
+
if network_alpha is not None and network_dim is not None:
|
55 |
+
break
|
56 |
+
if network_alpha is None:
|
57 |
+
network_alpha = network_dim
|
58 |
+
|
59 |
+
scale = network_alpha/network_dim
|
60 |
+
new_alpha = float(scale*new_rank) # calculate new alpha from scale
|
61 |
+
|
62 |
+
print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}")
|
63 |
+
|
64 |
+
lora_down_weight = None
|
65 |
+
lora_up_weight = None
|
66 |
+
|
67 |
+
o_lora_sd = lora_sd.copy()
|
68 |
+
block_down_name = None
|
69 |
+
block_up_name = None
|
70 |
+
|
71 |
+
print("resizing lora...")
|
72 |
+
with torch.no_grad():
|
73 |
+
for key, value in tqdm(lora_sd.items()):
|
74 |
+
if 'lora_down' in key:
|
75 |
+
block_down_name = key.split(".")[0]
|
76 |
+
lora_down_weight = value
|
77 |
+
if 'lora_up' in key:
|
78 |
+
block_up_name = key.split(".")[0]
|
79 |
+
lora_up_weight = value
|
80 |
+
|
81 |
+
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
|
82 |
+
|
83 |
+
if (block_down_name == block_up_name) and weights_loaded:
|
84 |
+
|
85 |
+
conv2d = (len(lora_down_weight.size()) == 4)
|
86 |
+
|
87 |
+
if conv2d:
|
88 |
+
lora_down_weight = lora_down_weight.squeeze()
|
89 |
+
lora_up_weight = lora_up_weight.squeeze()
|
90 |
+
|
91 |
+
if device:
|
92 |
+
org_device = lora_up_weight.device
|
93 |
+
lora_up_weight = lora_up_weight.to(args.device)
|
94 |
+
lora_down_weight = lora_down_weight.to(args.device)
|
95 |
+
|
96 |
+
full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
|
97 |
+
|
98 |
+
U, S, Vh = torch.linalg.svd(full_weight_matrix)
|
99 |
+
|
100 |
+
if verbose:
|
101 |
+
s_sum = torch.sum(torch.abs(S))
|
102 |
+
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
103 |
+
verbose_str+=f"{block_down_name:76} | "
|
104 |
+
verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n"
|
105 |
+
|
106 |
+
U = U[:, :new_rank]
|
107 |
+
S = S[:new_rank]
|
108 |
+
U = U @ torch.diag(S)
|
109 |
+
|
110 |
+
Vh = Vh[:new_rank, :]
|
111 |
+
|
112 |
+
dist = torch.cat([U.flatten(), Vh.flatten()])
|
113 |
+
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
114 |
+
low_val = -hi_val
|
115 |
+
|
116 |
+
U = U.clamp(low_val, hi_val)
|
117 |
+
Vh = Vh.clamp(low_val, hi_val)
|
118 |
+
|
119 |
+
if conv2d:
|
120 |
+
U = U.unsqueeze(2).unsqueeze(3)
|
121 |
+
Vh = Vh.unsqueeze(2).unsqueeze(3)
|
122 |
+
|
123 |
+
if device:
|
124 |
+
U = U.to(org_device)
|
125 |
+
Vh = Vh.to(org_device)
|
126 |
+
|
127 |
+
o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
|
128 |
+
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
|
129 |
+
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
|
130 |
+
|
131 |
+
block_down_name = None
|
132 |
+
block_up_name = None
|
133 |
+
lora_down_weight = None
|
134 |
+
lora_up_weight = None
|
135 |
+
weights_loaded = False
|
136 |
+
|
137 |
+
if verbose:
|
138 |
+
print(verbose_str)
|
139 |
+
print("resizing complete")
|
140 |
+
return o_lora_sd, network_dim, new_alpha
|
141 |
+
|
142 |
+
|
143 |
+
def resize(args):
|
144 |
+
|
145 |
+
def str_to_dtype(p):
|
146 |
+
if p == 'float':
|
147 |
+
return torch.float
|
148 |
+
if p == 'fp16':
|
149 |
+
return torch.float16
|
150 |
+
if p == 'bf16':
|
151 |
+
return torch.bfloat16
|
152 |
+
return None
|
153 |
+
|
154 |
+
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
|
155 |
+
save_dtype = str_to_dtype(args.save_precision)
|
156 |
+
if save_dtype is None:
|
157 |
+
save_dtype = merge_dtype
|
158 |
+
|
159 |
+
print("loading Model...")
|
160 |
+
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
161 |
+
|
162 |
+
print("resizing rank...")
|
163 |
+
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose)
|
164 |
+
|
165 |
+
# update metadata
|
166 |
+
if metadata is None:
|
167 |
+
metadata = {}
|
168 |
+
|
169 |
+
comment = metadata.get("ss_training_comment", "")
|
170 |
+
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
|
171 |
+
metadata["ss_network_dim"] = str(args.new_rank)
|
172 |
+
metadata["ss_network_alpha"] = str(new_alpha)
|
173 |
+
|
174 |
+
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
175 |
+
metadata["sshs_model_hash"] = model_hash
|
176 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
177 |
+
|
178 |
+
print(f"saving model to: {args.save_to}")
|
179 |
+
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
180 |
+
|
181 |
+
|
182 |
+
if __name__ == '__main__':
|
183 |
+
parser = argparse.ArgumentParser()
|
184 |
+
|
185 |
+
parser.add_argument("--save_precision", type=str, default=None,
|
186 |
+
choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat")
|
187 |
+
parser.add_argument("--new_rank", type=int, default=4,
|
188 |
+
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
189 |
+
parser.add_argument("--save_to", type=str, default=None,
|
190 |
+
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
191 |
+
parser.add_argument("--model", type=str, default=None,
|
192 |
+
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors")
|
193 |
+
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
194 |
+
parser.add_argument("--verbose", action="store_true",
|
195 |
+
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
|
196 |
+
|
197 |
+
args = parser.parse_args()
|
198 |
+
resize(args)
|
networks/svd_merge_lora.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import math
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from safetensors.torch import load_file, save_file
|
7 |
+
from tqdm import tqdm
|
8 |
+
import library.model_util as model_util
|
9 |
+
import lora
|
10 |
+
|
11 |
+
|
12 |
+
CLAMP_QUANTILE = 0.99
|
13 |
+
|
14 |
+
|
15 |
+
def load_state_dict(file_name, dtype):
|
16 |
+
if os.path.splitext(file_name)[1] == '.safetensors':
|
17 |
+
sd = load_file(file_name)
|
18 |
+
else:
|
19 |
+
sd = torch.load(file_name, map_location='cpu')
|
20 |
+
for key in list(sd.keys()):
|
21 |
+
if type(sd[key]) == torch.Tensor:
|
22 |
+
sd[key] = sd[key].to(dtype)
|
23 |
+
return sd
|
24 |
+
|
25 |
+
|
26 |
+
def save_to_file(file_name, model, state_dict, dtype):
|
27 |
+
if dtype is not None:
|
28 |
+
for key in list(state_dict.keys()):
|
29 |
+
if type(state_dict[key]) == torch.Tensor:
|
30 |
+
state_dict[key] = state_dict[key].to(dtype)
|
31 |
+
|
32 |
+
if os.path.splitext(file_name)[1] == '.safetensors':
|
33 |
+
save_file(model, file_name)
|
34 |
+
else:
|
35 |
+
torch.save(model, file_name)
|
36 |
+
|
37 |
+
|
38 |
+
def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
39 |
+
merged_sd = {}
|
40 |
+
for model, ratio in zip(models, ratios):
|
41 |
+
print(f"loading: {model}")
|
42 |
+
lora_sd = load_state_dict(model, merge_dtype)
|
43 |
+
|
44 |
+
# merge
|
45 |
+
print(f"merging...")
|
46 |
+
for key in tqdm(list(lora_sd.keys())):
|
47 |
+
if 'lora_down' not in key:
|
48 |
+
continue
|
49 |
+
|
50 |
+
lora_module_name = key[:key.rfind(".lora_down")]
|
51 |
+
|
52 |
+
down_weight = lora_sd[key]
|
53 |
+
network_dim = down_weight.size()[0]
|
54 |
+
|
55 |
+
up_weight = lora_sd[lora_module_name + '.lora_up.weight']
|
56 |
+
alpha = lora_sd.get(lora_module_name + '.alpha', network_dim)
|
57 |
+
|
58 |
+
in_dim = down_weight.size()[1]
|
59 |
+
out_dim = up_weight.size()[0]
|
60 |
+
conv2d = len(down_weight.size()) == 4
|
61 |
+
print(lora_module_name, network_dim, alpha, in_dim, out_dim)
|
62 |
+
|
63 |
+
# make original weight if not exist
|
64 |
+
if lora_module_name not in merged_sd:
|
65 |
+
weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
|
66 |
+
if device:
|
67 |
+
weight = weight.to(device)
|
68 |
+
else:
|
69 |
+
weight = merged_sd[lora_module_name]
|
70 |
+
|
71 |
+
# merge to weight
|
72 |
+
if device:
|
73 |
+
up_weight = up_weight.to(device)
|
74 |
+
down_weight = down_weight.to(device)
|
75 |
+
|
76 |
+
# W <- W + U * D
|
77 |
+
scale = (alpha / network_dim)
|
78 |
+
if not conv2d: # linear
|
79 |
+
weight = weight + ratio * (up_weight @ down_weight) * scale
|
80 |
+
else:
|
81 |
+
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
82 |
+
).unsqueeze(2).unsqueeze(3) * scale
|
83 |
+
|
84 |
+
merged_sd[lora_module_name] = weight
|
85 |
+
|
86 |
+
# extract from merged weights
|
87 |
+
print("extract new lora...")
|
88 |
+
merged_lora_sd = {}
|
89 |
+
with torch.no_grad():
|
90 |
+
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
91 |
+
conv2d = (len(mat.size()) == 4)
|
92 |
+
if conv2d:
|
93 |
+
mat = mat.squeeze()
|
94 |
+
|
95 |
+
U, S, Vh = torch.linalg.svd(mat)
|
96 |
+
|
97 |
+
U = U[:, :new_rank]
|
98 |
+
S = S[:new_rank]
|
99 |
+
U = U @ torch.diag(S)
|
100 |
+
|
101 |
+
Vh = Vh[:new_rank, :]
|
102 |
+
|
103 |
+
dist = torch.cat([U.flatten(), Vh.flatten()])
|
104 |
+
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
105 |
+
low_val = -hi_val
|
106 |
+
|
107 |
+
U = U.clamp(low_val, hi_val)
|
108 |
+
Vh = Vh.clamp(low_val, hi_val)
|
109 |
+
|
110 |
+
up_weight = U
|
111 |
+
down_weight = Vh
|
112 |
+
|
113 |
+
if conv2d:
|
114 |
+
up_weight = up_weight.unsqueeze(2).unsqueeze(3)
|
115 |
+
down_weight = down_weight.unsqueeze(2).unsqueeze(3)
|
116 |
+
|
117 |
+
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
|
118 |
+
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
|
119 |
+
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank)
|
120 |
+
|
121 |
+
return merged_lora_sd
|
122 |
+
|
123 |
+
|
124 |
+
def merge(args):
|
125 |
+
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
126 |
+
|
127 |
+
def str_to_dtype(p):
|
128 |
+
if p == 'float':
|
129 |
+
return torch.float
|
130 |
+
if p == 'fp16':
|
131 |
+
return torch.float16
|
132 |
+
if p == 'bf16':
|
133 |
+
return torch.bfloat16
|
134 |
+
return None
|
135 |
+
|
136 |
+
merge_dtype = str_to_dtype(args.precision)
|
137 |
+
save_dtype = str_to_dtype(args.save_precision)
|
138 |
+
if save_dtype is None:
|
139 |
+
save_dtype = merge_dtype
|
140 |
+
|
141 |
+
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, args.device, merge_dtype)
|
142 |
+
|
143 |
+
print(f"saving model to: {args.save_to}")
|
144 |
+
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
145 |
+
|
146 |
+
|
147 |
+
if __name__ == '__main__':
|
148 |
+
parser = argparse.ArgumentParser()
|
149 |
+
parser.add_argument("--save_precision", type=str, default=None,
|
150 |
+
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
|
151 |
+
parser.add_argument("--precision", type=str, default="float",
|
152 |
+
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)")
|
153 |
+
parser.add_argument("--save_to", type=str, default=None,
|
154 |
+
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
155 |
+
parser.add_argument("--models", type=str, nargs='*',
|
156 |
+
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
|
157 |
+
parser.add_argument("--ratios", type=float, nargs='*',
|
158 |
+
help="ratios for each model / それぞれのLoRAモデルの比率")
|
159 |
+
parser.add_argument("--new_rank", type=int, default=4,
|
160 |
+
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
161 |
+
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
162 |
+
|
163 |
+
args = parser.parse_args()
|
164 |
+
merge(args)
|
requirements.txt
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.15.0
|
2 |
+
transformers==4.26.0
|
3 |
+
ftfy==6.1.1
|
4 |
+
albumentations==1.3.0
|
5 |
+
opencv-python==4.7.0.68
|
6 |
+
einops==0.6.0
|
7 |
+
diffusers[torch]==0.10.2
|
8 |
+
pytorch-lightning==1.9.0
|
9 |
+
bitsandbytes==0.35.0
|
10 |
+
tensorboard==2.10.1
|
11 |
+
safetensors==0.2.6
|
12 |
+
gradio==3.16.2
|
13 |
+
altair==4.2.2
|
14 |
+
easygui==0.98.3
|
15 |
+
# for BLIP captioning
|
16 |
+
requests==2.28.2
|
17 |
+
timm==0.6.12
|
18 |
+
fairscale==0.4.13
|
19 |
+
# for WD14 captioning
|
20 |
+
# tensorflow<2.11
|
21 |
+
tensorflow==2.10.1
|
22 |
+
huggingface-hub==0.12.0
|
23 |
+
# for kohya_ss library
|
24 |
+
#locon.locon_kohya
|
25 |
+
.
|
requirements_startup.txt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.15.0
|
2 |
+
transformers==4.26.0
|
3 |
+
ftfy==6.1.1
|
4 |
+
albumentations==1.3.0
|
5 |
+
opencv-python==4.7.0.68
|
6 |
+
einops==0.6.0
|
7 |
+
diffusers[torch]==0.10.2
|
8 |
+
pytorch-lightning==1.9.0
|
9 |
+
bitsandbytes==0.35.0
|
10 |
+
tensorboard==2.10.1
|
11 |
+
safetensors==0.2.6
|
12 |
+
gradio==3.18.0
|
13 |
+
altair==4.2.2
|
14 |
+
easygui==0.98.3
|
15 |
+
# for BLIP captioning
|
16 |
+
requests==2.28.2
|
17 |
+
timm==0.4.12
|
18 |
+
fairscale==0.4.4
|
19 |
+
# for WD14 captioning
|
20 |
+
tensorflow==2.10.1
|
21 |
+
huggingface-hub==0.12.0
|
22 |
+
# for kohya_ss library
|
23 |
+
.
|
setup.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
setup(name = "library", packages = find_packages())
|
tools/convert_diffusers20_original_sd.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# convert Diffusers v1.x/v2.0 model to original Stable Diffusion
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from diffusers import StableDiffusionPipeline
|
7 |
+
|
8 |
+
import library.model_util as model_util
|
9 |
+
|
10 |
+
|
11 |
+
def convert(args):
|
12 |
+
# 引数を確認する
|
13 |
+
load_dtype = torch.float16 if args.fp16 else None
|
14 |
+
|
15 |
+
save_dtype = None
|
16 |
+
if args.fp16:
|
17 |
+
save_dtype = torch.float16
|
18 |
+
elif args.bf16:
|
19 |
+
save_dtype = torch.bfloat16
|
20 |
+
elif args.float:
|
21 |
+
save_dtype = torch.float
|
22 |
+
|
23 |
+
is_load_ckpt = os.path.isfile(args.model_to_load)
|
24 |
+
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
|
25 |
+
|
26 |
+
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
|
27 |
+
assert is_save_ckpt or args.reference_model is not None, f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
|
28 |
+
|
29 |
+
# モデルを読み込む
|
30 |
+
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
|
31 |
+
print(f"loading {msg}: {args.model_to_load}")
|
32 |
+
|
33 |
+
if is_load_ckpt:
|
34 |
+
v2_model = args.v2
|
35 |
+
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load)
|
36 |
+
else:
|
37 |
+
pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None)
|
38 |
+
text_encoder = pipe.text_encoder
|
39 |
+
vae = pipe.vae
|
40 |
+
unet = pipe.unet
|
41 |
+
|
42 |
+
if args.v1 == args.v2:
|
43 |
+
# 自動判定する
|
44 |
+
v2_model = unet.config.cross_attention_dim == 1024
|
45 |
+
print("checking model version: model is " + ('v2' if v2_model else 'v1'))
|
46 |
+
else:
|
47 |
+
v2_model = not args.v1
|
48 |
+
|
49 |
+
# 変換して保存する
|
50 |
+
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
|
51 |
+
print(f"converting and saving as {msg}: {args.model_to_save}")
|
52 |
+
|
53 |
+
if is_save_ckpt:
|
54 |
+
original_model = args.model_to_load if is_load_ckpt else None
|
55 |
+
key_count = model_util.save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet,
|
56 |
+
original_model, args.epoch, args.global_step, save_dtype, vae)
|
57 |
+
print(f"model saved. total converted state_dict keys: {key_count}")
|
58 |
+
else:
|
59 |
+
print(f"copy scheduler/tokenizer config from: {args.reference_model}")
|
60 |
+
model_util.save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors)
|
61 |
+
print(f"model saved.")
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == '__main__':
|
65 |
+
parser = argparse.ArgumentParser()
|
66 |
+
parser.add_argument("--v1", action='store_true',
|
67 |
+
help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')
|
68 |
+
parser.add_argument("--v2", action='store_true',
|
69 |
+
help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む')
|
70 |
+
parser.add_argument("--fp16", action='store_true',
|
71 |
+
help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)')
|
72 |
+
parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)')
|
73 |
+
parser.add_argument("--float", action='store_true',
|
74 |
+
help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)')
|
75 |
+
parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値')
|
76 |
+
parser.add_argument("--global_step", type=int, default=0,
|
77 |
+
help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')
|
78 |
+
parser.add_argument("--reference_model", type=str, default=None,
|
79 |
+
help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要")
|
80 |
+
parser.add_argument("--use_safetensors", action='store_true',
|
81 |
+
help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)")
|
82 |
+
|
83 |
+
parser.add_argument("model_to_load", type=str, default=None,
|
84 |
+
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")
|
85 |
+
parser.add_argument("model_to_save", type=str, default=None,
|
86 |
+
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")
|
87 |
+
|
88 |
+
args = parser.parse_args()
|
89 |
+
convert(args)
|
tools/detect_face_rotate.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
|
2 |
+
# (c) 2022 Kohya S. @kohya_ss
|
3 |
+
|
4 |
+
# 横長の画像から顔検出して正立するように回転し、そこを中心に正方形に切り出す
|
5 |
+
|
6 |
+
# v2: extract max face if multiple faces are found
|
7 |
+
# v3: add crop_ratio option
|
8 |
+
# v4: add multiple faces extraction and min/max size
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
import math
|
12 |
+
import cv2
|
13 |
+
import glob
|
14 |
+
import os
|
15 |
+
from anime_face_detector import create_detector
|
16 |
+
from tqdm import tqdm
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
KP_REYE = 11
|
20 |
+
KP_LEYE = 19
|
21 |
+
|
22 |
+
SCORE_THRES = 0.90
|
23 |
+
|
24 |
+
|
25 |
+
def detect_faces(detector, image, min_size):
|
26 |
+
preds = detector(image) # bgr
|
27 |
+
# print(len(preds))
|
28 |
+
|
29 |
+
faces = []
|
30 |
+
for pred in preds:
|
31 |
+
bb = pred['bbox']
|
32 |
+
score = bb[-1]
|
33 |
+
if score < SCORE_THRES:
|
34 |
+
continue
|
35 |
+
|
36 |
+
left, top, right, bottom = bb[:4]
|
37 |
+
cx = int((left + right) / 2)
|
38 |
+
cy = int((top + bottom) / 2)
|
39 |
+
fw = int(right - left)
|
40 |
+
fh = int(bottom - top)
|
41 |
+
|
42 |
+
lex, ley = pred['keypoints'][KP_LEYE, 0:2]
|
43 |
+
rex, rey = pred['keypoints'][KP_REYE, 0:2]
|
44 |
+
angle = math.atan2(ley - rey, lex - rex)
|
45 |
+
angle = angle / math.pi * 180
|
46 |
+
|
47 |
+
faces.append((cx, cy, fw, fh, angle))
|
48 |
+
|
49 |
+
faces.sort(key=lambda x: max(x[2], x[3]), reverse=True) # 大きい順
|
50 |
+
return faces
|
51 |
+
|
52 |
+
|
53 |
+
def rotate_image(image, angle, cx, cy):
|
54 |
+
h, w = image.shape[0:2]
|
55 |
+
rot_mat = cv2.getRotationMatrix2D((cx, cy), angle, 1.0)
|
56 |
+
|
57 |
+
# # 回転する分、すこし画像サイズを大きくする→とりあえず無効化
|
58 |
+
# nh = max(h, int(w * math.sin(angle)))
|
59 |
+
# nw = max(w, int(h * math.sin(angle)))
|
60 |
+
# if nh > h or nw > w:
|
61 |
+
# pad_y = nh - h
|
62 |
+
# pad_t = pad_y // 2
|
63 |
+
# pad_x = nw - w
|
64 |
+
# pad_l = pad_x // 2
|
65 |
+
# m = np.array([[0, 0, pad_l],
|
66 |
+
# [0, 0, pad_t]])
|
67 |
+
# rot_mat = rot_mat + m
|
68 |
+
# h, w = nh, nw
|
69 |
+
# cx += pad_l
|
70 |
+
# cy += pad_t
|
71 |
+
|
72 |
+
result = cv2.warpAffine(image, rot_mat, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
|
73 |
+
return result, cx, cy
|
74 |
+
|
75 |
+
|
76 |
+
def process(args):
|
77 |
+
assert (not args.resize_fit) or args.resize_face_size is None, f"resize_fit and resize_face_size can't be specified both / resize_fitとresize_face_sizeはどちらか片方しか指定できません"
|
78 |
+
assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません"
|
79 |
+
|
80 |
+
# アニメ顔検出モデルを読み込む
|
81 |
+
print("loading face detector.")
|
82 |
+
detector = create_detector('yolov3')
|
83 |
+
|
84 |
+
# cropの引数を解析する
|
85 |
+
if args.crop_size is None:
|
86 |
+
crop_width = crop_height = None
|
87 |
+
else:
|
88 |
+
tokens = args.crop_size.split(',')
|
89 |
+
assert len(tokens) == 2, f"crop_size must be 'width,height' / crop_sizeは'幅,高さ'で指定してください"
|
90 |
+
crop_width, crop_height = [int(t) for t in tokens]
|
91 |
+
|
92 |
+
if args.crop_ratio is None:
|
93 |
+
crop_h_ratio = crop_v_ratio = None
|
94 |
+
else:
|
95 |
+
tokens = args.crop_ratio.split(',')
|
96 |
+
assert len(tokens) == 2, f"crop_ratio must be 'horizontal,vertical' / crop_ratioは'幅,高さ'の倍率で指定してください"
|
97 |
+
crop_h_ratio, crop_v_ratio = [float(t) for t in tokens]
|
98 |
+
|
99 |
+
# 画像を処理する
|
100 |
+
print("processing.")
|
101 |
+
output_extension = ".png"
|
102 |
+
|
103 |
+
os.makedirs(args.dst_dir, exist_ok=True)
|
104 |
+
paths = glob.glob(os.path.join(args.src_dir, "*.png")) + glob.glob(os.path.join(args.src_dir, "*.jpg")) + \
|
105 |
+
glob.glob(os.path.join(args.src_dir, "*.webp"))
|
106 |
+
for path in tqdm(paths):
|
107 |
+
basename = os.path.splitext(os.path.basename(path))[0]
|
108 |
+
|
109 |
+
# image = cv2.imread(path) # 日本語ファイル名でエラーになる
|
110 |
+
image = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_UNCHANGED)
|
111 |
+
if len(image.shape) == 2:
|
112 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
113 |
+
if image.shape[2] == 4:
|
114 |
+
print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
|
115 |
+
image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい
|
116 |
+
|
117 |
+
h, w = image.shape[:2]
|
118 |
+
|
119 |
+
faces = detect_faces(detector, image, args.multiple_faces)
|
120 |
+
for i, face in enumerate(faces):
|
121 |
+
cx, cy, fw, fh, angle = face
|
122 |
+
face_size = max(fw, fh)
|
123 |
+
if args.min_size is not None and face_size < args.min_size:
|
124 |
+
continue
|
125 |
+
if args.max_size is not None and face_size >= args.max_size:
|
126 |
+
continue
|
127 |
+
face_suffix = f"_{i+1:02d}" if args.multiple_faces else ""
|
128 |
+
|
129 |
+
# オプション指定があれば回転する
|
130 |
+
face_img = image
|
131 |
+
if args.rotate:
|
132 |
+
face_img, cx, cy = rotate_image(face_img, angle, cx, cy)
|
133 |
+
|
134 |
+
# オプション指定があれば顔を中心に切り出す
|
135 |
+
if crop_width is not None or crop_h_ratio is not None:
|
136 |
+
cur_crop_width, cur_crop_height = crop_width, crop_height
|
137 |
+
if crop_h_ratio is not None:
|
138 |
+
cur_crop_width = int(face_size * crop_h_ratio + .5)
|
139 |
+
cur_crop_height = int(face_size * crop_v_ratio + .5)
|
140 |
+
|
141 |
+
# リサイズを必要なら行う
|
142 |
+
scale = 1.0
|
143 |
+
if args.resize_face_size is not None:
|
144 |
+
# 顔サイズを基準にリサイズする
|
145 |
+
scale = args.resize_face_size / face_size
|
146 |
+
if scale < cur_crop_width / w:
|
147 |
+
print(
|
148 |
+
f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
|
149 |
+
scale = cur_crop_width / w
|
150 |
+
if scale < cur_crop_height / h:
|
151 |
+
print(
|
152 |
+
f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
|
153 |
+
scale = cur_crop_height / h
|
154 |
+
elif crop_h_ratio is not None:
|
155 |
+
# 倍率指定の時にはリサイズしない
|
156 |
+
pass
|
157 |
+
else:
|
158 |
+
# 切り出しサイズ指定あり
|
159 |
+
if w < cur_crop_width:
|
160 |
+
print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
|
161 |
+
scale = cur_crop_width / w
|
162 |
+
if h < cur_crop_height:
|
163 |
+
print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
|
164 |
+
scale = cur_crop_height / h
|
165 |
+
if args.resize_fit:
|
166 |
+
scale = max(cur_crop_width / w, cur_crop_height / h)
|
167 |
+
|
168 |
+
if scale != 1.0:
|
169 |
+
w = int(w * scale + .5)
|
170 |
+
h = int(h * scale + .5)
|
171 |
+
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4)
|
172 |
+
cx = int(cx * scale + .5)
|
173 |
+
cy = int(cy * scale + .5)
|
174 |
+
fw = int(fw * scale + .5)
|
175 |
+
fh = int(fh * scale + .5)
|
176 |
+
|
177 |
+
cur_crop_width = min(cur_crop_width, face_img.shape[1])
|
178 |
+
cur_crop_height = min(cur_crop_height, face_img.shape[0])
|
179 |
+
|
180 |
+
x = cx - cur_crop_width // 2
|
181 |
+
cx = cur_crop_width // 2
|
182 |
+
if x < 0:
|
183 |
+
cx = cx + x
|
184 |
+
x = 0
|
185 |
+
elif x + cur_crop_width > w:
|
186 |
+
cx = cx + (x + cur_crop_width - w)
|
187 |
+
x = w - cur_crop_width
|
188 |
+
face_img = face_img[:, x:x+cur_crop_width]
|
189 |
+
|
190 |
+
y = cy - cur_crop_height // 2
|
191 |
+
cy = cur_crop_height // 2
|
192 |
+
if y < 0:
|
193 |
+
cy = cy + y
|
194 |
+
y = 0
|
195 |
+
elif y + cur_crop_height > h:
|
196 |
+
cy = cy + (y + cur_crop_height - h)
|
197 |
+
y = h - cur_crop_height
|
198 |
+
face_img = face_img[y:y + cur_crop_height]
|
199 |
+
|
200 |
+
# # debug
|
201 |
+
# print(path, cx, cy, angle)
|
202 |
+
# crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8))
|
203 |
+
# cv2.imshow("image", crp)
|
204 |
+
# if cv2.waitKey() == 27:
|
205 |
+
# break
|
206 |
+
# cv2.destroyAllWindows()
|
207 |
+
|
208 |
+
# debug
|
209 |
+
if args.debug:
|
210 |
+
cv2.rectangle(face_img, (cx-fw//2, cy-fh//2), (cx+fw//2, cy+fh//2), (255, 0, 255), fw//20)
|
211 |
+
|
212 |
+
_, buf = cv2.imencode(output_extension, face_img)
|
213 |
+
with open(os.path.join(args.dst_dir, f"{basename}{face_suffix}_{cx:04d}_{cy:04d}_{fw:04d}_{fh:04d}{output_extension}"), "wb") as f:
|
214 |
+
buf.tofile(f)
|
215 |
+
|
216 |
+
|
217 |
+
if __name__ == '__main__':
|
218 |
+
parser = argparse.ArgumentParser()
|
219 |
+
parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ")
|
220 |
+
parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ")
|
221 |
+
parser.add_argument("--rotate", action="store_true", help="rotate images to align faces / 顔が正立するように画像を回転する")
|
222 |
+
parser.add_argument("--resize_fit", action="store_true",
|
223 |
+
help="resize to fit smaller side after cropping / 切り出し後の画像の短辺がcrop_sizeにあうようにリサイズする")
|
224 |
+
parser.add_argument("--resize_face_size", type=int, default=None,
|
225 |
+
help="resize image before cropping by face size / 切り出し前に顔がこのサイズになるようにリサイズする")
|
226 |
+
parser.add_argument("--crop_size", type=str, default=None,
|
227 |
+
help="crop images with 'width,height' pixels, face centered / 顔を中心として'幅,高さ'のサイズで切り出す")
|
228 |
+
parser.add_argument("--crop_ratio", type=str, default=None,
|
229 |
+
help="crop images with 'horizontal,vertical' ratio to face, face centered / 顔を中心として顔サイズの'幅倍率,高さ倍率'のサイズで切り出す")
|
230 |
+
parser.add_argument("--min_size", type=int, default=None,
|
231 |
+
help="minimum face size to output (included) / 処理対象とする顔の最小サイズ(この値以上)")
|
232 |
+
parser.add_argument("--max_size", type=int, default=None,
|
233 |
+
help="maximum face size to output (excluded) / 処理対象とする顔の最大サイズ(この値未満)")
|
234 |
+
parser.add_argument("--multiple_faces", action="store_true",
|
235 |
+
help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す")
|
236 |
+
parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します")
|
237 |
+
args = parser.parse_args()
|
238 |
+
|
239 |
+
process(args)
|
tools/resize_images_to_resolution.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import cv2
|
4 |
+
import argparse
|
5 |
+
import shutil
|
6 |
+
import math
|
7 |
+
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
|
11 |
+
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
|
12 |
+
# Split the max_resolution string by "," and strip any whitespaces
|
13 |
+
max_resolutions = [res.strip() for res in max_resolution.split(',')]
|
14 |
+
|
15 |
+
# # Calculate max_pixels from max_resolution string
|
16 |
+
# max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
|
17 |
+
|
18 |
+
# Create destination folder if it does not exist
|
19 |
+
if not os.path.exists(dst_img_folder):
|
20 |
+
os.makedirs(dst_img_folder)
|
21 |
+
|
22 |
+
# Select interpolation method
|
23 |
+
if interpolation == 'lanczos4':
|
24 |
+
cv2_interpolation = cv2.INTER_LANCZOS4
|
25 |
+
elif interpolation == 'cubic':
|
26 |
+
cv2_interpolation = cv2.INTER_CUBIC
|
27 |
+
else:
|
28 |
+
cv2_interpolation = cv2.INTER_AREA
|
29 |
+
|
30 |
+
# Iterate through all files in src_img_folder
|
31 |
+
img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
|
32 |
+
for filename in os.listdir(src_img_folder):
|
33 |
+
# Check if the image is png, jpg or webp etc...
|
34 |
+
if not filename.endswith(img_exts):
|
35 |
+
# Copy the file to the destination folder if not png, jpg or webp etc (.txt or .caption or etc.)
|
36 |
+
shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
|
37 |
+
continue
|
38 |
+
|
39 |
+
# Load image
|
40 |
+
# img = cv2.imread(os.path.join(src_img_folder, filename))
|
41 |
+
image = Image.open(os.path.join(src_img_folder, filename))
|
42 |
+
if not image.mode == "RGB":
|
43 |
+
image = image.convert("RGB")
|
44 |
+
img = np.array(image, np.uint8)
|
45 |
+
|
46 |
+
base, _ = os.path.splitext(filename)
|
47 |
+
for max_resolution in max_resolutions:
|
48 |
+
# Calculate max_pixels from max_resolution string
|
49 |
+
max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
|
50 |
+
|
51 |
+
# Calculate current number of pixels
|
52 |
+
current_pixels = img.shape[0] * img.shape[1]
|
53 |
+
|
54 |
+
# Check if the image needs resizing
|
55 |
+
if current_pixels > max_pixels:
|
56 |
+
# Calculate scaling factor
|
57 |
+
scale_factor = max_pixels / current_pixels
|
58 |
+
|
59 |
+
# Calculate new dimensions
|
60 |
+
new_height = int(img.shape[0] * math.sqrt(scale_factor))
|
61 |
+
new_width = int(img.shape[1] * math.sqrt(scale_factor))
|
62 |
+
|
63 |
+
# Resize image
|
64 |
+
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
|
65 |
+
else:
|
66 |
+
new_height, new_width = img.shape[0:2]
|
67 |
+
|
68 |
+
# Calculate the new height and width that are divisible by divisible_by (with/without resizing)
|
69 |
+
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
|
70 |
+
new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
|
71 |
+
|
72 |
+
# Center crop the image to the calculated dimensions
|
73 |
+
y = int((img.shape[0] - new_height) / 2)
|
74 |
+
x = int((img.shape[1] - new_width) / 2)
|
75 |
+
img = img[y:y + new_height, x:x + new_width]
|
76 |
+
|
77 |
+
# Split filename into base and extension
|
78 |
+
new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg')
|
79 |
+
|
80 |
+
# Save resized image in dst_img_folder
|
81 |
+
# cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
|
82 |
+
image = Image.fromarray(img)
|
83 |
+
image.save(os.path.join(dst_img_folder, new_filename), quality=100)
|
84 |
+
|
85 |
+
proc = "Resized" if current_pixels > max_pixels else "Saved"
|
86 |
+
print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
|
87 |
+
|
88 |
+
# If other files with same basename, copy them with resolution suffix
|
89 |
+
if copy_associated_files:
|
90 |
+
asoc_files = glob.glob(os.path.join(src_img_folder, base + ".*"))
|
91 |
+
for asoc_file in asoc_files:
|
92 |
+
ext = os.path.splitext(asoc_file)[1]
|
93 |
+
if ext in img_exts:
|
94 |
+
continue
|
95 |
+
for max_resolution in max_resolutions:
|
96 |
+
new_asoc_file = base + '+' + max_resolution + ext
|
97 |
+
print(f"Copy {asoc_file} as {new_asoc_file}")
|
98 |
+
shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))
|
99 |
+
|
100 |
+
|
101 |
+
def main():
|
102 |
+
parser = argparse.ArgumentParser(
|
103 |
+
description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします')
|
104 |
+
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ')
|
105 |
+
parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images / リサイズ後の画像を保存するフォルダ')
|
106 |
+
parser.add_argument('--max_resolution', type=str,
|
107 |
+
help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
|
108 |
+
parser.add_argument('--divisible_by', type=int,
|
109 |
+
help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
|
110 |
+
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
|
111 |
+
default='area', help='Interpolation method for resizing / リサイズ時の補完方法')
|
112 |
+
parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
|
113 |
+
parser.add_argument('--copy_associated_files', action='store_true',
|
114 |
+
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')
|
115 |
+
|
116 |
+
args = parser.parse_args()
|
117 |
+
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution,
|
118 |
+
args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files)
|
119 |
+
|
120 |
+
|
121 |
+
if __name__ == '__main__':
|
122 |
+
main()
|