abc
commited on
Commit
·
3249d87
1
Parent(s):
74be2a5
Upload 55 files
Browse files- .gitattributes +1 -0
- .github/workflows/typos.yml +21 -0
- .gitignore +7 -0
- LICENSE.md +201 -0
- README-ja.md +147 -0
- README.md +230 -0
- append_module.py +378 -56
- bitsandbytes_windows/cextension.py +54 -0
- bitsandbytes_windows/libbitsandbytes_cpu.dll +0 -0
- bitsandbytes_windows/libbitsandbytes_cuda116.dll +3 -0
- bitsandbytes_windows/main.py +166 -0
- config_README-ja.md +279 -0
- fine_tune.py +50 -45
- fine_tune_README_ja.md +140 -0
- finetune/blip/blip.py +240 -0
- finetune/blip/med.py +955 -0
- finetune/blip/med_config.json +22 -0
- finetune/blip/vit.py +305 -0
- finetune/clean_captions_and_tags.py +184 -0
- finetune/hypernetwork_nai.py +96 -0
- finetune/make_captions.py +162 -0
- finetune/make_captions_by_git.py +145 -0
- finetune/merge_captions_to_metadata.py +67 -0
- finetune/merge_dd_tags_to_metadata.py +62 -0
- finetune/prepare_buckets_latents.py +261 -0
- finetune/tag_images_by_wd14_tagger.py +200 -0
- gen_img_diffusers.py +234 -55
- library/model_util.py +5 -1
- library/train_util.py +853 -229
- networks/check_lora_weights.py +1 -1
- networks/extract_lora_from_models.py +44 -25
- networks/lora.py +191 -30
- networks/merge_lora.py +11 -5
- networks/resize_lora.py +187 -50
- networks/svd_merge_lora.py +40 -18
- requirements.txt +2 -0
- tools/canny.py +24 -0
- tools/original_control_net.py +320 -0
- train_README-ja.md +936 -0
- train_db.py +47 -45
- train_db_README-ja.md +167 -0
- train_network.py +248 -175
- train_network_README-ja.md +269 -0
- train_network_opt.py +324 -373
- train_textual_inversion.py +72 -58
- train_ti_README-ja.md +105 -0
.gitattributes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
bitsandbytes_windows/libbitsandbytes_cuda116.dll filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/typos.yml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
# yamllint disable rule:line-length
|
3 |
+
name: Typos
|
4 |
+
|
5 |
+
on: # yamllint disable-line rule:truthy
|
6 |
+
push:
|
7 |
+
pull_request:
|
8 |
+
types:
|
9 |
+
- opened
|
10 |
+
- synchronize
|
11 |
+
- reopened
|
12 |
+
|
13 |
+
jobs:
|
14 |
+
build:
|
15 |
+
runs-on: ubuntu-latest
|
16 |
+
|
17 |
+
steps:
|
18 |
+
- uses: actions/checkout@v3
|
19 |
+
|
20 |
+
- name: typos-action
|
21 |
+
uses: crate-ci/[email protected]
|
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
logs
|
2 |
+
__pycache__
|
3 |
+
wd14_tagger_model
|
4 |
+
venv
|
5 |
+
*.egg-info
|
6 |
+
build
|
7 |
+
.vscode
|
LICENSE.md
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [2022] [kohya-ss]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README-ja.md
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## リポジトリについて
|
2 |
+
Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。
|
3 |
+
|
4 |
+
[README in English](./README.md) ←更新情報はこちらにあります
|
5 |
+
|
6 |
+
GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。
|
7 |
+
|
8 |
+
以下のスクリプトがあります。
|
9 |
+
|
10 |
+
* DreamBooth、U-NetおよびText Encoderの学習をサポート
|
11 |
+
* fine-tuning、同上
|
12 |
+
* 画像生成
|
13 |
+
* モデル変換(Stable Diffision ckpt/safetensorsとDiffusersの相互変換)
|
14 |
+
|
15 |
+
## 使用法について
|
16 |
+
|
17 |
+
当リポジトリ内およびnote.comに記事がありますのでそちらをご覧ください(将来的にはすべてこちらへ移すかもしれません)。
|
18 |
+
|
19 |
+
* [学習について、共通編](./train_README-ja.md) : データ整備やオプションなど
|
20 |
+
* [データセット設定](./config_README-ja.md)
|
21 |
+
* [DreamBoothの学習について](./train_db_README-ja.md)
|
22 |
+
* [fine-tuningのガイド](./fine_tune_README_ja.md):
|
23 |
+
* [LoRAの学習について](./train_network_README-ja.md)
|
24 |
+
* [Textual Inversionの学習について](./train_ti_README-ja.md)
|
25 |
+
* note.com [画像生成スクリプト](https://note.com/kohya_ss/n/n2693183a798e)
|
26 |
+
* note.com [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad)
|
27 |
+
|
28 |
+
## Windowsでの動作に必要なプログラム
|
29 |
+
|
30 |
+
Python 3.10.6およびGitが必要です。
|
31 |
+
|
32 |
+
- Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
|
33 |
+
- git: https://git-scm.com/download/win
|
34 |
+
|
35 |
+
PowerShellを使う場合、venvを使えるようにするためには以下の手順でセキュリティ設定を変更してください。
|
36 |
+
(venvに限らずスクリプトの実行が可能になりますので注意してください。)
|
37 |
+
|
38 |
+
- PowerShellを管理者として開きます。
|
39 |
+
- 「Set-ExecutionPolicy Unrestricted」と入力し、Yと答えます。
|
40 |
+
- 管理者のPowerShellを閉じます。
|
41 |
+
|
42 |
+
## Windows環境でのインストール
|
43 |
+
|
44 |
+
以下の例ではPyTorchは1.12.1/CUDA 11.6版をインストールします。CUDA 11.3版やPyTorch 1.13を使う場合は適宜書き換えください。
|
45 |
+
|
46 |
+
(なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。)
|
47 |
+
|
48 |
+
通常の(管理者ではない)PowerShellを開き以下を順に実行します。
|
49 |
+
|
50 |
+
```powershell
|
51 |
+
git clone https://github.com/kohya-ss/sd-scripts.git
|
52 |
+
cd sd-scripts
|
53 |
+
|
54 |
+
python -m venv venv
|
55 |
+
.\venv\Scripts\activate
|
56 |
+
|
57 |
+
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
|
58 |
+
pip install --upgrade -r requirements.txt
|
59 |
+
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
60 |
+
|
61 |
+
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
|
62 |
+
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
|
63 |
+
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
|
64 |
+
|
65 |
+
accelerate config
|
66 |
+
```
|
67 |
+
|
68 |
+
<!--
|
69 |
+
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
|
70 |
+
pip install --use-pep517 --upgrade -r requirements.txt
|
71 |
+
pip install -U -I --no-deps xformers==0.0.16
|
72 |
+
-->
|
73 |
+
|
74 |
+
コマンドプロンプトでは以下になります。
|
75 |
+
|
76 |
+
|
77 |
+
```bat
|
78 |
+
git clone https://github.com/kohya-ss/sd-scripts.git
|
79 |
+
cd sd-scripts
|
80 |
+
|
81 |
+
python -m venv venv
|
82 |
+
.\venv\Scripts\activate
|
83 |
+
|
84 |
+
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
|
85 |
+
pip install --upgrade -r requirements.txt
|
86 |
+
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
87 |
+
|
88 |
+
copy /y .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
|
89 |
+
copy /y .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
|
90 |
+
copy /y .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
|
91 |
+
|
92 |
+
accelerate config
|
93 |
+
```
|
94 |
+
|
95 |
+
(注:``python -m venv venv`` のほうが ``python -m venv --system-site-packages venv`` より安全そうなため書き換えました。globalなpythonにパッケージがインストールしてあると、後者だといろいろと問題が起きます。)
|
96 |
+
|
97 |
+
accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。)
|
98 |
+
|
99 |
+
※0.15.0から日本語環境では選択のためにカーソルキーを押すと落ちます(……)。数字キーの0、1、2……で選択できますので、そちらを使ってください。
|
100 |
+
|
101 |
+
```txt
|
102 |
+
- This machine
|
103 |
+
- No distributed training
|
104 |
+
- NO
|
105 |
+
- NO
|
106 |
+
- NO
|
107 |
+
- all
|
108 |
+
- fp16
|
109 |
+
```
|
110 |
+
|
111 |
+
※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問(
|
112 |
+
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。)
|
113 |
+
|
114 |
+
### PyTorchとxformersのバージョンについて
|
115 |
+
|
116 |
+
他のバージョンでは学習がうまくいかない場合があるようです。特に他の理由がなければ指定のバージョンをお使いください。
|
117 |
+
|
118 |
+
## アップグレード
|
119 |
+
|
120 |
+
新しいリリースがあった場合、以下のコマンドで更新できます。
|
121 |
+
|
122 |
+
```powershell
|
123 |
+
cd sd-scripts
|
124 |
+
git pull
|
125 |
+
.\venv\Scripts\activate
|
126 |
+
pip install --use-pep517 --upgrade -r requirements.txt
|
127 |
+
```
|
128 |
+
|
129 |
+
コマンドが成功すれば新しいバージョンが使用できます。
|
130 |
+
|
131 |
+
## 謝意
|
132 |
+
|
133 |
+
LoRAの実装は[cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を基にしたものです。感謝申し上げます。
|
134 |
+
|
135 |
+
Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) が最初にリリースし、KohakuBlueleaf氏が [LoCon](https://github.com/KohakuBlueleaf/LoCon) でその有効性を明らかにしたものです。KohakuBlueleaf氏に深く感謝します。
|
136 |
+
|
137 |
+
## ライセンス
|
138 |
+
|
139 |
+
スクリプトのライセンスはASL 2.0ですが(Diffusersおよびcloneofsimo氏のリポジトリ由来のものも同様)、一部他のライセンスのコードを含みます。
|
140 |
+
|
141 |
+
[Memory Efficient Attention Pytorch](https://github.com/lucidrains/memory-efficient-attention-pytorch): MIT
|
142 |
+
|
143 |
+
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT
|
144 |
+
|
145 |
+
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
|
146 |
+
|
147 |
+
|
README.md
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
This repository contains training, generation and utility scripts for Stable Diffusion.
|
2 |
+
|
3 |
+
[__Change History__](#change-history) is moved to the bottom of the page.
|
4 |
+
更新履歴は[ページ末尾](#change-history)に移しました。
|
5 |
+
|
6 |
+
[日本語版README](./README-ja.md)
|
7 |
+
|
8 |
+
For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais!
|
9 |
+
|
10 |
+
This repository contains the scripts for:
|
11 |
+
|
12 |
+
* DreamBooth training, including U-Net and Text Encoder
|
13 |
+
* Fine-tuning (native training), including U-Net and Text Encoder
|
14 |
+
* LoRA training
|
15 |
+
* Texutl Inversion training
|
16 |
+
* Image generation
|
17 |
+
* Model conversion (supports 1.x and 2.x, Stable Diffision ckpt/safetensors and Diffusers)
|
18 |
+
|
19 |
+
__Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ (SD 1.x based only) Thank you for great work!!!
|
20 |
+
|
21 |
+
## About requirements.txt
|
22 |
+
|
23 |
+
These files do not contain requirements for PyTorch. Because the versions of them depend on your environment. Please install PyTorch at first (see installation guide below.)
|
24 |
+
|
25 |
+
The scripts are tested with PyTorch 1.12.1 and 1.13.0, Diffusers 0.10.2.
|
26 |
+
|
27 |
+
## Links to how-to-use documents
|
28 |
+
|
29 |
+
All documents are in Japanese currently.
|
30 |
+
|
31 |
+
* [Training guide - common](./train_README-ja.md) : data preparation, options etc...
|
32 |
+
* [Dataset config](./config_README-ja.md)
|
33 |
+
* [DreamBooth training guide](./train_db_README-ja.md)
|
34 |
+
* [Step by Step fine-tuning guide](./fine_tune_README_ja.md):
|
35 |
+
* [training LoRA](./train_network_README-ja.md)
|
36 |
+
* [training Textual Inversion](./train_ti_README-ja.md)
|
37 |
+
* note.com [Image generation](https://note.com/kohya_ss/n/n2693183a798e)
|
38 |
+
* note.com [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad)
|
39 |
+
|
40 |
+
## Windows Required Dependencies
|
41 |
+
|
42 |
+
Python 3.10.6 and Git:
|
43 |
+
|
44 |
+
- Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
|
45 |
+
- git: https://git-scm.com/download/win
|
46 |
+
|
47 |
+
Give unrestricted script access to powershell so venv can work:
|
48 |
+
|
49 |
+
- Open an administrator powershell window
|
50 |
+
- Type `Set-ExecutionPolicy Unrestricted` and answer A
|
51 |
+
- Close admin powershell window
|
52 |
+
|
53 |
+
## Windows Installation
|
54 |
+
|
55 |
+
Open a regular Powershell terminal and type the following inside:
|
56 |
+
|
57 |
+
```powershell
|
58 |
+
git clone https://github.com/kohya-ss/sd-scripts.git
|
59 |
+
cd sd-scripts
|
60 |
+
|
61 |
+
python -m venv venv
|
62 |
+
.\venv\Scripts\activate
|
63 |
+
|
64 |
+
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
|
65 |
+
pip install --upgrade -r requirements.txt
|
66 |
+
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
67 |
+
|
68 |
+
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
|
69 |
+
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
|
70 |
+
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
|
71 |
+
|
72 |
+
accelerate config
|
73 |
+
```
|
74 |
+
|
75 |
+
update: ``python -m venv venv`` is seemed to be safer than ``python -m venv --system-site-packages venv`` (some user have packages in global python).
|
76 |
+
|
77 |
+
Answers to accelerate config:
|
78 |
+
|
79 |
+
```txt
|
80 |
+
- This machine
|
81 |
+
- No distributed training
|
82 |
+
- NO
|
83 |
+
- NO
|
84 |
+
- NO
|
85 |
+
- all
|
86 |
+
- fp16
|
87 |
+
```
|
88 |
+
|
89 |
+
note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is occurred in training. In this case, answer `0` for the 6th question:
|
90 |
+
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``
|
91 |
+
|
92 |
+
(Single GPU with id `0` will be used.)
|
93 |
+
|
94 |
+
### about PyTorch and xformers
|
95 |
+
|
96 |
+
Other versions of PyTorch and xformers seem to have problems with training.
|
97 |
+
If there is no other reason, please install the specified version.
|
98 |
+
|
99 |
+
## Upgrade
|
100 |
+
|
101 |
+
When a new release comes out you can upgrade your repo with the following command:
|
102 |
+
|
103 |
+
```powershell
|
104 |
+
cd sd-scripts
|
105 |
+
git pull
|
106 |
+
.\venv\Scripts\activate
|
107 |
+
pip install --use-pep517 --upgrade -r requirements.txt
|
108 |
+
```
|
109 |
+
|
110 |
+
Once the commands have completed successfully you should be ready to use the new version.
|
111 |
+
|
112 |
+
## Credits
|
113 |
+
|
114 |
+
The implementation for LoRA is based on [cloneofsimo's repo](https://github.com/cloneofsimo/lora). Thank you for great work!
|
115 |
+
|
116 |
+
The LoRA expansion to Conv2d 3x3 was initially released by cloneofsimo and its effectiveness was demonstrated at [LoCon](https://github.com/KohakuBlueleaf/LoCon) by KohakuBlueleaf. Thank you so much KohakuBlueleaf!
|
117 |
+
|
118 |
+
## License
|
119 |
+
|
120 |
+
The majority of scripts is licensed under ASL 2.0 (including codes from Diffusers, cloneofsimo's and LoCon), however portions of the project are available under separate license terms:
|
121 |
+
|
122 |
+
[Memory Efficient Attention Pytorch](https://github.com/lucidrains/memory-efficient-attention-pytorch): MIT
|
123 |
+
|
124 |
+
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT
|
125 |
+
|
126 |
+
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
|
127 |
+
|
128 |
+
## Change History
|
129 |
+
|
130 |
+
- 11 Mar. 2023, 2023/3/11:
|
131 |
+
- Fix `svd_merge_lora.py` causes an error about the device.
|
132 |
+
- `svd_merge_lora.py` でデバイス関連のエラーが発生する不具合を修正しました。
|
133 |
+
- 10 Mar. 2023, 2023/3/10: release v0.5.1
|
134 |
+
- Fix to LoRA modules in the model are same to the previous (before 0.5.0) if Conv2d-3x3 is disabled (no `conv_dim` arg, default).
|
135 |
+
- Conv2D with kernel size 1x1 in ResNet modules were accidentally included in v0.5.0.
|
136 |
+
- Trained models with v0.5.0 will work with Web UI's built-in LoRA and Additional Networks extension.
|
137 |
+
- Fix an issue that dim (rank) of LoRA module is limited to the in/out dimensions of the target Linear/Conv2d (in case of the dim > 320).
|
138 |
+
- `resize_lora.py` now have a feature to `dynamic resizing` which means each LoRA module can have different ranks (dims). Thanks to mgz-dev for this great work!
|
139 |
+
- The appropriate rank is selected based on the complexity of each module with an algorithm specified in the command line arguments. For details: https://github.com/kohya-ss/sd-scripts/pull/243
|
140 |
+
- Multiple GPUs training is finally supported in `train_network.py`. Thanks to ddPn08 to solve this long running issue!
|
141 |
+
- Dataset with fine-tuning method (with metadata json) now works without images if `.npz` files exist. Thanks to rvhfxb!
|
142 |
+
- `train_network.py` can work if the current directory is not the directory where the script is in. Thanks to mio2333!
|
143 |
+
- Fix `extract_lora_from_models.py` and `svd_merge_lora.py` doesn't work with higher rank (>320).
|
144 |
+
|
145 |
+
- LoRAのConv2d-3x3拡張を行わない場合(`conv_dim` を指定しない場合)、以前(v0.5.0)と同じ構成になるよう修正しました。
|
146 |
+
- ResNetのカーネルサイズ1x1のConv2dが誤って対象になっていました。
|
147 |
+
- ただv0.5.0で学習したモデルは Additional Networks 拡張、およびWeb UIのLoRA機能で問題なく使えると思われます。
|
148 |
+
- LoRAモジュールの dim (rank) が、対象モジュールの次元数以下に制限される不具合を修正しました(320より大きい dim を指定した場合)。
|
149 |
+
- `resize_lora.py` に `dynamic resizing` (リサイズ後の各LoRAモジュールが異なるrank (dim) を持てる機能)を追加しました。mgz-dev 氏の貢献に感謝します。
|
150 |
+
- 適切なランクがコマンドライン引数で指定したアルゴリズムにより自動的に選択されます。詳細はこちらをご覧ください: https://github.com/kohya-ss/sd-scripts/pull/243
|
151 |
+
- `train_network.py` でマルチGPU学習をサポートしました。長年の懸案を解決された ddPn08 氏に感謝します。
|
152 |
+
- fine-tuning方式のデータセット(メタデータ.jsonファイルを使うデータセット)で `.npz` が存在するときには画像がなくても動作するようになりました。rvhfxb 氏に感謝します。
|
153 |
+
- 他のディレクトリから `train_network.py` を呼び出しても動作するよう変更しました。 mio2333 氏に感謝します。
|
154 |
+
- `extract_lora_from_models.py` および `svd_merge_lora.py` が320より大きいrankを指定すると動かない不具合を修正しました。
|
155 |
+
|
156 |
+
- 9 Mar. 2023, 2023/3/9: release v0.5.0
|
157 |
+
- There may be problems due to major changes. If you cannot revert back to the previous version when problems occur, please do not update for a while.
|
158 |
+
- Minimum metadata (module name, dim, alpha and network_args) is recorded even with `--no_metadata`, issue https://github.com/kohya-ss/sd-scripts/issues/254
|
159 |
+
- `train_network.py` supports LoRA for Conv2d-3x3 (extended to conv2d with a kernel size not 1x1).
|
160 |
+
- Same as a current version of [LoCon](https://github.com/KohakuBlueleaf/LoCon). __Thank you very much KohakuBlueleaf for your help!__
|
161 |
+
- LoCon will be enhanced in the future. Compatibility for future versions is not guaranteed.
|
162 |
+
- Specify `--network_args` option like: `--network_args "conv_dim=4" "conv_alpha=1"`
|
163 |
+
- [Additional Networks extension](https://github.com/kohya-ss/sd-webui-additional-networks) version 0.5.0 or later is required to use 'LoRA for Conv2d-3x3' in Stable Diffusion web UI.
|
164 |
+
- __Stable Diffusion web UI built-in LoRA does not support 'LoRA for Conv2d-3x3' now. Consider carefully whether or not to use it.__
|
165 |
+
- Merging/extracting scripts also support LoRA for Conv2d-3x3.
|
166 |
+
- Free CUDA memory after sample generation to reduce VRAM usage, issue https://github.com/kohya-ss/sd-scripts/issues/260
|
167 |
+
- Empty caption doesn't cause error now, issue https://github.com/kohya-ss/sd-scripts/issues/258
|
168 |
+
- Fix sample generation is crashing in Textual Inversion training when using templates, or if height/width is not divisible by 8.
|
169 |
+
- Update documents (Japanese only).
|
170 |
+
|
171 |
+
- 大きく変更したため不具合があるかもしれません。問題が起きた時にスクリプトを前のバージョンに戻せない場合は、しばらく更新を控えてください。
|
172 |
+
- 最低限のメタデータ(module name, dim, alpha および network_args)が `--no_metadata` オプション指定時にも記録されます。issue https://github.com/kohya-ss/sd-scripts/issues/254
|
173 |
+
- `train_network.py` で LoRAの Conv2d-3x3 拡張に対応しました(カーネルサイズ1x1以外のConv2dにも対象範囲を拡大します)。
|
174 |
+
- 現在のバージョンの [LoCon](https://github.com/KohakuBlueleaf/LoCon) と同一の仕様です。__KohakuBlueleaf氏のご支援に深く感謝します。__
|
175 |
+
- LoCon が将来的に拡張された場合、それらのバージョンでの互換性は保証できません。
|
176 |
+
- `--network_args` オプションを `--network_args "conv_dim=4" "conv_alpha=1"` のように指定してください。
|
177 |
+
- Stable Diffusion web UI での使用には [Additional Networks extension](https://github.com/kohya-ss/sd-webui-additional-networks) のversion 0.5.0 以降が必要です。
|
178 |
+
- __Stable Diffusion web UI の LoRA 機能は LoRAの Conv2d-3x3 拡張に対応していないようです。使用するか否か慎重にご検討ください。__
|
179 |
+
- マージ、抽出のスクリプトについても LoRA の Conv2d-3x3 拡張に対応しました.
|
180 |
+
- サンプル画像生成後にCUDAメモリを解放しVRAM使用量を削減しました。 issue https://github.com/kohya-ss/sd-scripts/issues/260
|
181 |
+
- 空のキャプションが使えるようになりました。 issue https://github.com/kohya-ss/sd-scripts/issues/258
|
182 |
+
- Textual Inversion 学習でテンプレートを使ったとき、height/width が 8 で割り切れなかったときにサンプル画像生成がクラッシュするのを修正しました。
|
183 |
+
- ドキュメント類を更新しました。
|
184 |
+
|
185 |
+
- Sample image generation:
|
186 |
+
A prompt file might look like this, for example
|
187 |
+
|
188 |
+
```
|
189 |
+
# prompt 1
|
190 |
+
masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
|
191 |
+
|
192 |
+
# prompt 2
|
193 |
+
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
|
194 |
+
```
|
195 |
+
|
196 |
+
Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used.
|
197 |
+
|
198 |
+
* `--n` Negative prompt up to the next option.
|
199 |
+
* `--w` Specifies the width of the generated image.
|
200 |
+
* `--h` Specifies the height of the generated image.
|
201 |
+
* `--d` Specifies the seed of the generated image.
|
202 |
+
* `--l` Specifies the CFG scale of the generated image.
|
203 |
+
* `--s` Specifies the number of steps in the generation.
|
204 |
+
|
205 |
+
The prompt weighting such as `( )` and `[ ]` are not working.
|
206 |
+
|
207 |
+
- サンプル画像生成:
|
208 |
+
プロンプトファイルは例えば以下のようになります。
|
209 |
+
|
210 |
+
```
|
211 |
+
# prompt 1
|
212 |
+
masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
|
213 |
+
|
214 |
+
# prompt 2
|
215 |
+
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
|
216 |
+
```
|
217 |
+
|
218 |
+
`#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
|
219 |
+
|
220 |
+
* `--n` Negative prompt up to the next option.
|
221 |
+
* `--w` Specifies the width of the generated image.
|
222 |
+
* `--h` Specifies the height of the generated image.
|
223 |
+
* `--d` Specifies the seed of the generated image.
|
224 |
+
* `--l` Specifies the CFG scale of the generated image.
|
225 |
+
* `--s` Specifies the number of steps in the generation.
|
226 |
+
|
227 |
+
`( )` や `[ ]` などの重みづけは動作しません。
|
228 |
+
|
229 |
+
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
|
230 |
+
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
|
append_module.py
CHANGED
@@ -2,7 +2,19 @@ import argparse
|
|
2 |
import json
|
3 |
import shutil
|
4 |
import time
|
5 |
-
from typing import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from accelerate import Accelerator
|
7 |
from torch.autograd.function import Function
|
8 |
import glob
|
@@ -28,6 +40,7 @@ import safetensors.torch
|
|
28 |
|
29 |
import library.model_util as model_util
|
30 |
import library.train_util as train_util
|
|
|
31 |
|
32 |
#============================================================================================================
|
33 |
#AdafactorScheduleに暫定的にinitial_lrを層別に適用できるようにしたもの
|
@@ -115,6 +128,124 @@ def make_bucket_resolutions_fix(max_reso, min_reso, min_size=256, max_size=1024,
|
|
115 |
return area_size_resos_list, area_size_list
|
116 |
|
117 |
#============================================================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
#train_util 内より
|
119 |
#============================================================================================================
|
120 |
class BucketManager_append(train_util.BucketManager):
|
@@ -179,7 +310,7 @@ class BucketManager_append(train_util.BucketManager):
|
|
179 |
bucket_size_id_list.append(bucket_size_id + i + 1)
|
180 |
_min_error = 1000.
|
181 |
_min_id = bucket_size_id
|
182 |
-
for now_size_id in
|
183 |
self.predefined_aspect_ratios = self.predefined_aspect_ratios_list[now_size_id]
|
184 |
ar_errors = self.predefined_aspect_ratios - aspect_ratio
|
185 |
ar_error = np.abs(ar_errors).min()
|
@@ -253,13 +384,13 @@ class BucketManager_append(train_util.BucketManager):
|
|
253 |
return reso, resized_size, ar_error
|
254 |
|
255 |
class DreamBoothDataset(train_util.DreamBoothDataset):
|
256 |
-
def __init__(self,
|
257 |
print("use append DreamBoothDataset")
|
258 |
self.min_resolution = min_resolution
|
259 |
self.area_step = area_step
|
260 |
-
super().__init__(
|
261 |
-
|
262 |
-
|
263 |
def make_buckets(self):
|
264 |
'''
|
265 |
bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
|
@@ -352,40 +483,50 @@ class DreamBoothDataset(train_util.DreamBoothDataset):
|
|
352 |
self.shuffle_buckets()
|
353 |
self._length = len(self.buckets_indices)
|
354 |
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
dots.append(".npz")
|
369 |
-
for ext in dots:
|
370 |
-
if base == '*':
|
371 |
-
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
|
372 |
-
else:
|
373 |
-
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
374 |
-
return img_paths
|
375 |
-
|
376 |
#============================================================================================================
|
377 |
#networks.lora
|
378 |
#============================================================================================================
|
379 |
-
from networks.lora import LoRANetwork
|
380 |
-
def replace_prepare_optimizer_params(networks):
|
381 |
-
def prepare_optimizer_params(self, text_encoder_lr, unet_lr,
|
382 |
-
|
383 |
def enumerate_params(loras, lora_name=None):
|
384 |
params = []
|
385 |
for lora in loras:
|
386 |
if lora_name is not None:
|
387 |
-
|
388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
else:
|
390 |
params.extend(lora.parameters())
|
391 |
return params
|
@@ -393,6 +534,7 @@ def replace_prepare_optimizer_params(networks):
|
|
393 |
self.requires_grad_(True)
|
394 |
all_params = []
|
395 |
ret_scheduler_lr = []
|
|
|
396 |
|
397 |
if loranames is not None:
|
398 |
textencoder_names = [None]
|
@@ -405,37 +547,181 @@ def replace_prepare_optimizer_params(networks):
|
|
405 |
if self.text_encoder_loras:
|
406 |
for textencoder_name in textencoder_names:
|
407 |
param_data = {'params': enumerate_params(self.text_encoder_loras, lora_name=textencoder_name)}
|
|
|
408 |
if text_encoder_lr is not None:
|
409 |
param_data['lr'] = text_encoder_lr
|
410 |
-
|
411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
all_params.append(param_data)
|
413 |
|
414 |
if self.unet_loras:
|
415 |
for unet_name in unet_names:
|
416 |
param_data = {'params': enumerate_params(self.unet_loras, lora_name=unet_name)}
|
|
|
|
|
417 |
if unet_lr is not None:
|
418 |
param_data['lr'] = unet_lr
|
419 |
-
|
420 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
421 |
all_params.append(param_data)
|
422 |
|
423 |
-
return all_params, ret_scheduler_lr
|
424 |
-
|
425 |
-
|
|
|
|
|
426 |
|
427 |
#============================================================================================================
|
428 |
#新規追加
|
429 |
#============================================================================================================
|
430 |
def add_append_arguments(parser: argparse.ArgumentParser):
|
431 |
# for train_network_opt.py
|
432 |
-
parser.add_argument("--optimizer", type=str, default="AdamW", choices=["AdamW", "RAdam", "AdaBound", "AdaBelief", "AggMo", "AdamP", "Adastand", "Adastand_belief", "Apollo", "Lamb", "Ranger", "RangerVA", "Lookahead_Adam", "Lookahead_DiffGrad", "Yogi", "NovoGrad", "QHAdam", "DiffGrad", "MADGRAD", "Adafactor"], help="使用するoptimizerを指定する")
|
433 |
-
parser.add_argument("--optimizer_arg", type=str, default=None, nargs='*')
|
|
|
|
|
434 |
parser.add_argument("--split_lora_networks", action="store_true")
|
435 |
parser.add_argument("--split_lora_level", type=int, default=0, help="どれくらい細分化するかの設定 0がunetのみを層別に 1がunetを大枠で分割 2がtextencoder含めて層別")
|
|
|
|
|
436 |
parser.add_argument("--min_resolution", type=str, default=None)
|
437 |
parser.add_argument("--area_step", type=int, default=1)
|
438 |
parser.add_argument("--config", type=str, default=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
|
440 |
def create_split_names(split_flag, split_level):
|
441 |
split_names = None
|
@@ -446,14 +732,28 @@ def create_split_names(split_flag, split_level):
|
|
446 |
if split_level==1:
|
447 |
unet_names.append(f"lora_unet_down_blocks_")
|
448 |
unet_names.append(f"lora_unet_up_blocks_")
|
449 |
-
elif split_level==2 or split_level==0:
|
450 |
-
if split_level
|
451 |
text_encoder_names = []
|
452 |
for i in range(12):
|
453 |
text_encoder_names.append(f"lora_te_text_model_encoder_layers_{i}_")
|
454 |
-
|
455 |
-
|
456 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
457 |
split_names["text_encoder"] = text_encoder_names
|
458 |
split_names["unet"] = unet_names
|
459 |
return split_names
|
@@ -465,7 +765,7 @@ def get_config(parser):
|
|
465 |
import datetime
|
466 |
if os.path.splitext(args.config)[-1] == ".yaml":
|
467 |
args.config = os.path.splitext(args.config)[0]
|
468 |
-
config_path = f"
|
469 |
if os.path.exists(config_path):
|
470 |
print(f"{config_path} から設定を読み込み中...")
|
471 |
margs, rest = parser.parse_known_args()
|
@@ -486,19 +786,41 @@ def get_config(parser):
|
|
486 |
args_type_dic[key] = act.type
|
487 |
#データタイプの確認とargsにkeyの内容を代入していく
|
488 |
for key, v in configs.items():
|
489 |
-
if
|
490 |
-
if
|
491 |
-
|
492 |
-
|
493 |
-
v
|
494 |
-
|
495 |
-
|
496 |
if not type(v) == args_type_dic[key]:
|
497 |
v = args_type_dic[key](v)
|
498 |
-
|
499 |
#最後にデフォから指定が変わってるものを変更する
|
500 |
for key, v in change_def_dic.items():
|
501 |
args_dic[key] = v
|
502 |
else:
|
503 |
print(f"{config_path} が見つかりませんでした")
|
504 |
return args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import json
|
3 |
import shutil
|
4 |
import time
|
5 |
+
from typing import (
|
6 |
+
Dict,
|
7 |
+
List,
|
8 |
+
NamedTuple,
|
9 |
+
Optional,
|
10 |
+
Sequence,
|
11 |
+
Tuple,
|
12 |
+
Union,
|
13 |
+
)
|
14 |
+
from dataclasses import (
|
15 |
+
asdict,
|
16 |
+
dataclass,
|
17 |
+
)
|
18 |
from accelerate import Accelerator
|
19 |
from torch.autograd.function import Function
|
20 |
import glob
|
|
|
40 |
|
41 |
import library.model_util as model_util
|
42 |
import library.train_util as train_util
|
43 |
+
import library.config_util as config_util
|
44 |
|
45 |
#============================================================================================================
|
46 |
#AdafactorScheduleに暫定的にinitial_lrを層別に適用できるようにしたもの
|
|
|
128 |
return area_size_resos_list, area_size_list
|
129 |
|
130 |
#============================================================================================================
|
131 |
+
#config_util 内より
|
132 |
+
#============================================================================================================
|
133 |
+
@dataclass
|
134 |
+
class DreamBoothDatasetParams(config_util.DreamBoothDatasetParams):
|
135 |
+
min_resolution: Optional[Tuple[int, int]] = None
|
136 |
+
area_step : int = 2
|
137 |
+
|
138 |
+
class ConfigSanitizer(config_util.ConfigSanitizer):
|
139 |
+
#@config_util.curry
|
140 |
+
@staticmethod
|
141 |
+
def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
|
142 |
+
config_util.Schema(config_util.ExactSequence([klass, klass]))(value)
|
143 |
+
return tuple(value)
|
144 |
+
|
145 |
+
#@config_util.curry
|
146 |
+
@staticmethod
|
147 |
+
def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
|
148 |
+
config_util.Schema(config_util.Any(klass, config_util.ExactSequence([klass, klass])))(value)
|
149 |
+
try:
|
150 |
+
config_util.Schema(klass)(value)
|
151 |
+
return (value, value)
|
152 |
+
except:
|
153 |
+
return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
|
154 |
+
# datasets schema
|
155 |
+
DATASET_ASCENDABLE_SCHEMA = {
|
156 |
+
"batch_size": int,
|
157 |
+
"bucket_no_upscale": bool,
|
158 |
+
"bucket_reso_steps": int,
|
159 |
+
"enable_bucket": bool,
|
160 |
+
"max_bucket_reso": int,
|
161 |
+
"min_bucket_reso": int,
|
162 |
+
"resolution": config_util.functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
163 |
+
"min_resolution": config_util.functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
164 |
+
"area_step": int,
|
165 |
+
}
|
166 |
+
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None:
|
167 |
+
super().__init__(support_dreambooth, support_finetuning, support_dropout)
|
168 |
+
def _check(self):
|
169 |
+
print(self.db_dataset_schema)
|
170 |
+
|
171 |
+
class BlueprintGenerator(config_util.BlueprintGenerator):
|
172 |
+
def __init__(self, sanitizer: ConfigSanitizer):
|
173 |
+
config_util.DreamBoothDatasetParams = DreamBoothDatasetParams
|
174 |
+
super().__init__(sanitizer)
|
175 |
+
|
176 |
+
def generate_dataset_group_by_blueprint(dataset_group_blueprint: config_util.DatasetGroupBlueprint):
|
177 |
+
datasets: List[Union[DreamBoothDataset, train_util.FineTuningDataset]] = []
|
178 |
+
|
179 |
+
for dataset_blueprint in dataset_group_blueprint.datasets:
|
180 |
+
if dataset_blueprint.is_dreambooth:
|
181 |
+
subset_klass = train_util.DreamBoothSubset
|
182 |
+
dataset_klass = DreamBoothDataset
|
183 |
+
else:
|
184 |
+
subset_klass = train_util.FineTuningSubset
|
185 |
+
dataset_klass = train_util.FineTuningDataset
|
186 |
+
|
187 |
+
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
188 |
+
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
|
189 |
+
datasets.append(dataset)
|
190 |
+
|
191 |
+
# print info
|
192 |
+
info = ""
|
193 |
+
for i, dataset in enumerate(datasets):
|
194 |
+
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
195 |
+
info += config_util.dedent(f"""\
|
196 |
+
[Dataset {i}]
|
197 |
+
batch_size: {dataset.batch_size}
|
198 |
+
resolution: {(dataset.width, dataset.height)}
|
199 |
+
enable_bucket: {dataset.enable_bucket}
|
200 |
+
""")
|
201 |
+
|
202 |
+
if dataset.enable_bucket:
|
203 |
+
info += config_util.indent(config_util.dedent(f"""\
|
204 |
+
min_bucket_reso: {dataset.min_bucket_reso}
|
205 |
+
max_bucket_reso: {dataset.max_bucket_reso}
|
206 |
+
bucket_reso_steps: {dataset.bucket_reso_steps}
|
207 |
+
bucket_no_upscale: {dataset.bucket_no_upscale}
|
208 |
+
\n"""), " ")
|
209 |
+
else:
|
210 |
+
info += "\n"
|
211 |
+
|
212 |
+
for j, subset in enumerate(dataset.subsets):
|
213 |
+
info += config_util.indent(config_util.dedent(f"""\
|
214 |
+
[Subset {j} of Dataset {i}]
|
215 |
+
image_dir: "{subset.image_dir}"
|
216 |
+
image_count: {subset.img_count}
|
217 |
+
num_repeats: {subset.num_repeats}
|
218 |
+
shuffle_caption: {subset.shuffle_caption}
|
219 |
+
keep_tokens: {subset.keep_tokens}
|
220 |
+
caption_dropout_rate: {subset.caption_dropout_rate}
|
221 |
+
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
222 |
+
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
223 |
+
color_aug: {subset.color_aug}
|
224 |
+
flip_aug: {subset.flip_aug}
|
225 |
+
face_crop_aug_range: {subset.face_crop_aug_range}
|
226 |
+
random_crop: {subset.random_crop}
|
227 |
+
"""), " ")
|
228 |
+
|
229 |
+
if is_dreambooth:
|
230 |
+
info += config_util.indent(config_util.dedent(f"""\
|
231 |
+
is_reg: {subset.is_reg}
|
232 |
+
class_tokens: {subset.class_tokens}
|
233 |
+
caption_extension: {subset.caption_extension}
|
234 |
+
\n"""), " ")
|
235 |
+
else:
|
236 |
+
info += config_util.indent(config_util.dedent(f"""\
|
237 |
+
metadata_file: {subset.metadata_file}
|
238 |
+
\n"""), " ")
|
239 |
+
|
240 |
+
print(info)
|
241 |
+
|
242 |
+
# make buckets first because it determines the length of dataset
|
243 |
+
for i, dataset in enumerate(datasets):
|
244 |
+
print(f"[Dataset {i}]")
|
245 |
+
dataset.make_buckets()
|
246 |
+
|
247 |
+
return train_util.DatasetGroup(datasets)
|
248 |
+
#============================================================================================================
|
249 |
#train_util 内より
|
250 |
#============================================================================================================
|
251 |
class BucketManager_append(train_util.BucketManager):
|
|
|
310 |
bucket_size_id_list.append(bucket_size_id + i + 1)
|
311 |
_min_error = 1000.
|
312 |
_min_id = bucket_size_id
|
313 |
+
for now_size_id in bucket_size_id_list:
|
314 |
self.predefined_aspect_ratios = self.predefined_aspect_ratios_list[now_size_id]
|
315 |
ar_errors = self.predefined_aspect_ratios - aspect_ratio
|
316 |
ar_error = np.abs(ar_errors).min()
|
|
|
384 |
return reso, resized_size, ar_error
|
385 |
|
386 |
class DreamBoothDataset(train_util.DreamBoothDataset):
|
387 |
+
def __init__(self, subsets: Sequence[train_util.DreamBoothSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset, min_resolution=None, area_step=None) -> None:
|
388 |
print("use append DreamBoothDataset")
|
389 |
self.min_resolution = min_resolution
|
390 |
self.area_step = area_step
|
391 |
+
super().__init__(subsets, batch_size, tokenizer, max_token_length,
|
392 |
+
resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale,
|
393 |
+
prior_loss_weight, debug_dataset)
|
394 |
def make_buckets(self):
|
395 |
'''
|
396 |
bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
|
|
|
483 |
self.shuffle_buckets()
|
484 |
self._length = len(self.buckets_indices)
|
485 |
|
486 |
+
import transformers
|
487 |
+
from torch.optim import Optimizer
|
488 |
+
from diffusers.optimization import SchedulerType
|
489 |
+
from typing import Union
|
490 |
+
def get_scheduler_Adafactor(
|
491 |
+
name: Union[str, SchedulerType],
|
492 |
+
optimizer: Optimizer,
|
493 |
+
scheduler_arg: Dict
|
494 |
+
):
|
495 |
+
if name.startswith("adafactor"):
|
496 |
+
assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
|
497 |
+
print(scheduler_arg)
|
498 |
+
return AdafactorSchedule_append(optimizer, **scheduler_arg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
#============================================================================================================
|
500 |
#networks.lora
|
501 |
#============================================================================================================
|
502 |
+
#from networks.lora import LoRANetwork
|
503 |
+
def replace_prepare_optimizer_params(networks, network_module):
|
504 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, loranames=None, lr_dic=None, block_args_dic=None):
|
505 |
+
|
506 |
def enumerate_params(loras, lora_name=None):
|
507 |
params = []
|
508 |
for lora in loras:
|
509 |
if lora_name is not None:
|
510 |
+
get_param_flag = False
|
511 |
+
if "attentions" in lora_name or "lora_unet_up_blocks_0_resnets_2":
|
512 |
+
lora_names = [lora_name]
|
513 |
+
if "attentions" in lora_name:
|
514 |
+
lora_names.append(lora_name.replace("attentions", "resnets"))
|
515 |
+
elif "lora_unet_up_blocks_0_resnets_2" in lora_name:
|
516 |
+
lora_names.append("lora_unet_up_blocks_0_upsamplers_")
|
517 |
+
elif "lora_unet_up_blocks_1_attentions_2_" in lora_name:
|
518 |
+
lora_names.append("lora_unet_up_blocks_1_upsamplers_")
|
519 |
+
elif "lora_unet_up_blocks_2_attentions_2_" in lora_name:
|
520 |
+
lora_names.append("lora_unet_up_blocks_2_upsamplers_")
|
521 |
+
|
522 |
+
for _name in lora_names:
|
523 |
+
if _name in lora.lora_name:
|
524 |
+
get_param_flag = True
|
525 |
+
break
|
526 |
+
else:
|
527 |
+
if lora_name in lora.lora_name:
|
528 |
+
get_param_flag = True
|
529 |
+
if get_param_flag: params.extend(lora.parameters())
|
530 |
else:
|
531 |
params.extend(lora.parameters())
|
532 |
return params
|
|
|
534 |
self.requires_grad_(True)
|
535 |
all_params = []
|
536 |
ret_scheduler_lr = []
|
537 |
+
used_names = []
|
538 |
|
539 |
if loranames is not None:
|
540 |
textencoder_names = [None]
|
|
|
547 |
if self.text_encoder_loras:
|
548 |
for textencoder_name in textencoder_names:
|
549 |
param_data = {'params': enumerate_params(self.text_encoder_loras, lora_name=textencoder_name)}
|
550 |
+
used_names.append(textencoder_name)
|
551 |
if text_encoder_lr is not None:
|
552 |
param_data['lr'] = text_encoder_lr
|
553 |
+
if lr_dic is not None:
|
554 |
+
if textencoder_name in lr_dic:
|
555 |
+
param_data['lr'] = lr_dic[textencoder_name]
|
556 |
+
print(f"{textencoder_name} lr: {param_data['lr']}")
|
557 |
+
|
558 |
+
if block_args_dic is not None:
|
559 |
+
if "lora_te_" in block_args_dic:
|
560 |
+
for pname, value in block_args_dic["lora_te_"].items():
|
561 |
+
param_data[pname] = value
|
562 |
+
if textencoder_name in block_args_dic:
|
563 |
+
for pname, value in block_args_dic[textencoder_name].items():
|
564 |
+
param_data[pname] = value
|
565 |
+
|
566 |
+
if text_encoder_lr is not None:
|
567 |
+
ret_scheduler_lr.append(text_encoder_lr)
|
568 |
+
else:
|
569 |
+
ret_scheduler_lr.append(0.)
|
570 |
+
if lr_dic is not None:
|
571 |
+
if textencoder_name in lr_dic:
|
572 |
+
ret_scheduler_lr[-1] = lr_dic[textencoder_name]
|
573 |
all_params.append(param_data)
|
574 |
|
575 |
if self.unet_loras:
|
576 |
for unet_name in unet_names:
|
577 |
param_data = {'params': enumerate_params(self.unet_loras, lora_name=unet_name)}
|
578 |
+
if len(param_data["params"])==0: continue
|
579 |
+
used_names.append(unet_name)
|
580 |
if unet_lr is not None:
|
581 |
param_data['lr'] = unet_lr
|
582 |
+
if lr_dic is not None:
|
583 |
+
if unet_name in lr_dic:
|
584 |
+
param_data['lr'] = lr_dic[unet_name]
|
585 |
+
print(f"{unet_name} lr: {param_data['lr']}")
|
586 |
+
|
587 |
+
if block_args_dic is not None:
|
588 |
+
if "lora_unet_" in block_args_dic:
|
589 |
+
for pname, value in block_args_dic["lora_unet_"].items():
|
590 |
+
param_data[pname] = value
|
591 |
+
if unet_name in block_args_dic:
|
592 |
+
for pname, value in block_args_dic[unet_name].items():
|
593 |
+
param_data[pname] = value
|
594 |
+
|
595 |
+
if unet_lr is not None:
|
596 |
+
ret_scheduler_lr.append(unet_lr)
|
597 |
+
else:
|
598 |
+
ret_scheduler_lr.append(0.)
|
599 |
+
if lr_dic is not None:
|
600 |
+
if unet_name in lr_dic:
|
601 |
+
ret_scheduler_lr[-1] = lr_dic[unet_name]
|
602 |
all_params.append(param_data)
|
603 |
|
604 |
+
return all_params, {"initial_lr" : ret_scheduler_lr}, used_names
|
605 |
+
try:
|
606 |
+
network_module.LoRANetwork.prepare_optimizer_params = prepare_optimizer_params
|
607 |
+
except:
|
608 |
+
print("cant't replace prepare_optimizer_params")
|
609 |
|
610 |
#============================================================================================================
|
611 |
#新規追加
|
612 |
#============================================================================================================
|
613 |
def add_append_arguments(parser: argparse.ArgumentParser):
|
614 |
# for train_network_opt.py
|
615 |
+
#parser.add_argument("--optimizer", type=str, default="AdamW", choices=["AdamW", "RAdam", "AdaBound", "AdaBelief", "AggMo", "AdamP", "Adastand", "Adastand_belief", "Apollo", "Lamb", "Ranger", "RangerVA", "Lookahead_Adam", "Lookahead_DiffGrad", "Yogi", "NovoGrad", "QHAdam", "DiffGrad", "MADGRAD", "Adafactor"], help="使用するoptimizerを指定する")
|
616 |
+
#parser.add_argument("--optimizer_arg", type=str, default=None, nargs='*')
|
617 |
+
parser.add_argument("--use_lookahead", action="store_true")
|
618 |
+
parser.add_argument("--lookahead_arg", type=str, nargs="*", default=None)
|
619 |
parser.add_argument("--split_lora_networks", action="store_true")
|
620 |
parser.add_argument("--split_lora_level", type=int, default=0, help="どれくらい細分化するかの設定 0がunetのみを層別に 1がunetを大枠で分割 2がtextencoder含めて層別")
|
621 |
+
parser.add_argument("--blocks_lr_setting", type=str, default=None)
|
622 |
+
parser.add_argument("--block_optim_args", type=str, nargs="*", default=None)
|
623 |
parser.add_argument("--min_resolution", type=str, default=None)
|
624 |
parser.add_argument("--area_step", type=int, default=1)
|
625 |
parser.add_argument("--config", type=str, default=None)
|
626 |
+
parser.add_argument("--not_output_config", action="store_true")
|
627 |
+
|
628 |
+
class MyNetwork_Names:
|
629 |
+
ex_block_weight_dic = {
|
630 |
+
"BASE": ["te"],
|
631 |
+
"IN01": ["down_0_at_0","donw_0_res_0"], "IN02": ["down_0_at_1","down_0_res_1"], "IN03": ["down_0_down"],
|
632 |
+
"IN04": ["down_1_at_0","donw_1_res_0"], "IN05": ["down_1_at_1","donw_1_res_1"], "IN06": ["down_1_down"],
|
633 |
+
"IN07": ["down_2_at_0","donw_2_res_0"], "IN08": ["down_2_at_1","donw_2_res_1"], "IN09": ["down_2_down"],
|
634 |
+
"IN10": ["down_3_res_0"], "IN11": ["down_3_res_1"],
|
635 |
+
"MID": ["mid"],
|
636 |
+
"OUT00": ["up_0_res_0"], "OUT01": ["up_0_res_1"], "OUT02": ["up_0_res_2", "up_0_up"],
|
637 |
+
"OUT03": ["up_1_at_0", "up_1_res_0"], "OUT04": ["up_1_at_1", "up_1_res_1"], "OUT05": ["up_1_at_2", "up_1_res_2", "up_1_up"],
|
638 |
+
"OUT06": ["up_2_at_0", "up_2_res_0"], "OUT07": ["up_2_at_1", "up_2_res_1"], "OUT08": ["up_2_at_2", "up_2_res_2", "up_2_up"],
|
639 |
+
"OUT09": ["up_3_at_0", "up_3_res_0"], "OUT10": ["up_3_at_1", "up_3_res_1"], "OUT11": ["up_3_at_2", "up_3_res_2"],
|
640 |
+
}
|
641 |
+
|
642 |
+
blocks_name_dic = { "te": "lora_te_",
|
643 |
+
"unet": "lora_unet_",
|
644 |
+
"mid": "lora_unet_mid_block_",
|
645 |
+
"down": "lora_unet_down_blocks_",
|
646 |
+
"up": "lora_unet_up_blocks_"}
|
647 |
+
for i in range(12):
|
648 |
+
blocks_name_dic[f"te_{i}"] = f"lora_te_text_model_encoder_layers_{i}_"
|
649 |
+
for i in range(3):
|
650 |
+
blocks_name_dic[f"down_{i}"] = f"lora_unet_down_blocks_{i}"
|
651 |
+
blocks_name_dic[f"up_{i+1}"] = f"lora_unet_up_blocks_{i+1}"
|
652 |
+
for i in range(4):
|
653 |
+
for j in range(2):
|
654 |
+
if i<=2: blocks_name_dic[f"down_{i}_at_{j}"] = f"lora_unet_down_blocks_{i}_attentions_{j}_"
|
655 |
+
blocks_name_dic[f"down_{i}_res_{j}"] = f"lora_unet_down_blocks_{i}_resnets_{j}"
|
656 |
+
for j in range(3):
|
657 |
+
if i>=1: blocks_name_dic[f"up_{i}_at_{j}"] = f"lora_unet_up_blocks_{i}_attentions_{j}_"
|
658 |
+
blocks_name_dic[f"up_{i}_res_{j}"] = f"lora_unet_up_blocks_{i}_resnets_{j}"
|
659 |
+
if i<=2:
|
660 |
+
blocks_name_dic[f"down_{i}_down"] = f"lora_unet_down_blocks_{i}_downsamplers_"
|
661 |
+
blocks_name_dic[f"up_{i}_up"] = f"lora_unet_up_blocks_{i}_upsamplers_"
|
662 |
+
|
663 |
+
def create_lr_blocks(lr_setting_str=None, block_optim_args=None):
|
664 |
+
ex_block_weight_dic = MyNetwork_Names.ex_block_weight_dic
|
665 |
+
blocks_name_dic = MyNetwork_Names.blocks_name_dic
|
666 |
+
|
667 |
+
lr_dic = {}
|
668 |
+
if lr_setting_str==None or lr_setting_str=="":
|
669 |
+
pass
|
670 |
+
else:
|
671 |
+
lr_settings = lr_setting_str.replace(" ", "").split(",")
|
672 |
+
for lr_setting in lr_settings:
|
673 |
+
key, value = lr_setting.split("=")
|
674 |
+
if key in ex_block_weight_dic:
|
675 |
+
keys = ex_block_weight_dic[key]
|
676 |
+
else:
|
677 |
+
keys = [key]
|
678 |
+
for key in keys:
|
679 |
+
if key in blocks_name_dic:
|
680 |
+
new_key = blocks_name_dic[key]
|
681 |
+
lr_dic[new_key] = float(value)
|
682 |
+
if len(lr_dic)==0:
|
683 |
+
lr_dic = None
|
684 |
+
|
685 |
+
args_dic = {}
|
686 |
+
if (block_optim_args is None):
|
687 |
+
block_optim_args = []
|
688 |
+
if (len(block_optim_args)>0):
|
689 |
+
for my_arg in block_optim_args:
|
690 |
+
my_arg = my_arg.replace(" ", "")
|
691 |
+
splits = my_arg.split(":")
|
692 |
+
b_name = splits[0]
|
693 |
+
|
694 |
+
key, _value = splits[1].split("=")
|
695 |
+
value_type = float
|
696 |
+
if len(splits)==3:
|
697 |
+
if _value=="str":
|
698 |
+
value_type = str
|
699 |
+
elif _value=="int":
|
700 |
+
value_type = int
|
701 |
+
_value = splits[2]
|
702 |
+
if _value=="true" or _value=="false":
|
703 |
+
value_type = bool
|
704 |
+
if "," in _value:
|
705 |
+
_value = _value.split(",")
|
706 |
+
for i in range(len(_value)):
|
707 |
+
_value[i] = value_type(_value[i])
|
708 |
+
value=tuple(_value)
|
709 |
+
else:
|
710 |
+
value = value_type(_value)
|
711 |
+
|
712 |
+
if b_name in ex_block_weight_dic:
|
713 |
+
b_names = ex_block_weight_dic[b_name]
|
714 |
+
else:
|
715 |
+
b_names = [b_name]
|
716 |
+
for b_name in b_names:
|
717 |
+
new_b_name = blocks_name_dic[b_name]
|
718 |
+
if not new_b_name in args_dic:
|
719 |
+
args_dic[new_b_name] = {}
|
720 |
+
args_dic[new_b_name][key] = value
|
721 |
+
|
722 |
+
if len(args_dic)==0:
|
723 |
+
args_dic = None
|
724 |
+
return lr_dic, args_dic
|
725 |
|
726 |
def create_split_names(split_flag, split_level):
|
727 |
split_names = None
|
|
|
732 |
if split_level==1:
|
733 |
unet_names.append(f"lora_unet_down_blocks_")
|
734 |
unet_names.append(f"lora_unet_up_blocks_")
|
735 |
+
elif split_level==2 or split_level==0 or split_level==4:
|
736 |
+
if split_level>=2:
|
737 |
text_encoder_names = []
|
738 |
for i in range(12):
|
739 |
text_encoder_names.append(f"lora_te_text_model_encoder_layers_{i}_")
|
740 |
+
|
741 |
+
if split_level<=2:
|
742 |
+
for i in range(3):
|
743 |
+
unet_names.append(f"lora_unet_down_blocks_{i}")
|
744 |
+
unet_names.append(f"lora_unet_up_blocks_{i+1}")
|
745 |
+
|
746 |
+
if split_level>=3:
|
747 |
+
for i in range(4):
|
748 |
+
for j in range(2):
|
749 |
+
if i<=2: unet_names.append(f"lora_unet_down_blocks_{i}_attentions_{j}_")
|
750 |
+
if i== 3: unet_names.append(f"lora_unet_down_blocks_{i}_resnets_{j}")
|
751 |
+
for j in range(3):
|
752 |
+
if i>=1: unet_names.append(f"lora_unet_up_blocks_{i}_attentions_{j}_")
|
753 |
+
if i==0: unet_names.append(f"lora_unet_up_blocks_{i}_resnets_{j}")
|
754 |
+
if i<=2:
|
755 |
+
unet_names.append(f"lora_unet_down_blocks_{i}_downsamplers_")
|
756 |
+
|
757 |
split_names["text_encoder"] = text_encoder_names
|
758 |
split_names["unet"] = unet_names
|
759 |
return split_names
|
|
|
765 |
import datetime
|
766 |
if os.path.splitext(args.config)[-1] == ".yaml":
|
767 |
args.config = os.path.splitext(args.config)[0]
|
768 |
+
config_path = f"{args.config}.yaml"
|
769 |
if os.path.exists(config_path):
|
770 |
print(f"{config_path} から設定を読み込み中...")
|
771 |
margs, rest = parser.parse_known_args()
|
|
|
786 |
args_type_dic[key] = act.type
|
787 |
#データタイプの確認とargsにkeyの内容を代入していく
|
788 |
for key, v in configs.items():
|
789 |
+
if v is not None:
|
790 |
+
if key in args_dic:
|
791 |
+
if args_dic[key] is not None:
|
792 |
+
new_type = type(args_dic[key])
|
793 |
+
if (not type(v) == new_type) and (not new_type==list):
|
794 |
+
v = new_type(v)
|
795 |
+
else:
|
796 |
if not type(v) == args_type_dic[key]:
|
797 |
v = args_type_dic[key](v)
|
798 |
+
args_dic[key] = v
|
799 |
#最後にデフォから指定が変わってるものを変更する
|
800 |
for key, v in change_def_dic.items():
|
801 |
args_dic[key] = v
|
802 |
else:
|
803 |
print(f"{config_path} が見つかりませんでした")
|
804 |
return args
|
805 |
+
|
806 |
+
'''
|
807 |
+
class GradientReversalFunction(torch.autograd.Function):
|
808 |
+
@staticmethod
|
809 |
+
def forward(ctx, input_forward: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
810 |
+
ctx.save_for_backward(scale)
|
811 |
+
return input_forward
|
812 |
+
@staticmethod
|
813 |
+
def backward(ctx, grad_backward: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
814 |
+
scale, = ctx.saved_tensors
|
815 |
+
return scale * -grad_backward, None
|
816 |
+
|
817 |
+
class GradientReversal(torch.nn.Module):
|
818 |
+
def __init__(self, scale: float):
|
819 |
+
super(GradientReversal, self).__init__()
|
820 |
+
self.scale = torch.tensor(scale)
|
821 |
+
def forward(self, x: torch.Tensor, flag: bool = False) -> torch.Tensor:
|
822 |
+
if flag:
|
823 |
+
return x
|
824 |
+
else:
|
825 |
+
return GradientReversalFunction.apply(x, self.scale)
|
826 |
+
'''
|
bitsandbytes_windows/cextension.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ctypes as ct
|
2 |
+
from pathlib import Path
|
3 |
+
from warnings import warn
|
4 |
+
|
5 |
+
from .cuda_setup.main import evaluate_cuda_setup
|
6 |
+
|
7 |
+
|
8 |
+
class CUDALibrary_Singleton(object):
|
9 |
+
_instance = None
|
10 |
+
|
11 |
+
def __init__(self):
|
12 |
+
raise RuntimeError("Call get_instance() instead")
|
13 |
+
|
14 |
+
def initialize(self):
|
15 |
+
binary_name = evaluate_cuda_setup()
|
16 |
+
package_dir = Path(__file__).parent
|
17 |
+
binary_path = package_dir / binary_name
|
18 |
+
|
19 |
+
if not binary_path.exists():
|
20 |
+
print(f"CUDA SETUP: TODO: compile library for specific version: {binary_name}")
|
21 |
+
legacy_binary_name = "libbitsandbytes.so"
|
22 |
+
print(f"CUDA SETUP: Defaulting to {legacy_binary_name}...")
|
23 |
+
binary_path = package_dir / legacy_binary_name
|
24 |
+
if not binary_path.exists():
|
25 |
+
print('CUDA SETUP: CUDA detection failed. Either CUDA driver not installed, CUDA not installed, or you have multiple conflicting CUDA libraries!')
|
26 |
+
print('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.')
|
27 |
+
raise Exception('CUDA SETUP: Setup Failed!')
|
28 |
+
# self.lib = ct.cdll.LoadLibrary(binary_path)
|
29 |
+
self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$
|
30 |
+
else:
|
31 |
+
print(f"CUDA SETUP: Loading binary {binary_path}...")
|
32 |
+
# self.lib = ct.cdll.LoadLibrary(binary_path)
|
33 |
+
self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$
|
34 |
+
|
35 |
+
@classmethod
|
36 |
+
def get_instance(cls):
|
37 |
+
if cls._instance is None:
|
38 |
+
cls._instance = cls.__new__(cls)
|
39 |
+
cls._instance.initialize()
|
40 |
+
return cls._instance
|
41 |
+
|
42 |
+
|
43 |
+
lib = CUDALibrary_Singleton.get_instance().lib
|
44 |
+
try:
|
45 |
+
lib.cadam32bit_g32
|
46 |
+
lib.get_context.restype = ct.c_void_p
|
47 |
+
lib.get_cusparse.restype = ct.c_void_p
|
48 |
+
COMPILED_WITH_CUDA = True
|
49 |
+
except AttributeError:
|
50 |
+
warn(
|
51 |
+
"The installed version of bitsandbytes was compiled without GPU support. "
|
52 |
+
"8-bit optimizers and GPU quantization are unavailable."
|
53 |
+
)
|
54 |
+
COMPILED_WITH_CUDA = False
|
bitsandbytes_windows/libbitsandbytes_cpu.dll
ADDED
Binary file (76.3 kB). View file
|
|
bitsandbytes_windows/libbitsandbytes_cuda116.dll
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:88f7bd2916ca3effc43f88492f1e1b9088d13cb5be3b4a3a4aede6aa3bf8d412
|
3 |
+
size 4724224
|
bitsandbytes_windows/main.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
extract factors the build is dependent on:
|
3 |
+
[X] compute capability
|
4 |
+
[ ] TODO: Q - What if we have multiple GPUs of different makes?
|
5 |
+
- CUDA version
|
6 |
+
- Software:
|
7 |
+
- CPU-only: only CPU quantization functions (no optimizer, no matrix multiple)
|
8 |
+
- CuBLAS-LT: full-build 8-bit optimizer
|
9 |
+
- no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)
|
10 |
+
|
11 |
+
evaluation:
|
12 |
+
- if paths faulty, return meaningful error
|
13 |
+
- else:
|
14 |
+
- determine CUDA version
|
15 |
+
- determine capabilities
|
16 |
+
- based on that set the default path
|
17 |
+
"""
|
18 |
+
|
19 |
+
import ctypes
|
20 |
+
|
21 |
+
from .paths import determine_cuda_runtime_lib_path
|
22 |
+
|
23 |
+
|
24 |
+
def check_cuda_result(cuda, result_val):
|
25 |
+
# 3. Check for CUDA errors
|
26 |
+
if result_val != 0:
|
27 |
+
error_str = ctypes.c_char_p()
|
28 |
+
cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
|
29 |
+
print(f"CUDA exception! Error code: {error_str.value.decode()}")
|
30 |
+
|
31 |
+
def get_cuda_version(cuda, cudart_path):
|
32 |
+
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
|
33 |
+
try:
|
34 |
+
cudart = ctypes.CDLL(cudart_path)
|
35 |
+
except OSError:
|
36 |
+
# TODO: shouldn't we error or at least warn here?
|
37 |
+
print(f'ERROR: libcudart.so could not be read from path: {cudart_path}!')
|
38 |
+
return None
|
39 |
+
|
40 |
+
version = ctypes.c_int()
|
41 |
+
check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ctypes.byref(version)))
|
42 |
+
version = int(version.value)
|
43 |
+
major = version//1000
|
44 |
+
minor = (version-(major*1000))//10
|
45 |
+
|
46 |
+
if major < 11:
|
47 |
+
print('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
|
48 |
+
|
49 |
+
return f'{major}{minor}'
|
50 |
+
|
51 |
+
|
52 |
+
def get_cuda_lib_handle():
|
53 |
+
# 1. find libcuda.so library (GPU driver) (/usr/lib)
|
54 |
+
try:
|
55 |
+
cuda = ctypes.CDLL("libcuda.so")
|
56 |
+
except OSError:
|
57 |
+
# TODO: shouldn't we error or at least warn here?
|
58 |
+
print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
|
59 |
+
return None
|
60 |
+
check_cuda_result(cuda, cuda.cuInit(0))
|
61 |
+
|
62 |
+
return cuda
|
63 |
+
|
64 |
+
|
65 |
+
def get_compute_capabilities(cuda):
|
66 |
+
"""
|
67 |
+
1. find libcuda.so library (GPU driver) (/usr/lib)
|
68 |
+
init_device -> init variables -> call function by reference
|
69 |
+
2. call extern C function to determine CC
|
70 |
+
(https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
|
71 |
+
3. Check for CUDA errors
|
72 |
+
https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api
|
73 |
+
# bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
|
74 |
+
"""
|
75 |
+
|
76 |
+
|
77 |
+
nGpus = ctypes.c_int()
|
78 |
+
cc_major = ctypes.c_int()
|
79 |
+
cc_minor = ctypes.c_int()
|
80 |
+
|
81 |
+
device = ctypes.c_int()
|
82 |
+
|
83 |
+
check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
|
84 |
+
ccs = []
|
85 |
+
for i in range(nGpus.value):
|
86 |
+
check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i))
|
87 |
+
ref_major = ctypes.byref(cc_major)
|
88 |
+
ref_minor = ctypes.byref(cc_minor)
|
89 |
+
# 2. call extern C function to determine CC
|
90 |
+
check_cuda_result(
|
91 |
+
cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device)
|
92 |
+
)
|
93 |
+
ccs.append(f"{cc_major.value}.{cc_minor.value}")
|
94 |
+
|
95 |
+
return ccs
|
96 |
+
|
97 |
+
|
98 |
+
# def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error
|
99 |
+
def get_compute_capability(cuda):
|
100 |
+
"""
|
101 |
+
Extracts the highest compute capbility from all available GPUs, as compute
|
102 |
+
capabilities are downwards compatible. If no GPUs are detected, it returns
|
103 |
+
None.
|
104 |
+
"""
|
105 |
+
ccs = get_compute_capabilities(cuda)
|
106 |
+
if ccs is not None:
|
107 |
+
# TODO: handle different compute capabilities; for now, take the max
|
108 |
+
return ccs[-1]
|
109 |
+
return None
|
110 |
+
|
111 |
+
|
112 |
+
def evaluate_cuda_setup():
|
113 |
+
print('')
|
114 |
+
print('='*35 + 'BUG REPORT' + '='*35)
|
115 |
+
print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
|
116 |
+
print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
|
117 |
+
print('='*80)
|
118 |
+
return "libbitsandbytes_cuda116.dll" # $$$
|
119 |
+
|
120 |
+
binary_name = "libbitsandbytes_cpu.so"
|
121 |
+
#if not torch.cuda.is_available():
|
122 |
+
#print('No GPU detected. Loading CPU library...')
|
123 |
+
#return binary_name
|
124 |
+
|
125 |
+
cudart_path = determine_cuda_runtime_lib_path()
|
126 |
+
if cudart_path is None:
|
127 |
+
print(
|
128 |
+
"WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!"
|
129 |
+
)
|
130 |
+
return binary_name
|
131 |
+
|
132 |
+
print(f"CUDA SETUP: CUDA runtime path found: {cudart_path}")
|
133 |
+
cuda = get_cuda_lib_handle()
|
134 |
+
cc = get_compute_capability(cuda)
|
135 |
+
print(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
|
136 |
+
cuda_version_string = get_cuda_version(cuda, cudart_path)
|
137 |
+
|
138 |
+
|
139 |
+
if cc == '':
|
140 |
+
print(
|
141 |
+
"WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..."
|
142 |
+
)
|
143 |
+
return binary_name
|
144 |
+
|
145 |
+
# 7.5 is the minimum CC vor cublaslt
|
146 |
+
has_cublaslt = cc in ["7.5", "8.0", "8.6"]
|
147 |
+
|
148 |
+
# TODO:
|
149 |
+
# (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
|
150 |
+
# (2) Multiple CUDA versions installed
|
151 |
+
|
152 |
+
# we use ls -l instead of nvcc to determine the cuda version
|
153 |
+
# since most installations will have the libcudart.so installed, but not the compiler
|
154 |
+
print(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
|
155 |
+
|
156 |
+
def get_binary_name():
|
157 |
+
"if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
|
158 |
+
bin_base_name = "libbitsandbytes_cuda"
|
159 |
+
if has_cublaslt:
|
160 |
+
return f"{bin_base_name}{cuda_version_string}.so"
|
161 |
+
else:
|
162 |
+
return f"{bin_base_name}{cuda_version_string}_nocublaslt.so"
|
163 |
+
|
164 |
+
binary_name = get_binary_name()
|
165 |
+
|
166 |
+
return binary_name
|
config_README-ja.md
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
For non-Japanese speakers: this README is provided only in Japanese in the current state. Sorry for inconvenience. We will provide English version in the near future.
|
2 |
+
|
3 |
+
`--dataset_config` で渡すことができる設定ファイルに関する説明です。
|
4 |
+
|
5 |
+
## 概要
|
6 |
+
|
7 |
+
設定ファイルを渡すことにより、ユーザが細かい設定を行えるようにします。
|
8 |
+
|
9 |
+
* 複数のデータセットが設定可能になります
|
10 |
+
* 例えば `resolution` をデータセットごとに設定して、それらを混合して学習できます。
|
11 |
+
* DreamBooth の手法と fine tuning の手法の両方に対応している学習方法では、DreamBooth 方式と fine tuning 方式のデータセットを混合することが可能です。
|
12 |
+
* サブセットごとに設定を変更することが可能になります
|
13 |
+
* データセットを画像ディレクトリ別またはメタデータ別に分割したものがサブセットです。いくつかのサブセットが集まってデータセットを構成します。
|
14 |
+
* `keep_tokens` や `flip_aug` 等のオプションはサブセットごとに設定可能です。一方、`resolution` や `batch_size` といったオプションはデータセットごとに設定可能で、同じデータセットに属するサブセットでは値が共通になります。詳しくは後述します。
|
15 |
+
|
16 |
+
設定ファイルの形式は JSON か TOML を利用できます。記述のしやすさを考えると [TOML](https://toml.io/ja/v1.0.0-rc.2) を利用するのがオススメです。以下、TOML の利用を前提に説明します。
|
17 |
+
|
18 |
+
TOML で記述した設定ファイルの例です。
|
19 |
+
|
20 |
+
```toml
|
21 |
+
[general]
|
22 |
+
shuffle_caption = true
|
23 |
+
caption_extension = '.txt'
|
24 |
+
keep_tokens = 1
|
25 |
+
|
26 |
+
# これは DreamBooth 方式のデータセット
|
27 |
+
[[datasets]]
|
28 |
+
resolution = 512
|
29 |
+
batch_size = 4
|
30 |
+
keep_tokens = 2
|
31 |
+
|
32 |
+
[[datasets.subsets]]
|
33 |
+
image_dir = 'C:\hoge'
|
34 |
+
class_tokens = 'hoge girl'
|
35 |
+
# このサブセットは keep_tokens = 2 (所属する datasets の値が使われる)
|
36 |
+
|
37 |
+
[[datasets.subsets]]
|
38 |
+
image_dir = 'C:\fuga'
|
39 |
+
class_tokens = 'fuga boy'
|
40 |
+
keep_tokens = 3
|
41 |
+
|
42 |
+
[[datasets.subsets]]
|
43 |
+
is_reg = true
|
44 |
+
image_dir = 'C:\reg'
|
45 |
+
class_tokens = 'human'
|
46 |
+
keep_tokens = 1
|
47 |
+
|
48 |
+
# これは fine tuning 方式のデータセット
|
49 |
+
[[datasets]]
|
50 |
+
resolution = [768, 768]
|
51 |
+
batch_size = 2
|
52 |
+
|
53 |
+
[[datasets.subsets]]
|
54 |
+
image_dir = 'C:\piyo'
|
55 |
+
metadata_file = 'C:\piyo\piyo_md.json'
|
56 |
+
# このサブセットは keep_tokens = 1 (general の値が使われる)
|
57 |
+
```
|
58 |
+
|
59 |
+
この例では、3 つのディレクトリを DreamBooth 方式のデータセットとして 512x512 (batch size 4) で学習させ、1 つのディレクトリを fine tuning 方式のデータセットとして 768x768 (batch size 2) で学習させることになります。
|
60 |
+
|
61 |
+
## データセット・サブセットに関する設定
|
62 |
+
|
63 |
+
データセット・サブセットに関する設定は、登録可能な箇所がいくつかに分かれています。
|
64 |
+
|
65 |
+
* `[general]`
|
66 |
+
* 全データセットまたは全サブセットに適用されるオプションを指定する箇所です。
|
67 |
+
* データセットごとの設定及びサブセットごとの設定に同名のオプションが存在していた場合には、データセット・サブセットごとの設定が優先されます。
|
68 |
+
* `[[datasets]]`
|
69 |
+
* `datasets` はデータセットに関する設定の登録箇所になります。各データセットに個別に適用されるオプションを指定する箇所です。
|
70 |
+
* サブセットごとの設定が存在していた場合には、サブセットごとの設定が優先されます。
|
71 |
+
* `[[datasets.subsets]]`
|
72 |
+
* `datasets.subsets` はサブセットに関する設定の登録箇所になります。各サブセットに個別に適用されるオプションを指定する箇所です。
|
73 |
+
|
74 |
+
先程の例における、画像ディレクトリと登録箇所の対応に関するイメージ図です。
|
75 |
+
|
76 |
+
```
|
77 |
+
C:\
|
78 |
+
├─ hoge -> [[datasets.subsets]] No.1 ┐ ┐
|
79 |
+
├─ fuga -> [[datasets.subsets]] No.2 |-> [[datasets]] No.1 |-> [general]
|
80 |
+
├─ reg -> [[datasets.subsets]] No.3 ┘ |
|
81 |
+
└─ piyo -> [[datasets.subsets]] No.4 --> [[datasets]] No.2 ┘
|
82 |
+
```
|
83 |
+
|
84 |
+
画像ディレクトリがそれぞれ1つの `[[datasets.subsets]]` に対応しています。そして `[[datasets.subsets]]` が1つ以上組み合わさって1つの `[[datasets]]` を構成します。`[general]` には全ての `[[datasets]]`, `[[datasets.subsets]]` が属します。
|
85 |
+
|
86 |
+
登録箇所ごとに指定可能なオプションは異なりますが、同名のオプションが指定された場合は下位の登録箇所にある値が優先されます。先程の例の `keep_tokens` オプションの扱われ方を確認してもらうと理解しやすいかと思います。
|
87 |
+
|
88 |
+
加えて、学習方法が対応している手法によっても指定可能なオプションが変化します。
|
89 |
+
|
90 |
+
* DreamBooth 方式専用のオプション
|
91 |
+
* fine tuning 方式専用のオプション
|
92 |
+
* caption dropout の手法が使える場合のオプション
|
93 |
+
|
94 |
+
DreamBooth の手法と fine tuning の手法の両方とも利用可能な学習方法では、両者を併用することができます。
|
95 |
+
併用する際の注意点として、DreamBooth 方式なのか fine tuning 方式なのかはデータセット単位で判別を行っているため、同じデータセット中に DreamBooth 方式のサブセットと fine tuning 方式のサブセットを混在させることはできません。
|
96 |
+
つまり、これらを併用したい場合には異なる方式のサブセットが異なるデータセットに所属するように設定する必要があります。
|
97 |
+
|
98 |
+
プログラムの挙動としては、後述する `metadata_file` オプションが存在していたら fine tuning 方式のサブセットだと判断します。
|
99 |
+
そのため、同一のデータセットに所属するサブセットについて言うと、「全てが `metadata_file` オプションを持つ」か「全てが `metadata_file` オプションを持たない」かのどちらかになっていれば問題ありません。
|
100 |
+
|
101 |
+
以下、利用可能なオプションを説明します。コマンドライン引数と名称が同一のオプションについては、基本的に説明を割愛します。他の README を参照してください。
|
102 |
+
|
103 |
+
### 全学習方法で共通のオプション
|
104 |
+
|
105 |
+
学習方法によらずに指定可能なオプションです。
|
106 |
+
|
107 |
+
#### データセット向けオプション
|
108 |
+
|
109 |
+
データセットの設定に関わるオプションです。`datasets.subsets` には記述できません。
|
110 |
+
|
111 |
+
| オプション名 | 設定例 | `[general]` | `[[datasets]]` |
|
112 |
+
| ---- | ---- | ---- | ---- |
|
113 |
+
| `batch_size` | `1` | o | o |
|
114 |
+
| `bucket_no_upscale` | `true` | o | o |
|
115 |
+
| `bucket_reso_steps` | `64` | o | o |
|
116 |
+
| `enable_bucket` | `true` | o | o |
|
117 |
+
| `max_bucket_reso` | `1024` | o | o |
|
118 |
+
| `min_bucket_reso` | `128` | o | o |
|
119 |
+
| `resolution` | `256`, `[512, 512]` | o | o |
|
120 |
+
|
121 |
+
* `batch_size`
|
122 |
+
* コマンドライン引数の `--train_batch_size` と同等です。
|
123 |
+
|
124 |
+
これらの設定はデータセットごとに固定です。
|
125 |
+
つまり、データセットに所属するサブセットはこれらの設定を共有することになります。
|
126 |
+
例えば解像度が異なるデータセットを用意したい場合は、上に挙げた例のように別々のデータセットとして定義すれば別々の解像度を設定可能です。
|
127 |
+
|
128 |
+
#### サブセット向けオプション
|
129 |
+
|
130 |
+
サブセットの設定に関わるオプションです。
|
131 |
+
|
132 |
+
| オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
|
133 |
+
| ---- | ---- | ---- | ---- | ---- |
|
134 |
+
| `color_aug` | `false` | o | o | o |
|
135 |
+
| `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o |
|
136 |
+
| `flip_aug` | `true` | o | o | o |
|
137 |
+
| `keep_tokens` | `2` | o | o | o |
|
138 |
+
| `num_repeats` | `10` | o | o | o |
|
139 |
+
| `random_crop` | `false` | o | o | o |
|
140 |
+
| `shuffle_caption` | `true` | o | o | o |
|
141 |
+
|
142 |
+
* `num_repeats`
|
143 |
+
* サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
|
144 |
+
|
145 |
+
### DreamBooth 方式専用のオプション
|
146 |
+
|
147 |
+
DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。
|
148 |
+
|
149 |
+
#### サブセット向けオプション
|
150 |
+
|
151 |
+
DreamBooth 方式のサブセットの設定に関わるオプションです。
|
152 |
+
|
153 |
+
| オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
|
154 |
+
| ---- | ---- | ---- | ---- | ---- |
|
155 |
+
| `image_dir` | `‘C:\hoge’` | - | - | o(必須) |
|
156 |
+
| `caption_extension` | `".txt"` | o | o | o |
|
157 |
+
| `class_tokens` | `“sks girl”` | - | - | o |
|
158 |
+
| `is_reg` | `false` | - | - | o |
|
159 |
+
|
160 |
+
まず注意点として、 `image_dir` には画像ファイルが直下に置かれているパスを指定する必要があります。従来の DreamBooth の手法ではサブディレクトリに画像を置く必要がありましたが、そちらとは仕様に互換性がありません。また、`5_cat` のようなフォルダ名にしても、画像の繰り返し回数とクラス名は反映されません。これらを個別に設定したい場合、`num_repeats` と `class_tokens` で明示的に指定する必要があることに注意してください。
|
161 |
+
|
162 |
+
* `image_dir`
|
163 |
+
* 画像ディレクトリのパスを指定します。指定必須オプションです。
|
164 |
+
* 画像はディレクトリ直下に置かれている必要があります。
|
165 |
+
* `class_tokens`
|
166 |
+
* クラストークンを設定します。
|
167 |
+
* 画像に対応する caption ファイルが存在しない場合にのみ学習時に利用されます。利用するかどうかの判定は画像ごとに行います。`class_tokens` を指定しなかった場合に caption ファイル���見つからなかった場合にはエラーになります。
|
168 |
+
* `is_reg`
|
169 |
+
* サブセットの画像が正規化用かどうかを指定します。指定しなかった場合は `false` として、つまり正規化画像ではないとして扱います。
|
170 |
+
|
171 |
+
### fine tuning 方式専用のオプション
|
172 |
+
|
173 |
+
fine tuning 方式のオプションは、サブセット向けオプションのみ存在します。
|
174 |
+
|
175 |
+
#### サブセット向けオプション
|
176 |
+
|
177 |
+
fine tuning 方式のサブセットの設定に関わるオプションです。
|
178 |
+
|
179 |
+
| オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
|
180 |
+
| ---- | ---- | ---- | ---- | ---- |
|
181 |
+
| `image_dir` | `‘C:\hoge’` | - | - | o |
|
182 |
+
| `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o(必須) |
|
183 |
+
|
184 |
+
* `image_dir`
|
185 |
+
* 画像ディレクトリのパスを指定します。DreamBooth の手法の方とは異なり指定は必須ではありませんが、設定することを推奨します。
|
186 |
+
* 指定する必要がない状況としては、メタデータファイルの生成時に `--full_path` を付与して実行していた場合です。
|
187 |
+
* 画像はディレクトリ直下に置かれている必要があります。
|
188 |
+
* `metadata_file`
|
189 |
+
* サブセットで利用されるメタデータファイルのパスを指定します。指定必須オプションです。
|
190 |
+
* コマンドライン引数の `--in_json` と同等です。
|
191 |
+
* サブセットごとにメタデータファイルを指定する必要がある仕様上、ディレクトリを跨いだメタデータを1つのメタデータファイルとして作成することは避けた方が良いでしょう。画像ディレクトリごとにメタデータファイルを用意し、それらを別々のサブセットとして登録することを強く推奨します。
|
192 |
+
|
193 |
+
### caption dropout の手法が使える場合に指定可能なオプション
|
194 |
+
|
195 |
+
caption dropout の手法が使える場合のオプションは、サブセット向けオプションのみ存在します。
|
196 |
+
DreamBooth 方式か fine tuning 方式かに関わらず、caption dropout に対応している学習方法であれば指定可能です。
|
197 |
+
|
198 |
+
#### サブセット向けオプション
|
199 |
+
|
200 |
+
caption dropout が使えるサブセットの設定に関わるオプションです。
|
201 |
+
|
202 |
+
| オプション名 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
|
203 |
+
| ---- | ---- | ---- | ---- |
|
204 |
+
| `caption_dropout_every_n_epochs` | o | o | o |
|
205 |
+
| `caption_dropout_rate` | o | o | o |
|
206 |
+
| `caption_tag_dropout_rate` | o | o | o |
|
207 |
+
|
208 |
+
## 重複したサブセットが存在する時の挙動
|
209 |
+
|
210 |
+
DreamBooth 方式のデータセットの場合、その中にある `image_dir` が同一のサブセットは重複していると見なされます。
|
211 |
+
fine tuning 方式のデータセットの場合は、その中にある `metadata_file` が同一のサブセットは重複していると見なされます。
|
212 |
+
データセット中に重複したサブセットが存在する場合、2個目以降は無視されます。
|
213 |
+
|
214 |
+
一方、異なるデータセットに所属している場合は、重複しているとは見なされません。
|
215 |
+
例えば、以下のように同一の `image_dir` を持つサブセットを別々のデータセットに入れた場合には、重複していないと見なします。
|
216 |
+
これは、同じ画像でも異なる解像度で学習したい場合に役立ちます。
|
217 |
+
|
218 |
+
```toml
|
219 |
+
# 別々のデータセットに存在している場合は重複とは見なされず、両方とも学習に使われる
|
220 |
+
|
221 |
+
[[datasets]]
|
222 |
+
resolution = 512
|
223 |
+
|
224 |
+
[[datasets.subsets]]
|
225 |
+
image_dir = 'C:\hoge'
|
226 |
+
|
227 |
+
[[datasets]]
|
228 |
+
resolution = 768
|
229 |
+
|
230 |
+
[[datasets.subsets]]
|
231 |
+
image_dir = 'C:\hoge'
|
232 |
+
```
|
233 |
+
|
234 |
+
## コマンドライン引数との併用
|
235 |
+
|
236 |
+
設定ファイルのオプションの中には、コマンドライン引数のオプションと役割が重複しているものがあります。
|
237 |
+
|
238 |
+
以下に挙げるコマンドライン引数のオプションは、設定ファイルを渡した場合には無視されます。
|
239 |
+
|
240 |
+
* `--train_data_dir`
|
241 |
+
* `--reg_data_dir`
|
242 |
+
* `--in_json`
|
243 |
+
|
244 |
+
以下に挙げるコマンドライン引数のオプションは、コマンドライン引数と設定ファイルで同時に指定された場合、コマンドライン引数の値よりも設定ファイルの値が優先されます。特に断りがなければ同名のオプションとなります。
|
245 |
+
|
246 |
+
| コマンドライン引数のオプション | 優先される設定ファイルのオプション |
|
247 |
+
| ---------------------------------- | ---------------------------------- |
|
248 |
+
| `--bucket_no_upscale` | |
|
249 |
+
| `--bucket_reso_steps` | |
|
250 |
+
| `--caption_dropout_every_n_epochs` | |
|
251 |
+
| `--caption_dropout_rate` | |
|
252 |
+
| `--caption_extension` | |
|
253 |
+
| `--caption_tag_dropout_rate` | |
|
254 |
+
| `--color_aug` | |
|
255 |
+
| `--dataset_repeats` | `num_repeats` |
|
256 |
+
| `--enable_bucket` | |
|
257 |
+
| `--face_crop_aug_range` | |
|
258 |
+
| `--flip_aug` | |
|
259 |
+
| `--keep_tokens` | |
|
260 |
+
| `--min_bucket_reso` | |
|
261 |
+
| `--random_crop` | |
|
262 |
+
| `--resolution` | |
|
263 |
+
| `--shuffle_caption` | |
|
264 |
+
| `--train_batch_size` | `batch_size` |
|
265 |
+
|
266 |
+
## エラーの手引き
|
267 |
+
|
268 |
+
現在、外部ライブラリを利用して設定ファイルの記述が正しいかどうかをチェックしているのですが、整備が行き届いておらずエラーメッセージがわかりづらいという問題があります。
|
269 |
+
将来的にはこの問題の改善に取り組む予定です。
|
270 |
+
|
271 |
+
次善策として、頻出のエラーとその対処法について載せておきます。
|
272 |
+
正しいはずなのにエラーが出る場合、エラー内容がどうしても分からない場合は、バグかもしれないのでご連絡ください。
|
273 |
+
|
274 |
+
* `voluptuous.error.MultipleInvalid: required key not provided @ ...`: 指定必須のオプションが指定されていないというエラーです。指定を忘れているか、オプション名を間違って記述している可能性が高いです。
|
275 |
+
* `...` の箇所にはエラーが発生した場所が載っています。例えば `voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']` のようなエラーが出たら、0 番目の `datasets` 中の 0 番目の `subsets` の設定に `image_dir` が存在しないということになります。
|
276 |
+
* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: 指定する値の形式が不正というエラーです。値の形式が間違っている可能性が高いです。`int` の部分は対象となるオプションによって変わります。この README に載っているオプションの「設定例」が役立つかもしれません。
|
277 |
+
* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: 対応していないオプション名が存在している場合に発生するエラーです。オプション名を間違って記述しているか、誤って紛れ込んでいる可能性が高いです。
|
278 |
+
|
279 |
+
|
fine_tune.py
CHANGED
@@ -13,7 +13,11 @@ import diffusers
|
|
13 |
from diffusers import DDPMScheduler
|
14 |
|
15 |
import library.train_util as train_util
|
16 |
-
|
|
|
|
|
|
|
|
|
17 |
|
18 |
def collate_fn(examples):
|
19 |
return examples[0]
|
@@ -30,25 +34,36 @@ def train(args):
|
|
30 |
|
31 |
tokenizer = train_util.load_tokenizer(args)
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
if args.debug_dataset:
|
46 |
-
train_util.debug_dataset(
|
47 |
return
|
48 |
-
if len(
|
49 |
print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
|
50 |
return
|
51 |
|
|
|
|
|
|
|
52 |
# acceleratorを準備する
|
53 |
print("prepare accelerator")
|
54 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
@@ -109,7 +124,7 @@ def train(args):
|
|
109 |
vae.requires_grad_(False)
|
110 |
vae.eval()
|
111 |
with torch.no_grad():
|
112 |
-
|
113 |
vae.to("cpu")
|
114 |
if torch.cuda.is_available():
|
115 |
torch.cuda.empty_cache()
|
@@ -149,33 +164,13 @@ def train(args):
|
|
149 |
|
150 |
# 学習に必要なクラスを準備する
|
151 |
print("prepare optimizer, data loader etc.")
|
152 |
-
|
153 |
-
# 8-bit Adamを使う
|
154 |
-
if args.use_8bit_adam:
|
155 |
-
try:
|
156 |
-
import bitsandbytes as bnb
|
157 |
-
except ImportError:
|
158 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
159 |
-
print("use 8-bit Adam optimizer")
|
160 |
-
optimizer_class = bnb.optim.AdamW8bit
|
161 |
-
elif args.use_lion_optimizer:
|
162 |
-
try:
|
163 |
-
import lion_pytorch
|
164 |
-
except ImportError:
|
165 |
-
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
166 |
-
print("use Lion optimizer")
|
167 |
-
optimizer_class = lion_pytorch.Lion
|
168 |
-
else:
|
169 |
-
optimizer_class = torch.optim.AdamW
|
170 |
-
|
171 |
-
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
172 |
-
optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
|
173 |
|
174 |
# dataloaderを準備する
|
175 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
176 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
177 |
train_dataloader = torch.utils.data.DataLoader(
|
178 |
-
|
179 |
|
180 |
# 学習ステップ数を計算する
|
181 |
if args.max_train_epochs is not None:
|
@@ -183,8 +178,9 @@ def train(args):
|
|
183 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
184 |
|
185 |
# lr schedulerを用意する
|
186 |
-
lr_scheduler =
|
187 |
-
|
|
|
188 |
|
189 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
190 |
if args.full_fp16:
|
@@ -218,7 +214,7 @@ def train(args):
|
|
218 |
# 学習する
|
219 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
220 |
print("running training / 学習開始")
|
221 |
-
print(f" num examples / サンプル数: {
|
222 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
223 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
224 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
@@ -237,7 +233,7 @@ def train(args):
|
|
237 |
|
238 |
for epoch in range(num_train_epochs):
|
239 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
240 |
-
|
241 |
|
242 |
for m in training_models:
|
243 |
m.train()
|
@@ -286,11 +282,11 @@ def train(args):
|
|
286 |
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
287 |
|
288 |
accelerator.backward(loss)
|
289 |
-
if accelerator.sync_gradients:
|
290 |
params_to_clip = []
|
291 |
for m in training_models:
|
292 |
params_to_clip.extend(m.parameters())
|
293 |
-
accelerator.clip_grad_norm_(params_to_clip,
|
294 |
|
295 |
optimizer.step()
|
296 |
lr_scheduler.step()
|
@@ -301,11 +297,16 @@ def train(args):
|
|
301 |
progress_bar.update(1)
|
302 |
global_step += 1
|
303 |
|
|
|
|
|
304 |
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
305 |
if args.logging_dir is not None:
|
306 |
-
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
|
|
307 |
accelerator.log(logs, step=global_step)
|
308 |
|
|
|
309 |
loss_total += current_loss
|
310 |
avr_loss = loss_total / (step+1)
|
311 |
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
@@ -315,7 +316,7 @@ def train(args):
|
|
315 |
break
|
316 |
|
317 |
if args.logging_dir is not None:
|
318 |
-
logs = {"
|
319 |
accelerator.log(logs, step=epoch+1)
|
320 |
|
321 |
accelerator.wait_for_everyone()
|
@@ -325,6 +326,8 @@ def train(args):
|
|
325 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
326 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
327 |
|
|
|
|
|
328 |
is_main_process = accelerator.is_main_process
|
329 |
if is_main_process:
|
330 |
unet = unwrap_model(unet)
|
@@ -351,6 +354,8 @@ if __name__ == '__main__':
|
|
351 |
train_util.add_dataset_arguments(parser, False, True, True)
|
352 |
train_util.add_training_arguments(parser, False)
|
353 |
train_util.add_sd_saving_arguments(parser)
|
|
|
|
|
354 |
|
355 |
parser.add_argument("--diffusers_xformers", action='store_true',
|
356 |
help='use xformers by diffusers / Diffusersでxformersを使用する')
|
|
|
13 |
from diffusers import DDPMScheduler
|
14 |
|
15 |
import library.train_util as train_util
|
16 |
+
import library.config_util as config_util
|
17 |
+
from library.config_util import (
|
18 |
+
ConfigSanitizer,
|
19 |
+
BlueprintGenerator,
|
20 |
+
)
|
21 |
|
22 |
def collate_fn(examples):
|
23 |
return examples[0]
|
|
|
34 |
|
35 |
tokenizer = train_util.load_tokenizer(args)
|
36 |
|
37 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
|
38 |
+
if args.dataset_config is not None:
|
39 |
+
print(f"Load dataset config from {args.dataset_config}")
|
40 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
41 |
+
ignored = ["train_data_dir", "in_json"]
|
42 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
43 |
+
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
44 |
+
else:
|
45 |
+
user_config = {
|
46 |
+
"datasets": [{
|
47 |
+
"subsets": [{
|
48 |
+
"image_dir": args.train_data_dir,
|
49 |
+
"metadata_file": args.in_json,
|
50 |
+
}]
|
51 |
+
}]
|
52 |
+
}
|
53 |
+
|
54 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
55 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
56 |
|
57 |
if args.debug_dataset:
|
58 |
+
train_util.debug_dataset(train_dataset_group)
|
59 |
return
|
60 |
+
if len(train_dataset_group) == 0:
|
61 |
print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
|
62 |
return
|
63 |
|
64 |
+
if cache_latents:
|
65 |
+
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
66 |
+
|
67 |
# acceleratorを準備する
|
68 |
print("prepare accelerator")
|
69 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
|
|
124 |
vae.requires_grad_(False)
|
125 |
vae.eval()
|
126 |
with torch.no_grad():
|
127 |
+
train_dataset_group.cache_latents(vae)
|
128 |
vae.to("cpu")
|
129 |
if torch.cuda.is_available():
|
130 |
torch.cuda.empty_cache()
|
|
|
164 |
|
165 |
# 学習に必要なクラスを準備する
|
166 |
print("prepare optimizer, data loader etc.")
|
167 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
# dataloaderを準備する
|
170 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
171 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
172 |
train_dataloader = torch.utils.data.DataLoader(
|
173 |
+
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
174 |
|
175 |
# 学習ステップ数を計算する
|
176 |
if args.max_train_epochs is not None:
|
|
|
178 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
179 |
|
180 |
# lr schedulerを用意する
|
181 |
+
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
182 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
183 |
+
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
184 |
|
185 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
186 |
if args.full_fp16:
|
|
|
214 |
# 学習する
|
215 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
216 |
print("running training / 学習開始")
|
217 |
+
print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
218 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
219 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
220 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
|
|
233 |
|
234 |
for epoch in range(num_train_epochs):
|
235 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
236 |
+
train_dataset_group.set_current_epoch(epoch + 1)
|
237 |
|
238 |
for m in training_models:
|
239 |
m.train()
|
|
|
282 |
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
283 |
|
284 |
accelerator.backward(loss)
|
285 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
286 |
params_to_clip = []
|
287 |
for m in training_models:
|
288 |
params_to_clip.extend(m.parameters())
|
289 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
290 |
|
291 |
optimizer.step()
|
292 |
lr_scheduler.step()
|
|
|
297 |
progress_bar.update(1)
|
298 |
global_step += 1
|
299 |
|
300 |
+
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
301 |
+
|
302 |
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
303 |
if args.logging_dir is not None:
|
304 |
+
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
305 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
306 |
+
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
307 |
accelerator.log(logs, step=global_step)
|
308 |
|
309 |
+
# TODO moving averageにする
|
310 |
loss_total += current_loss
|
311 |
avr_loss = loss_total / (step+1)
|
312 |
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
316 |
break
|
317 |
|
318 |
if args.logging_dir is not None:
|
319 |
+
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
320 |
accelerator.log(logs, step=epoch+1)
|
321 |
|
322 |
accelerator.wait_for_everyone()
|
|
|
326 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
327 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
328 |
|
329 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
330 |
+
|
331 |
is_main_process = accelerator.is_main_process
|
332 |
if is_main_process:
|
333 |
unet = unwrap_model(unet)
|
|
|
354 |
train_util.add_dataset_arguments(parser, False, True, True)
|
355 |
train_util.add_training_arguments(parser, False)
|
356 |
train_util.add_sd_saving_arguments(parser)
|
357 |
+
train_util.add_optimizer_arguments(parser)
|
358 |
+
config_util.add_config_arguments(parser)
|
359 |
|
360 |
parser.add_argument("--diffusers_xformers", action='store_true',
|
361 |
help='use xformers by diffusers / Diffusersでxformersを使用する')
|
fine_tune_README_ja.md
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
NovelAIの提案した学習手法、自動キャプションニング、タグ付け、Windows+VRAM 12GB(SD v1.xの場合)環境等に対応したfine tuningです。ここでfine tuningとは、モデルを画像とキャプションで学習することを指します(LoRAやTextual Inversion、Hypernetworksは含みません)
|
2 |
+
|
3 |
+
[学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。
|
4 |
+
|
5 |
+
# 概要
|
6 |
+
|
7 |
+
Diffusersを用いてStable DiffusionのU-Netのfine tuningを行います。NovelAIの記事にある以下の改善に対応しています(Aspect Ratio BucketingについてはNovelAIのコードを参考にしましたが、最終的なコードはすべてオリジナルです)。
|
8 |
+
|
9 |
+
* CLIP(Text Encoder)の最後の層ではなく最後から二番目の層の出力を用いる。
|
10 |
+
* 正方形以外の解像度での学習(Aspect Ratio Bucketing) 。
|
11 |
+
* トークン長を75から225に拡張する。
|
12 |
+
* BLIPによるキャプショニング(キャプションの自動作成)、DeepDanbooruまたはWD14Taggerによる自動タグ付けを行う。
|
13 |
+
* Hypernetworkの学習にも対応する。
|
14 |
+
* Stable Diffusion v2.0(baseおよび768/v)に対応。
|
15 |
+
* VAEの出力をあらかじめ取得しディスクに保存しておくことで、学習の省メモリ化、高速化を図る。
|
16 |
+
|
17 |
+
デフォルトではText Encoderの学習は行いません。モデル全体のfine tuningではU-Netだけを学習するのが一般的なようです(NovelAIもそのようです)。オプション指定でText Encoderも学習対象とできます。
|
18 |
+
|
19 |
+
# 追加機能について
|
20 |
+
|
21 |
+
## CLIPの出力の変更
|
22 |
+
|
23 |
+
プロンプトを画像に反映するため、テキストの特徴量への変換を行うのがCLIP(Text Encoder)です。Stable DiffusionではCLIPの最後の層の出力を用いていますが、それを最後から二番目の層の出力を用いるよう変更できます。NovelAIによると、これによりより正確にプロンプトが反映されるようになるとのことです。
|
24 |
+
元のまま、最後の層の出力を用いることも可能です。
|
25 |
+
|
26 |
+
※Stable Diffusion 2.0では最後から二番目の層をデフォルトで使います。clip_skipオプションを指定しないでください。
|
27 |
+
|
28 |
+
## 正方形以外の解像度での学習
|
29 |
+
|
30 |
+
Stable Diffusionは512\*512で学習されていますが、それに加えて256\*1024や384\*640といった解像度でも学習します。これによりトリミングされる部分が減り、より正しくプロンプトと画像の関係が学習されることが期待されます。
|
31 |
+
学習解像度はパラメータとして与えられた解像度の面積(=メモリ使用量)を超えない範囲で、64ピクセル単位で縦横に調整、作成されます。
|
32 |
+
|
33 |
+
機械学習では入力サイズをすべて統一するのが一般的ですが、特に制約があるわけではなく、実際は同一のバッチ内で統一されていれば大丈夫です。NovelAIの言うbucketingは、あらかじめ教師データを、アスペクト比に応じた学習解像度ごとに分類しておくことを指しているようです。そしてバッチを各bucket内の画像で作成することで、バッチの画像サイズを統一します。
|
34 |
+
|
35 |
+
## トークン長の75から225への拡張
|
36 |
+
|
37 |
+
Stable Diffusionでは最大75トークン(開始・終了を含むと77トークン)ですが、それを225トークンまで拡張します。
|
38 |
+
ただしCLIPが受け付ける最大長は75トークンですので、225トークンの場合、単純に三分割してCLIPを呼び出してから結果を連結しています。
|
39 |
+
|
40 |
+
※これが望ましい実装なのかどうかはいまひとつわかりません。とりあえず動いてはいるようです。特に2.0では何も参考になる実装がないので独自に実装してあります。
|
41 |
+
|
42 |
+
※Automatic1111氏のWeb UIではカンマを意識して分割、といったこともしているようですが、私の場合はそこまでしておらず単純な分割です。
|
43 |
+
|
44 |
+
# 学習の手順
|
45 |
+
|
46 |
+
あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。
|
47 |
+
|
48 |
+
## データの準備
|
49 |
+
|
50 |
+
[学習データの準備について](./train_README-ja.md) を参照してください。fine tuningではメタデータを用いるfine tuning方式のみ対応しています。
|
51 |
+
|
52 |
+
## 学習の実行
|
53 |
+
たとえば以下のように実行します。以下は省メモリ化のための設定です。それぞれの行を必要に応じて書き換えてください。
|
54 |
+
|
55 |
+
```
|
56 |
+
accelerate launch --num_cpu_threads_per_process 1 fine_tune.py
|
57 |
+
--pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
|
58 |
+
--output_dir=<学習したモデルの出力先フォルダ>
|
59 |
+
--output_name=<学習したモデル出力時のファイル名>
|
60 |
+
--dataset_config=<データ準備で作成した.tomlファイル>
|
61 |
+
--save_model_as=safetensors
|
62 |
+
--learning_rate=5e-6 --max_train_steps=10000
|
63 |
+
--use_8bit_adam --xformers --gradient_checkpointing
|
64 |
+
--mixed_precision=fp16
|
65 |
+
```
|
66 |
+
|
67 |
+
`num_cpu_threads_per_process` には通常は1を指定するとよいようです。
|
68 |
+
|
69 |
+
`pretrained_model_name_or_path` に追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。
|
70 |
+
|
71 |
+
`output_dir` に学習後のモデルを保存するフォルダを指定します。`output_name` にモデルのファイル名を拡張子を除いて指定します。`save_model_as` でsafetensors形式での保存を指定しています。
|
72 |
+
|
73 |
+
`dataset_config` に `.toml` ファイルを指定します。ファイル内でのバッチサイズ指定は、当初はメモリ消費を抑えるために `1` としてください。
|
74 |
+
|
75 |
+
学習させるステップ数 `max_train_steps` を10000とします。学習率 `learning_rate` はここでは5e-6を指定しています。
|
76 |
+
|
77 |
+
省メモリ化のため `mixed_precision="fp16"` を指定します(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。また `gradient_checkpointing` を指定します。
|
78 |
+
|
79 |
+
オプティマイザ(モデルを学習データにあうように最適化=学習させるクラス)にメモリ消費の少ない 8bit AdamW を使うため、 `optimizer_type="AdamW8bit"` を指定します。
|
80 |
+
|
81 |
+
`xformers` オプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。
|
82 |
+
|
83 |
+
ある程度メモリがある場合は、`.toml` ファイルを編集してバッチサイズをたとえば `4` くらいに増やしてください(高速化と精度向上の可能性があります)。
|
84 |
+
|
85 |
+
### よく使われるオプションについて
|
86 |
+
|
87 |
+
以下の場合にはオプションに関するドキュメントを参照してください。
|
88 |
+
|
89 |
+
- Stable Diffusion 2.xまたはそこからの派生モデルを学習する
|
90 |
+
- clip skipを2以上を前提としたモデルを学習する
|
91 |
+
- 75トークンを超えたキャプションで学習する
|
92 |
+
|
93 |
+
### バッチサイズについて
|
94 |
+
|
95 |
+
モデル全体を学習するためLoRA等の学習に比べるとメモリ消費量は多くなります(DreamBoothと同じ)。
|
96 |
+
|
97 |
+
### 学習率について
|
98 |
+
|
99 |
+
1e-6から5e-6程度が一般的なようです。他のfine tuningの例なども参照してみてください。
|
100 |
+
|
101 |
+
### 以前の形式のデータセット指定をした場合のコマンドライン
|
102 |
+
|
103 |
+
解像度やバッチサイズをオプションで指定します。コマンドラインの例は以下の通りです。
|
104 |
+
|
105 |
+
```
|
106 |
+
accelerate launch --num_cpu_threads_per_process 1 fine_tune.py
|
107 |
+
--pretrained_model_name_or_path=model.ckpt
|
108 |
+
--in_json meta_lat.json
|
109 |
+
--train_data_dir=train_data
|
110 |
+
--output_dir=fine_tuned
|
111 |
+
--shuffle_caption
|
112 |
+
--train_batch_size=1 --learning_rate=5e-6 --max_train_steps=10000
|
113 |
+
--use_8bit_adam --xformers --gradient_checkpointing
|
114 |
+
--mixed_precision=bf16
|
115 |
+
--save_every_n_epochs=4
|
116 |
+
```
|
117 |
+
|
118 |
+
<!--
|
119 |
+
### 勾配をfp16とした学習(実験的機能)
|
120 |
+
full_fp16オプションを指定すると勾配を通常のfloat32からfloat16(fp16)に変更して学習します(mixed precisionではなく完全なfp16学習になるようです)。これによりSD1.xの512*512サイズでは8GB未満、SD2.xの512*512サイズで12GB未満のVRAM使用量で学習できるようです。
|
121 |
+
|
122 |
+
あらかじめaccelerate configでfp16を指定し、オプションでmixed_precision="fp16"としてください(bf16では動作しません)。
|
123 |
+
|
124 |
+
メモリ使用量を最小化するためには、xformers、use_8bit_adam、gradient_checkpointingの各オプションを指定し、train_batch_sizeを1としてください。
|
125 |
+
(余裕があるようならtrain_batch_sizeを段階的に増やすと若干精度が上がるはずです。)
|
126 |
+
|
127 |
+
PyTorchのソースにパッチを当てて無理やり実現しています(PyTorch 1.12.1と1.13.0で確認)。精度はかなり落ちますし、途中で学習失敗する確率も高くなります。学習率やステップ数の設定もシビアなようです。それらを認識したうえで自己責任でお使いください。
|
128 |
+
-->
|
129 |
+
|
130 |
+
# fine tuning特有のその他の主なオプション
|
131 |
+
|
132 |
+
すべてのオプションについては別文書を参照してください。
|
133 |
+
|
134 |
+
## `train_text_encoder`
|
135 |
+
Text Encoderも学習対象とします。メモリ使用量が若干増加します。
|
136 |
+
|
137 |
+
通常のfine tuningではText Encoderは学習対象としませんが(恐らくText Encoderの出力に従うようにU-Netを学習するため)、学習データ数が少ない場合には、DreamBoothのようにText Encoder側に学習させるのも有効的なようです。
|
138 |
+
|
139 |
+
## `diffusers_xformers`
|
140 |
+
スクリプト独自のxformers置換機能ではなくDiffusersのxformers機能を利用します。Hypernetworkの学習はできなくなります。
|
finetune/blip/blip.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import warnings
|
9 |
+
warnings.filterwarnings("ignore")
|
10 |
+
|
11 |
+
# from models.vit import VisionTransformer, interpolate_pos_embed
|
12 |
+
# from models.med import BertConfig, BertModel, BertLMHeadModel
|
13 |
+
from blip.vit import VisionTransformer, interpolate_pos_embed
|
14 |
+
from blip.med import BertConfig, BertModel, BertLMHeadModel
|
15 |
+
from transformers import BertTokenizer
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
|
21 |
+
import os
|
22 |
+
from urllib.parse import urlparse
|
23 |
+
from timm.models.hub import download_cached_file
|
24 |
+
|
25 |
+
class BLIP_Base(nn.Module):
|
26 |
+
def __init__(self,
|
27 |
+
med_config = 'configs/med_config.json',
|
28 |
+
image_size = 224,
|
29 |
+
vit = 'base',
|
30 |
+
vit_grad_ckpt = False,
|
31 |
+
vit_ckpt_layer = 0,
|
32 |
+
):
|
33 |
+
"""
|
34 |
+
Args:
|
35 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
36 |
+
image_size (int): input image size
|
37 |
+
vit (str): model size of vision transformer
|
38 |
+
"""
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
42 |
+
self.tokenizer = init_tokenizer()
|
43 |
+
med_config = BertConfig.from_json_file(med_config)
|
44 |
+
med_config.encoder_width = vision_width
|
45 |
+
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
46 |
+
|
47 |
+
|
48 |
+
def forward(self, image, caption, mode):
|
49 |
+
|
50 |
+
assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
|
51 |
+
text = self.tokenizer(caption, return_tensors="pt").to(image.device)
|
52 |
+
|
53 |
+
if mode=='image':
|
54 |
+
# return image features
|
55 |
+
image_embeds = self.visual_encoder(image)
|
56 |
+
return image_embeds
|
57 |
+
|
58 |
+
elif mode=='text':
|
59 |
+
# return text features
|
60 |
+
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
61 |
+
return_dict = True, mode = 'text')
|
62 |
+
return text_output.last_hidden_state
|
63 |
+
|
64 |
+
elif mode=='multimodal':
|
65 |
+
# return multimodel features
|
66 |
+
image_embeds = self.visual_encoder(image)
|
67 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
68 |
+
|
69 |
+
text.input_ids[:,0] = self.tokenizer.enc_token_id
|
70 |
+
output = self.text_encoder(text.input_ids,
|
71 |
+
attention_mask = text.attention_mask,
|
72 |
+
encoder_hidden_states = image_embeds,
|
73 |
+
encoder_attention_mask = image_atts,
|
74 |
+
return_dict = True,
|
75 |
+
)
|
76 |
+
return output.last_hidden_state
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
class BLIP_Decoder(nn.Module):
|
81 |
+
def __init__(self,
|
82 |
+
med_config = 'configs/med_config.json',
|
83 |
+
image_size = 384,
|
84 |
+
vit = 'base',
|
85 |
+
vit_grad_ckpt = False,
|
86 |
+
vit_ckpt_layer = 0,
|
87 |
+
prompt = 'a picture of ',
|
88 |
+
):
|
89 |
+
"""
|
90 |
+
Args:
|
91 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
92 |
+
image_size (int): input image size
|
93 |
+
vit (str): model size of vision transformer
|
94 |
+
"""
|
95 |
+
super().__init__()
|
96 |
+
|
97 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
98 |
+
self.tokenizer = init_tokenizer()
|
99 |
+
med_config = BertConfig.from_json_file(med_config)
|
100 |
+
med_config.encoder_width = vision_width
|
101 |
+
self.text_decoder = BertLMHeadModel(config=med_config)
|
102 |
+
|
103 |
+
self.prompt = prompt
|
104 |
+
self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
|
105 |
+
|
106 |
+
|
107 |
+
def forward(self, image, caption):
|
108 |
+
|
109 |
+
image_embeds = self.visual_encoder(image)
|
110 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
111 |
+
|
112 |
+
text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
|
113 |
+
|
114 |
+
text.input_ids[:,0] = self.tokenizer.bos_token_id
|
115 |
+
|
116 |
+
decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
|
117 |
+
decoder_targets[:,:self.prompt_length] = -100
|
118 |
+
|
119 |
+
decoder_output = self.text_decoder(text.input_ids,
|
120 |
+
attention_mask = text.attention_mask,
|
121 |
+
encoder_hidden_states = image_embeds,
|
122 |
+
encoder_attention_mask = image_atts,
|
123 |
+
labels = decoder_targets,
|
124 |
+
return_dict = True,
|
125 |
+
)
|
126 |
+
loss_lm = decoder_output.loss
|
127 |
+
|
128 |
+
return loss_lm
|
129 |
+
|
130 |
+
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
|
131 |
+
image_embeds = self.visual_encoder(image)
|
132 |
+
|
133 |
+
if not sample:
|
134 |
+
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
|
135 |
+
|
136 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
137 |
+
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
|
138 |
+
|
139 |
+
prompt = [self.prompt] * image.size(0)
|
140 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
|
141 |
+
input_ids[:,0] = self.tokenizer.bos_token_id
|
142 |
+
input_ids = input_ids[:, :-1]
|
143 |
+
|
144 |
+
if sample:
|
145 |
+
#nucleus sampling
|
146 |
+
outputs = self.text_decoder.generate(input_ids=input_ids,
|
147 |
+
max_length=max_length,
|
148 |
+
min_length=min_length,
|
149 |
+
do_sample=True,
|
150 |
+
top_p=top_p,
|
151 |
+
num_return_sequences=1,
|
152 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
153 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
154 |
+
repetition_penalty=1.1,
|
155 |
+
**model_kwargs)
|
156 |
+
else:
|
157 |
+
#beam search
|
158 |
+
outputs = self.text_decoder.generate(input_ids=input_ids,
|
159 |
+
max_length=max_length,
|
160 |
+
min_length=min_length,
|
161 |
+
num_beams=num_beams,
|
162 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
163 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
164 |
+
repetition_penalty=repetition_penalty,
|
165 |
+
**model_kwargs)
|
166 |
+
|
167 |
+
captions = []
|
168 |
+
for output in outputs:
|
169 |
+
caption = self.tokenizer.decode(output, skip_special_tokens=True)
|
170 |
+
captions.append(caption[len(self.prompt):])
|
171 |
+
return captions
|
172 |
+
|
173 |
+
|
174 |
+
def blip_decoder(pretrained='',**kwargs):
|
175 |
+
model = BLIP_Decoder(**kwargs)
|
176 |
+
if pretrained:
|
177 |
+
model,msg = load_checkpoint(model,pretrained)
|
178 |
+
assert(len(msg.missing_keys)==0)
|
179 |
+
return model
|
180 |
+
|
181 |
+
def blip_feature_extractor(pretrained='',**kwargs):
|
182 |
+
model = BLIP_Base(**kwargs)
|
183 |
+
if pretrained:
|
184 |
+
model,msg = load_checkpoint(model,pretrained)
|
185 |
+
assert(len(msg.missing_keys)==0)
|
186 |
+
return model
|
187 |
+
|
188 |
+
def init_tokenizer():
|
189 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
190 |
+
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
191 |
+
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
192 |
+
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
193 |
+
return tokenizer
|
194 |
+
|
195 |
+
|
196 |
+
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
|
197 |
+
|
198 |
+
assert vit in ['base', 'large'], "vit parameter must be base or large"
|
199 |
+
if vit=='base':
|
200 |
+
vision_width = 768
|
201 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
|
202 |
+
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
203 |
+
drop_path_rate=0 or drop_path_rate
|
204 |
+
)
|
205 |
+
elif vit=='large':
|
206 |
+
vision_width = 1024
|
207 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
|
208 |
+
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
209 |
+
drop_path_rate=0.1 or drop_path_rate
|
210 |
+
)
|
211 |
+
return visual_encoder, vision_width
|
212 |
+
|
213 |
+
def is_url(url_or_filename):
|
214 |
+
parsed = urlparse(url_or_filename)
|
215 |
+
return parsed.scheme in ("http", "https")
|
216 |
+
|
217 |
+
def load_checkpoint(model,url_or_filename):
|
218 |
+
if is_url(url_or_filename):
|
219 |
+
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
220 |
+
checkpoint = torch.load(cached_file, map_location='cpu')
|
221 |
+
elif os.path.isfile(url_or_filename):
|
222 |
+
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
223 |
+
else:
|
224 |
+
raise RuntimeError('checkpoint url or path is invalid')
|
225 |
+
|
226 |
+
state_dict = checkpoint['model']
|
227 |
+
|
228 |
+
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
229 |
+
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
|
230 |
+
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
|
231 |
+
model.visual_encoder_m)
|
232 |
+
for key in model.state_dict().keys():
|
233 |
+
if key in state_dict.keys():
|
234 |
+
if state_dict[key].shape!=model.state_dict()[key].shape:
|
235 |
+
del state_dict[key]
|
236 |
+
|
237 |
+
msg = model.load_state_dict(state_dict,strict=False)
|
238 |
+
print('load checkpoint from %s'%url_or_filename)
|
239 |
+
return model,msg
|
240 |
+
|
finetune/blip/med.py
ADDED
@@ -0,0 +1,955 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
* Based on huggingface code base
|
8 |
+
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
9 |
+
'''
|
10 |
+
|
11 |
+
import math
|
12 |
+
import os
|
13 |
+
import warnings
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Optional, Tuple
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import Tensor, device, dtype, nn
|
19 |
+
import torch.utils.checkpoint
|
20 |
+
from torch import nn
|
21 |
+
from torch.nn import CrossEntropyLoss
|
22 |
+
import torch.nn.functional as F
|
23 |
+
|
24 |
+
from transformers.activations import ACT2FN
|
25 |
+
from transformers.file_utils import (
|
26 |
+
ModelOutput,
|
27 |
+
)
|
28 |
+
from transformers.modeling_outputs import (
|
29 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
30 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
31 |
+
CausalLMOutputWithCrossAttentions,
|
32 |
+
MaskedLMOutput,
|
33 |
+
MultipleChoiceModelOutput,
|
34 |
+
NextSentencePredictorOutput,
|
35 |
+
QuestionAnsweringModelOutput,
|
36 |
+
SequenceClassifierOutput,
|
37 |
+
TokenClassifierOutput,
|
38 |
+
)
|
39 |
+
from transformers.modeling_utils import (
|
40 |
+
PreTrainedModel,
|
41 |
+
apply_chunking_to_forward,
|
42 |
+
find_pruneable_heads_and_indices,
|
43 |
+
prune_linear_layer,
|
44 |
+
)
|
45 |
+
from transformers.utils import logging
|
46 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
47 |
+
|
48 |
+
|
49 |
+
logger = logging.get_logger(__name__)
|
50 |
+
|
51 |
+
|
52 |
+
class BertEmbeddings(nn.Module):
|
53 |
+
"""Construct the embeddings from word and position embeddings."""
|
54 |
+
|
55 |
+
def __init__(self, config):
|
56 |
+
super().__init__()
|
57 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
58 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
59 |
+
|
60 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
61 |
+
# any TensorFlow checkpoint file
|
62 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
63 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
64 |
+
|
65 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
66 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
67 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
68 |
+
|
69 |
+
self.config = config
|
70 |
+
|
71 |
+
def forward(
|
72 |
+
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
73 |
+
):
|
74 |
+
if input_ids is not None:
|
75 |
+
input_shape = input_ids.size()
|
76 |
+
else:
|
77 |
+
input_shape = inputs_embeds.size()[:-1]
|
78 |
+
|
79 |
+
seq_length = input_shape[1]
|
80 |
+
|
81 |
+
if position_ids is None:
|
82 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
83 |
+
|
84 |
+
if inputs_embeds is None:
|
85 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
86 |
+
|
87 |
+
embeddings = inputs_embeds
|
88 |
+
|
89 |
+
if self.position_embedding_type == "absolute":
|
90 |
+
position_embeddings = self.position_embeddings(position_ids)
|
91 |
+
embeddings += position_embeddings
|
92 |
+
embeddings = self.LayerNorm(embeddings)
|
93 |
+
embeddings = self.dropout(embeddings)
|
94 |
+
return embeddings
|
95 |
+
|
96 |
+
|
97 |
+
class BertSelfAttention(nn.Module):
|
98 |
+
def __init__(self, config, is_cross_attention):
|
99 |
+
super().__init__()
|
100 |
+
self.config = config
|
101 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
102 |
+
raise ValueError(
|
103 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
104 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
105 |
+
)
|
106 |
+
|
107 |
+
self.num_attention_heads = config.num_attention_heads
|
108 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
109 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
110 |
+
|
111 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
112 |
+
if is_cross_attention:
|
113 |
+
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
114 |
+
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
115 |
+
else:
|
116 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
117 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
118 |
+
|
119 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
120 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
121 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
122 |
+
self.max_position_embeddings = config.max_position_embeddings
|
123 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
124 |
+
self.save_attention = False
|
125 |
+
|
126 |
+
def save_attn_gradients(self, attn_gradients):
|
127 |
+
self.attn_gradients = attn_gradients
|
128 |
+
|
129 |
+
def get_attn_gradients(self):
|
130 |
+
return self.attn_gradients
|
131 |
+
|
132 |
+
def save_attention_map(self, attention_map):
|
133 |
+
self.attention_map = attention_map
|
134 |
+
|
135 |
+
def get_attention_map(self):
|
136 |
+
return self.attention_map
|
137 |
+
|
138 |
+
def transpose_for_scores(self, x):
|
139 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
140 |
+
x = x.view(*new_x_shape)
|
141 |
+
return x.permute(0, 2, 1, 3)
|
142 |
+
|
143 |
+
def forward(
|
144 |
+
self,
|
145 |
+
hidden_states,
|
146 |
+
attention_mask=None,
|
147 |
+
head_mask=None,
|
148 |
+
encoder_hidden_states=None,
|
149 |
+
encoder_attention_mask=None,
|
150 |
+
past_key_value=None,
|
151 |
+
output_attentions=False,
|
152 |
+
):
|
153 |
+
mixed_query_layer = self.query(hidden_states)
|
154 |
+
|
155 |
+
# If this is instantiated as a cross-attention module, the keys
|
156 |
+
# and values come from an encoder; the attention mask needs to be
|
157 |
+
# such that the encoder's padding tokens are not attended to.
|
158 |
+
is_cross_attention = encoder_hidden_states is not None
|
159 |
+
|
160 |
+
if is_cross_attention:
|
161 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
162 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
163 |
+
attention_mask = encoder_attention_mask
|
164 |
+
elif past_key_value is not None:
|
165 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
166 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
167 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
168 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
169 |
+
else:
|
170 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
171 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
172 |
+
|
173 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
174 |
+
|
175 |
+
past_key_value = (key_layer, value_layer)
|
176 |
+
|
177 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
178 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
179 |
+
|
180 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
181 |
+
seq_length = hidden_states.size()[1]
|
182 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
183 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
184 |
+
distance = position_ids_l - position_ids_r
|
185 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
186 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
187 |
+
|
188 |
+
if self.position_embedding_type == "relative_key":
|
189 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
190 |
+
attention_scores = attention_scores + relative_position_scores
|
191 |
+
elif self.position_embedding_type == "relative_key_query":
|
192 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
193 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
194 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
195 |
+
|
196 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
197 |
+
if attention_mask is not None:
|
198 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
199 |
+
attention_scores = attention_scores + attention_mask
|
200 |
+
|
201 |
+
# Normalize the attention scores to probabilities.
|
202 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
203 |
+
|
204 |
+
if is_cross_attention and self.save_attention:
|
205 |
+
self.save_attention_map(attention_probs)
|
206 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
207 |
+
|
208 |
+
# This is actually dropping out entire tokens to attend to, which might
|
209 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
210 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
211 |
+
|
212 |
+
# Mask heads if we want to
|
213 |
+
if head_mask is not None:
|
214 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
215 |
+
|
216 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
217 |
+
|
218 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
219 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
220 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
221 |
+
|
222 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
223 |
+
|
224 |
+
outputs = outputs + (past_key_value,)
|
225 |
+
return outputs
|
226 |
+
|
227 |
+
|
228 |
+
class BertSelfOutput(nn.Module):
|
229 |
+
def __init__(self, config):
|
230 |
+
super().__init__()
|
231 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
232 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
233 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
234 |
+
|
235 |
+
def forward(self, hidden_states, input_tensor):
|
236 |
+
hidden_states = self.dense(hidden_states)
|
237 |
+
hidden_states = self.dropout(hidden_states)
|
238 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
239 |
+
return hidden_states
|
240 |
+
|
241 |
+
|
242 |
+
class BertAttention(nn.Module):
|
243 |
+
def __init__(self, config, is_cross_attention=False):
|
244 |
+
super().__init__()
|
245 |
+
self.self = BertSelfAttention(config, is_cross_attention)
|
246 |
+
self.output = BertSelfOutput(config)
|
247 |
+
self.pruned_heads = set()
|
248 |
+
|
249 |
+
def prune_heads(self, heads):
|
250 |
+
if len(heads) == 0:
|
251 |
+
return
|
252 |
+
heads, index = find_pruneable_heads_and_indices(
|
253 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
254 |
+
)
|
255 |
+
|
256 |
+
# Prune linear layers
|
257 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
258 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
259 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
260 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
261 |
+
|
262 |
+
# Update hyper params and store pruned heads
|
263 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
264 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
265 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
266 |
+
|
267 |
+
def forward(
|
268 |
+
self,
|
269 |
+
hidden_states,
|
270 |
+
attention_mask=None,
|
271 |
+
head_mask=None,
|
272 |
+
encoder_hidden_states=None,
|
273 |
+
encoder_attention_mask=None,
|
274 |
+
past_key_value=None,
|
275 |
+
output_attentions=False,
|
276 |
+
):
|
277 |
+
self_outputs = self.self(
|
278 |
+
hidden_states,
|
279 |
+
attention_mask,
|
280 |
+
head_mask,
|
281 |
+
encoder_hidden_states,
|
282 |
+
encoder_attention_mask,
|
283 |
+
past_key_value,
|
284 |
+
output_attentions,
|
285 |
+
)
|
286 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
287 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
288 |
+
return outputs
|
289 |
+
|
290 |
+
|
291 |
+
class BertIntermediate(nn.Module):
|
292 |
+
def __init__(self, config):
|
293 |
+
super().__init__()
|
294 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
295 |
+
if isinstance(config.hidden_act, str):
|
296 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
297 |
+
else:
|
298 |
+
self.intermediate_act_fn = config.hidden_act
|
299 |
+
|
300 |
+
def forward(self, hidden_states):
|
301 |
+
hidden_states = self.dense(hidden_states)
|
302 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
303 |
+
return hidden_states
|
304 |
+
|
305 |
+
|
306 |
+
class BertOutput(nn.Module):
|
307 |
+
def __init__(self, config):
|
308 |
+
super().__init__()
|
309 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
310 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
311 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
312 |
+
|
313 |
+
def forward(self, hidden_states, input_tensor):
|
314 |
+
hidden_states = self.dense(hidden_states)
|
315 |
+
hidden_states = self.dropout(hidden_states)
|
316 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
317 |
+
return hidden_states
|
318 |
+
|
319 |
+
|
320 |
+
class BertLayer(nn.Module):
|
321 |
+
def __init__(self, config, layer_num):
|
322 |
+
super().__init__()
|
323 |
+
self.config = config
|
324 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
325 |
+
self.seq_len_dim = 1
|
326 |
+
self.attention = BertAttention(config)
|
327 |
+
self.layer_num = layer_num
|
328 |
+
if self.config.add_cross_attention:
|
329 |
+
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
|
330 |
+
self.intermediate = BertIntermediate(config)
|
331 |
+
self.output = BertOutput(config)
|
332 |
+
|
333 |
+
def forward(
|
334 |
+
self,
|
335 |
+
hidden_states,
|
336 |
+
attention_mask=None,
|
337 |
+
head_mask=None,
|
338 |
+
encoder_hidden_states=None,
|
339 |
+
encoder_attention_mask=None,
|
340 |
+
past_key_value=None,
|
341 |
+
output_attentions=False,
|
342 |
+
mode=None,
|
343 |
+
):
|
344 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
345 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
346 |
+
self_attention_outputs = self.attention(
|
347 |
+
hidden_states,
|
348 |
+
attention_mask,
|
349 |
+
head_mask,
|
350 |
+
output_attentions=output_attentions,
|
351 |
+
past_key_value=self_attn_past_key_value,
|
352 |
+
)
|
353 |
+
attention_output = self_attention_outputs[0]
|
354 |
+
|
355 |
+
outputs = self_attention_outputs[1:-1]
|
356 |
+
present_key_value = self_attention_outputs[-1]
|
357 |
+
|
358 |
+
if mode=='multimodal':
|
359 |
+
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
|
360 |
+
|
361 |
+
cross_attention_outputs = self.crossattention(
|
362 |
+
attention_output,
|
363 |
+
attention_mask,
|
364 |
+
head_mask,
|
365 |
+
encoder_hidden_states,
|
366 |
+
encoder_attention_mask,
|
367 |
+
output_attentions=output_attentions,
|
368 |
+
)
|
369 |
+
attention_output = cross_attention_outputs[0]
|
370 |
+
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
371 |
+
layer_output = apply_chunking_to_forward(
|
372 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
373 |
+
)
|
374 |
+
outputs = (layer_output,) + outputs
|
375 |
+
|
376 |
+
outputs = outputs + (present_key_value,)
|
377 |
+
|
378 |
+
return outputs
|
379 |
+
|
380 |
+
def feed_forward_chunk(self, attention_output):
|
381 |
+
intermediate_output = self.intermediate(attention_output)
|
382 |
+
layer_output = self.output(intermediate_output, attention_output)
|
383 |
+
return layer_output
|
384 |
+
|
385 |
+
|
386 |
+
class BertEncoder(nn.Module):
|
387 |
+
def __init__(self, config):
|
388 |
+
super().__init__()
|
389 |
+
self.config = config
|
390 |
+
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
|
391 |
+
self.gradient_checkpointing = False
|
392 |
+
|
393 |
+
def forward(
|
394 |
+
self,
|
395 |
+
hidden_states,
|
396 |
+
attention_mask=None,
|
397 |
+
head_mask=None,
|
398 |
+
encoder_hidden_states=None,
|
399 |
+
encoder_attention_mask=None,
|
400 |
+
past_key_values=None,
|
401 |
+
use_cache=None,
|
402 |
+
output_attentions=False,
|
403 |
+
output_hidden_states=False,
|
404 |
+
return_dict=True,
|
405 |
+
mode='multimodal',
|
406 |
+
):
|
407 |
+
all_hidden_states = () if output_hidden_states else None
|
408 |
+
all_self_attentions = () if output_attentions else None
|
409 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
410 |
+
|
411 |
+
next_decoder_cache = () if use_cache else None
|
412 |
+
|
413 |
+
for i in range(self.config.num_hidden_layers):
|
414 |
+
layer_module = self.layer[i]
|
415 |
+
if output_hidden_states:
|
416 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
417 |
+
|
418 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
419 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
420 |
+
|
421 |
+
if self.gradient_checkpointing and self.training:
|
422 |
+
|
423 |
+
if use_cache:
|
424 |
+
logger.warn(
|
425 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
426 |
+
)
|
427 |
+
use_cache = False
|
428 |
+
|
429 |
+
def create_custom_forward(module):
|
430 |
+
def custom_forward(*inputs):
|
431 |
+
return module(*inputs, past_key_value, output_attentions)
|
432 |
+
|
433 |
+
return custom_forward
|
434 |
+
|
435 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
436 |
+
create_custom_forward(layer_module),
|
437 |
+
hidden_states,
|
438 |
+
attention_mask,
|
439 |
+
layer_head_mask,
|
440 |
+
encoder_hidden_states,
|
441 |
+
encoder_attention_mask,
|
442 |
+
mode=mode,
|
443 |
+
)
|
444 |
+
else:
|
445 |
+
layer_outputs = layer_module(
|
446 |
+
hidden_states,
|
447 |
+
attention_mask,
|
448 |
+
layer_head_mask,
|
449 |
+
encoder_hidden_states,
|
450 |
+
encoder_attention_mask,
|
451 |
+
past_key_value,
|
452 |
+
output_attentions,
|
453 |
+
mode=mode,
|
454 |
+
)
|
455 |
+
|
456 |
+
hidden_states = layer_outputs[0]
|
457 |
+
if use_cache:
|
458 |
+
next_decoder_cache += (layer_outputs[-1],)
|
459 |
+
if output_attentions:
|
460 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
461 |
+
|
462 |
+
if output_hidden_states:
|
463 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
464 |
+
|
465 |
+
if not return_dict:
|
466 |
+
return tuple(
|
467 |
+
v
|
468 |
+
for v in [
|
469 |
+
hidden_states,
|
470 |
+
next_decoder_cache,
|
471 |
+
all_hidden_states,
|
472 |
+
all_self_attentions,
|
473 |
+
all_cross_attentions,
|
474 |
+
]
|
475 |
+
if v is not None
|
476 |
+
)
|
477 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
478 |
+
last_hidden_state=hidden_states,
|
479 |
+
past_key_values=next_decoder_cache,
|
480 |
+
hidden_states=all_hidden_states,
|
481 |
+
attentions=all_self_attentions,
|
482 |
+
cross_attentions=all_cross_attentions,
|
483 |
+
)
|
484 |
+
|
485 |
+
|
486 |
+
class BertPooler(nn.Module):
|
487 |
+
def __init__(self, config):
|
488 |
+
super().__init__()
|
489 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
490 |
+
self.activation = nn.Tanh()
|
491 |
+
|
492 |
+
def forward(self, hidden_states):
|
493 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
494 |
+
# to the first token.
|
495 |
+
first_token_tensor = hidden_states[:, 0]
|
496 |
+
pooled_output = self.dense(first_token_tensor)
|
497 |
+
pooled_output = self.activation(pooled_output)
|
498 |
+
return pooled_output
|
499 |
+
|
500 |
+
|
501 |
+
class BertPredictionHeadTransform(nn.Module):
|
502 |
+
def __init__(self, config):
|
503 |
+
super().__init__()
|
504 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
505 |
+
if isinstance(config.hidden_act, str):
|
506 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
507 |
+
else:
|
508 |
+
self.transform_act_fn = config.hidden_act
|
509 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
510 |
+
|
511 |
+
def forward(self, hidden_states):
|
512 |
+
hidden_states = self.dense(hidden_states)
|
513 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
514 |
+
hidden_states = self.LayerNorm(hidden_states)
|
515 |
+
return hidden_states
|
516 |
+
|
517 |
+
|
518 |
+
class BertLMPredictionHead(nn.Module):
|
519 |
+
def __init__(self, config):
|
520 |
+
super().__init__()
|
521 |
+
self.transform = BertPredictionHeadTransform(config)
|
522 |
+
|
523 |
+
# The output weights are the same as the input embeddings, but there is
|
524 |
+
# an output-only bias for each token.
|
525 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
526 |
+
|
527 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
528 |
+
|
529 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
530 |
+
self.decoder.bias = self.bias
|
531 |
+
|
532 |
+
def forward(self, hidden_states):
|
533 |
+
hidden_states = self.transform(hidden_states)
|
534 |
+
hidden_states = self.decoder(hidden_states)
|
535 |
+
return hidden_states
|
536 |
+
|
537 |
+
|
538 |
+
class BertOnlyMLMHead(nn.Module):
|
539 |
+
def __init__(self, config):
|
540 |
+
super().__init__()
|
541 |
+
self.predictions = BertLMPredictionHead(config)
|
542 |
+
|
543 |
+
def forward(self, sequence_output):
|
544 |
+
prediction_scores = self.predictions(sequence_output)
|
545 |
+
return prediction_scores
|
546 |
+
|
547 |
+
|
548 |
+
class BertPreTrainedModel(PreTrainedModel):
|
549 |
+
"""
|
550 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
551 |
+
models.
|
552 |
+
"""
|
553 |
+
|
554 |
+
config_class = BertConfig
|
555 |
+
base_model_prefix = "bert"
|
556 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
557 |
+
|
558 |
+
def _init_weights(self, module):
|
559 |
+
""" Initialize the weights """
|
560 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
561 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
562 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
563 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
564 |
+
elif isinstance(module, nn.LayerNorm):
|
565 |
+
module.bias.data.zero_()
|
566 |
+
module.weight.data.fill_(1.0)
|
567 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
568 |
+
module.bias.data.zero_()
|
569 |
+
|
570 |
+
|
571 |
+
class BertModel(BertPreTrainedModel):
|
572 |
+
"""
|
573 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
574 |
+
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
575 |
+
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
576 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
577 |
+
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
578 |
+
input to the forward pass.
|
579 |
+
"""
|
580 |
+
|
581 |
+
def __init__(self, config, add_pooling_layer=True):
|
582 |
+
super().__init__(config)
|
583 |
+
self.config = config
|
584 |
+
|
585 |
+
self.embeddings = BertEmbeddings(config)
|
586 |
+
|
587 |
+
self.encoder = BertEncoder(config)
|
588 |
+
|
589 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
590 |
+
|
591 |
+
self.init_weights()
|
592 |
+
|
593 |
+
|
594 |
+
def get_input_embeddings(self):
|
595 |
+
return self.embeddings.word_embeddings
|
596 |
+
|
597 |
+
def set_input_embeddings(self, value):
|
598 |
+
self.embeddings.word_embeddings = value
|
599 |
+
|
600 |
+
def _prune_heads(self, heads_to_prune):
|
601 |
+
"""
|
602 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
603 |
+
class PreTrainedModel
|
604 |
+
"""
|
605 |
+
for layer, heads in heads_to_prune.items():
|
606 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
607 |
+
|
608 |
+
|
609 |
+
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
|
610 |
+
"""
|
611 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
612 |
+
|
613 |
+
Arguments:
|
614 |
+
attention_mask (:obj:`torch.Tensor`):
|
615 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
616 |
+
input_shape (:obj:`Tuple[int]`):
|
617 |
+
The shape of the input to the model.
|
618 |
+
device: (:obj:`torch.device`):
|
619 |
+
The device of the input to the model.
|
620 |
+
|
621 |
+
Returns:
|
622 |
+
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
623 |
+
"""
|
624 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
625 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
626 |
+
if attention_mask.dim() == 3:
|
627 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
628 |
+
elif attention_mask.dim() == 2:
|
629 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
630 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
631 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
632 |
+
if is_decoder:
|
633 |
+
batch_size, seq_length = input_shape
|
634 |
+
|
635 |
+
seq_ids = torch.arange(seq_length, device=device)
|
636 |
+
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
637 |
+
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
638 |
+
# causal and attention masks must have same type with pytorch version < 1.3
|
639 |
+
causal_mask = causal_mask.to(attention_mask.dtype)
|
640 |
+
|
641 |
+
if causal_mask.shape[1] < attention_mask.shape[1]:
|
642 |
+
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
643 |
+
causal_mask = torch.cat(
|
644 |
+
[
|
645 |
+
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
|
646 |
+
causal_mask,
|
647 |
+
],
|
648 |
+
axis=-1,
|
649 |
+
)
|
650 |
+
|
651 |
+
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
652 |
+
else:
|
653 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
654 |
+
else:
|
655 |
+
raise ValueError(
|
656 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
657 |
+
input_shape, attention_mask.shape
|
658 |
+
)
|
659 |
+
)
|
660 |
+
|
661 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
662 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
663 |
+
# positions we want to attend and -10000.0 for masked positions.
|
664 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
665 |
+
# effectively the same as removing these entirely.
|
666 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
667 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
668 |
+
return extended_attention_mask
|
669 |
+
|
670 |
+
def forward(
|
671 |
+
self,
|
672 |
+
input_ids=None,
|
673 |
+
attention_mask=None,
|
674 |
+
position_ids=None,
|
675 |
+
head_mask=None,
|
676 |
+
inputs_embeds=None,
|
677 |
+
encoder_embeds=None,
|
678 |
+
encoder_hidden_states=None,
|
679 |
+
encoder_attention_mask=None,
|
680 |
+
past_key_values=None,
|
681 |
+
use_cache=None,
|
682 |
+
output_attentions=None,
|
683 |
+
output_hidden_states=None,
|
684 |
+
return_dict=None,
|
685 |
+
is_decoder=False,
|
686 |
+
mode='multimodal',
|
687 |
+
):
|
688 |
+
r"""
|
689 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
690 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
691 |
+
the model is configured as a decoder.
|
692 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
693 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
694 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
695 |
+
- 1 for tokens that are **not masked**,
|
696 |
+
- 0 for tokens that are **masked**.
|
697 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
698 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
699 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
700 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
701 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
702 |
+
use_cache (:obj:`bool`, `optional`):
|
703 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
704 |
+
decoding (see :obj:`past_key_values`).
|
705 |
+
"""
|
706 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
707 |
+
output_hidden_states = (
|
708 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
709 |
+
)
|
710 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
711 |
+
|
712 |
+
if is_decoder:
|
713 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
714 |
+
else:
|
715 |
+
use_cache = False
|
716 |
+
|
717 |
+
if input_ids is not None and inputs_embeds is not None:
|
718 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
719 |
+
elif input_ids is not None:
|
720 |
+
input_shape = input_ids.size()
|
721 |
+
batch_size, seq_length = input_shape
|
722 |
+
device = input_ids.device
|
723 |
+
elif inputs_embeds is not None:
|
724 |
+
input_shape = inputs_embeds.size()[:-1]
|
725 |
+
batch_size, seq_length = input_shape
|
726 |
+
device = inputs_embeds.device
|
727 |
+
elif encoder_embeds is not None:
|
728 |
+
input_shape = encoder_embeds.size()[:-1]
|
729 |
+
batch_size, seq_length = input_shape
|
730 |
+
device = encoder_embeds.device
|
731 |
+
else:
|
732 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
|
733 |
+
|
734 |
+
# past_key_values_length
|
735 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
736 |
+
|
737 |
+
if attention_mask is None:
|
738 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
739 |
+
|
740 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
741 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
742 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
|
743 |
+
device, is_decoder)
|
744 |
+
|
745 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
746 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
747 |
+
if encoder_hidden_states is not None:
|
748 |
+
if type(encoder_hidden_states) == list:
|
749 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
750 |
+
else:
|
751 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
752 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
753 |
+
|
754 |
+
if type(encoder_attention_mask) == list:
|
755 |
+
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
756 |
+
elif encoder_attention_mask is None:
|
757 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
758 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
759 |
+
else:
|
760 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
761 |
+
else:
|
762 |
+
encoder_extended_attention_mask = None
|
763 |
+
|
764 |
+
# Prepare head mask if needed
|
765 |
+
# 1.0 in head_mask indicate we keep the head
|
766 |
+
# attention_probs has shape bsz x n_heads x N x N
|
767 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
768 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
769 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
770 |
+
|
771 |
+
if encoder_embeds is None:
|
772 |
+
embedding_output = self.embeddings(
|
773 |
+
input_ids=input_ids,
|
774 |
+
position_ids=position_ids,
|
775 |
+
inputs_embeds=inputs_embeds,
|
776 |
+
past_key_values_length=past_key_values_length,
|
777 |
+
)
|
778 |
+
else:
|
779 |
+
embedding_output = encoder_embeds
|
780 |
+
|
781 |
+
encoder_outputs = self.encoder(
|
782 |
+
embedding_output,
|
783 |
+
attention_mask=extended_attention_mask,
|
784 |
+
head_mask=head_mask,
|
785 |
+
encoder_hidden_states=encoder_hidden_states,
|
786 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
787 |
+
past_key_values=past_key_values,
|
788 |
+
use_cache=use_cache,
|
789 |
+
output_attentions=output_attentions,
|
790 |
+
output_hidden_states=output_hidden_states,
|
791 |
+
return_dict=return_dict,
|
792 |
+
mode=mode,
|
793 |
+
)
|
794 |
+
sequence_output = encoder_outputs[0]
|
795 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
796 |
+
|
797 |
+
if not return_dict:
|
798 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
799 |
+
|
800 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
801 |
+
last_hidden_state=sequence_output,
|
802 |
+
pooler_output=pooled_output,
|
803 |
+
past_key_values=encoder_outputs.past_key_values,
|
804 |
+
hidden_states=encoder_outputs.hidden_states,
|
805 |
+
attentions=encoder_outputs.attentions,
|
806 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
807 |
+
)
|
808 |
+
|
809 |
+
|
810 |
+
|
811 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
812 |
+
|
813 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
814 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
815 |
+
|
816 |
+
def __init__(self, config):
|
817 |
+
super().__init__(config)
|
818 |
+
|
819 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
820 |
+
self.cls = BertOnlyMLMHead(config)
|
821 |
+
|
822 |
+
self.init_weights()
|
823 |
+
|
824 |
+
def get_output_embeddings(self):
|
825 |
+
return self.cls.predictions.decoder
|
826 |
+
|
827 |
+
def set_output_embeddings(self, new_embeddings):
|
828 |
+
self.cls.predictions.decoder = new_embeddings
|
829 |
+
|
830 |
+
def forward(
|
831 |
+
self,
|
832 |
+
input_ids=None,
|
833 |
+
attention_mask=None,
|
834 |
+
position_ids=None,
|
835 |
+
head_mask=None,
|
836 |
+
inputs_embeds=None,
|
837 |
+
encoder_hidden_states=None,
|
838 |
+
encoder_attention_mask=None,
|
839 |
+
labels=None,
|
840 |
+
past_key_values=None,
|
841 |
+
use_cache=None,
|
842 |
+
output_attentions=None,
|
843 |
+
output_hidden_states=None,
|
844 |
+
return_dict=None,
|
845 |
+
return_logits=False,
|
846 |
+
is_decoder=True,
|
847 |
+
reduction='mean',
|
848 |
+
mode='multimodal',
|
849 |
+
):
|
850 |
+
r"""
|
851 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
852 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
853 |
+
the model is configured as a decoder.
|
854 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
855 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
856 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
857 |
+
- 1 for tokens that are **not masked**,
|
858 |
+
- 0 for tokens that are **masked**.
|
859 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
860 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
861 |
+
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
862 |
+
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
863 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
864 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
865 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
866 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
867 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
868 |
+
use_cache (:obj:`bool`, `optional`):
|
869 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
870 |
+
decoding (see :obj:`past_key_values`).
|
871 |
+
Returns:
|
872 |
+
Example::
|
873 |
+
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
874 |
+
>>> import torch
|
875 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
876 |
+
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
877 |
+
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
878 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
879 |
+
>>> outputs = model(**inputs)
|
880 |
+
>>> prediction_logits = outputs.logits
|
881 |
+
"""
|
882 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
883 |
+
if labels is not None:
|
884 |
+
use_cache = False
|
885 |
+
|
886 |
+
outputs = self.bert(
|
887 |
+
input_ids,
|
888 |
+
attention_mask=attention_mask,
|
889 |
+
position_ids=position_ids,
|
890 |
+
head_mask=head_mask,
|
891 |
+
inputs_embeds=inputs_embeds,
|
892 |
+
encoder_hidden_states=encoder_hidden_states,
|
893 |
+
encoder_attention_mask=encoder_attention_mask,
|
894 |
+
past_key_values=past_key_values,
|
895 |
+
use_cache=use_cache,
|
896 |
+
output_attentions=output_attentions,
|
897 |
+
output_hidden_states=output_hidden_states,
|
898 |
+
return_dict=return_dict,
|
899 |
+
is_decoder=is_decoder,
|
900 |
+
mode=mode,
|
901 |
+
)
|
902 |
+
|
903 |
+
sequence_output = outputs[0]
|
904 |
+
prediction_scores = self.cls(sequence_output)
|
905 |
+
|
906 |
+
if return_logits:
|
907 |
+
return prediction_scores[:, :-1, :].contiguous()
|
908 |
+
|
909 |
+
lm_loss = None
|
910 |
+
if labels is not None:
|
911 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
912 |
+
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
913 |
+
labels = labels[:, 1:].contiguous()
|
914 |
+
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
915 |
+
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
916 |
+
if reduction=='none':
|
917 |
+
lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
|
918 |
+
|
919 |
+
if not return_dict:
|
920 |
+
output = (prediction_scores,) + outputs[2:]
|
921 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
922 |
+
|
923 |
+
return CausalLMOutputWithCrossAttentions(
|
924 |
+
loss=lm_loss,
|
925 |
+
logits=prediction_scores,
|
926 |
+
past_key_values=outputs.past_key_values,
|
927 |
+
hidden_states=outputs.hidden_states,
|
928 |
+
attentions=outputs.attentions,
|
929 |
+
cross_attentions=outputs.cross_attentions,
|
930 |
+
)
|
931 |
+
|
932 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
933 |
+
input_shape = input_ids.shape
|
934 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
935 |
+
if attention_mask is None:
|
936 |
+
attention_mask = input_ids.new_ones(input_shape)
|
937 |
+
|
938 |
+
# cut decoder_input_ids if past is used
|
939 |
+
if past is not None:
|
940 |
+
input_ids = input_ids[:, -1:]
|
941 |
+
|
942 |
+
return {
|
943 |
+
"input_ids": input_ids,
|
944 |
+
"attention_mask": attention_mask,
|
945 |
+
"past_key_values": past,
|
946 |
+
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
947 |
+
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
948 |
+
"is_decoder": True,
|
949 |
+
}
|
950 |
+
|
951 |
+
def _reorder_cache(self, past, beam_idx):
|
952 |
+
reordered_past = ()
|
953 |
+
for layer_past in past:
|
954 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
955 |
+
return reordered_past
|
finetune/blip/med_config.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertModel"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"hidden_act": "gelu",
|
7 |
+
"hidden_dropout_prob": 0.1,
|
8 |
+
"hidden_size": 768,
|
9 |
+
"initializer_range": 0.02,
|
10 |
+
"intermediate_size": 3072,
|
11 |
+
"layer_norm_eps": 1e-12,
|
12 |
+
"max_position_embeddings": 512,
|
13 |
+
"model_type": "bert",
|
14 |
+
"num_attention_heads": 12,
|
15 |
+
"num_hidden_layers": 12,
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"type_vocab_size": 2,
|
18 |
+
"vocab_size": 30524,
|
19 |
+
"encoder_width": 768,
|
20 |
+
"add_cross_attention": true
|
21 |
+
}
|
22 |
+
|
finetune/blip/vit.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
* Based on timm code base
|
8 |
+
* https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
9 |
+
'''
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from functools import partial
|
15 |
+
|
16 |
+
from timm.models.vision_transformer import _cfg, PatchEmbed
|
17 |
+
from timm.models.registry import register_model
|
18 |
+
from timm.models.layers import trunc_normal_, DropPath
|
19 |
+
from timm.models.helpers import named_apply, adapt_input_conv
|
20 |
+
|
21 |
+
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
22 |
+
|
23 |
+
class Mlp(nn.Module):
|
24 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
25 |
+
"""
|
26 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
27 |
+
super().__init__()
|
28 |
+
out_features = out_features or in_features
|
29 |
+
hidden_features = hidden_features or in_features
|
30 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
31 |
+
self.act = act_layer()
|
32 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
33 |
+
self.drop = nn.Dropout(drop)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
x = self.fc1(x)
|
37 |
+
x = self.act(x)
|
38 |
+
x = self.drop(x)
|
39 |
+
x = self.fc2(x)
|
40 |
+
x = self.drop(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
class Attention(nn.Module):
|
45 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
46 |
+
super().__init__()
|
47 |
+
self.num_heads = num_heads
|
48 |
+
head_dim = dim // num_heads
|
49 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
50 |
+
self.scale = qk_scale or head_dim ** -0.5
|
51 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
52 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
53 |
+
self.proj = nn.Linear(dim, dim)
|
54 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
55 |
+
self.attn_gradients = None
|
56 |
+
self.attention_map = None
|
57 |
+
|
58 |
+
def save_attn_gradients(self, attn_gradients):
|
59 |
+
self.attn_gradients = attn_gradients
|
60 |
+
|
61 |
+
def get_attn_gradients(self):
|
62 |
+
return self.attn_gradients
|
63 |
+
|
64 |
+
def save_attention_map(self, attention_map):
|
65 |
+
self.attention_map = attention_map
|
66 |
+
|
67 |
+
def get_attention_map(self):
|
68 |
+
return self.attention_map
|
69 |
+
|
70 |
+
def forward(self, x, register_hook=False):
|
71 |
+
B, N, C = x.shape
|
72 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
73 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
74 |
+
|
75 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
76 |
+
attn = attn.softmax(dim=-1)
|
77 |
+
attn = self.attn_drop(attn)
|
78 |
+
|
79 |
+
if register_hook:
|
80 |
+
self.save_attention_map(attn)
|
81 |
+
attn.register_hook(self.save_attn_gradients)
|
82 |
+
|
83 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
84 |
+
x = self.proj(x)
|
85 |
+
x = self.proj_drop(x)
|
86 |
+
return x
|
87 |
+
|
88 |
+
|
89 |
+
class Block(nn.Module):
|
90 |
+
|
91 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
92 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
|
93 |
+
super().__init__()
|
94 |
+
self.norm1 = norm_layer(dim)
|
95 |
+
self.attn = Attention(
|
96 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
97 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
98 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
99 |
+
self.norm2 = norm_layer(dim)
|
100 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
101 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
102 |
+
|
103 |
+
if use_grad_checkpointing:
|
104 |
+
self.attn = checkpoint_wrapper(self.attn)
|
105 |
+
self.mlp = checkpoint_wrapper(self.mlp)
|
106 |
+
|
107 |
+
def forward(self, x, register_hook=False):
|
108 |
+
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
|
109 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class VisionTransformer(nn.Module):
|
114 |
+
""" Vision Transformer
|
115 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
116 |
+
https://arxiv.org/abs/2010.11929
|
117 |
+
"""
|
118 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
119 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
120 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
|
121 |
+
use_grad_checkpointing=False, ckpt_layer=0):
|
122 |
+
"""
|
123 |
+
Args:
|
124 |
+
img_size (int, tuple): input image size
|
125 |
+
patch_size (int, tuple): patch size
|
126 |
+
in_chans (int): number of input channels
|
127 |
+
num_classes (int): number of classes for classification head
|
128 |
+
embed_dim (int): embedding dimension
|
129 |
+
depth (int): depth of transformer
|
130 |
+
num_heads (int): number of attention heads
|
131 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
132 |
+
qkv_bias (bool): enable bias for qkv if True
|
133 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
134 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
135 |
+
drop_rate (float): dropout rate
|
136 |
+
attn_drop_rate (float): attention dropout rate
|
137 |
+
drop_path_rate (float): stochastic depth rate
|
138 |
+
norm_layer: (nn.Module): normalization layer
|
139 |
+
"""
|
140 |
+
super().__init__()
|
141 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
142 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
143 |
+
|
144 |
+
self.patch_embed = PatchEmbed(
|
145 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
146 |
+
|
147 |
+
num_patches = self.patch_embed.num_patches
|
148 |
+
|
149 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
150 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
151 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
152 |
+
|
153 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
154 |
+
self.blocks = nn.ModuleList([
|
155 |
+
Block(
|
156 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
157 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
158 |
+
use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
|
159 |
+
)
|
160 |
+
for i in range(depth)])
|
161 |
+
self.norm = norm_layer(embed_dim)
|
162 |
+
|
163 |
+
trunc_normal_(self.pos_embed, std=.02)
|
164 |
+
trunc_normal_(self.cls_token, std=.02)
|
165 |
+
self.apply(self._init_weights)
|
166 |
+
|
167 |
+
def _init_weights(self, m):
|
168 |
+
if isinstance(m, nn.Linear):
|
169 |
+
trunc_normal_(m.weight, std=.02)
|
170 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
171 |
+
nn.init.constant_(m.bias, 0)
|
172 |
+
elif isinstance(m, nn.LayerNorm):
|
173 |
+
nn.init.constant_(m.bias, 0)
|
174 |
+
nn.init.constant_(m.weight, 1.0)
|
175 |
+
|
176 |
+
@torch.jit.ignore
|
177 |
+
def no_weight_decay(self):
|
178 |
+
return {'pos_embed', 'cls_token'}
|
179 |
+
|
180 |
+
def forward(self, x, register_blk=-1):
|
181 |
+
B = x.shape[0]
|
182 |
+
x = self.patch_embed(x)
|
183 |
+
|
184 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
185 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
186 |
+
|
187 |
+
x = x + self.pos_embed[:,:x.size(1),:]
|
188 |
+
x = self.pos_drop(x)
|
189 |
+
|
190 |
+
for i,blk in enumerate(self.blocks):
|
191 |
+
x = blk(x, register_blk==i)
|
192 |
+
x = self.norm(x)
|
193 |
+
|
194 |
+
return x
|
195 |
+
|
196 |
+
@torch.jit.ignore()
|
197 |
+
def load_pretrained(self, checkpoint_path, prefix=''):
|
198 |
+
_load_weights(self, checkpoint_path, prefix)
|
199 |
+
|
200 |
+
|
201 |
+
@torch.no_grad()
|
202 |
+
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
|
203 |
+
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
|
204 |
+
"""
|
205 |
+
import numpy as np
|
206 |
+
|
207 |
+
def _n2p(w, t=True):
|
208 |
+
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
209 |
+
w = w.flatten()
|
210 |
+
if t:
|
211 |
+
if w.ndim == 4:
|
212 |
+
w = w.transpose([3, 2, 0, 1])
|
213 |
+
elif w.ndim == 3:
|
214 |
+
w = w.transpose([2, 0, 1])
|
215 |
+
elif w.ndim == 2:
|
216 |
+
w = w.transpose([1, 0])
|
217 |
+
return torch.from_numpy(w)
|
218 |
+
|
219 |
+
w = np.load(checkpoint_path)
|
220 |
+
if not prefix and 'opt/target/embedding/kernel' in w:
|
221 |
+
prefix = 'opt/target/'
|
222 |
+
|
223 |
+
if hasattr(model.patch_embed, 'backbone'):
|
224 |
+
# hybrid
|
225 |
+
backbone = model.patch_embed.backbone
|
226 |
+
stem_only = not hasattr(backbone, 'stem')
|
227 |
+
stem = backbone if stem_only else backbone.stem
|
228 |
+
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
|
229 |
+
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
|
230 |
+
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
|
231 |
+
if not stem_only:
|
232 |
+
for i, stage in enumerate(backbone.stages):
|
233 |
+
for j, block in enumerate(stage.blocks):
|
234 |
+
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
|
235 |
+
for r in range(3):
|
236 |
+
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
|
237 |
+
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
|
238 |
+
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
|
239 |
+
if block.downsample is not None:
|
240 |
+
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
|
241 |
+
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
|
242 |
+
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
|
243 |
+
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
|
244 |
+
else:
|
245 |
+
embed_conv_w = adapt_input_conv(
|
246 |
+
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
247 |
+
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
248 |
+
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
249 |
+
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
250 |
+
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
251 |
+
if pos_embed_w.shape != model.pos_embed.shape:
|
252 |
+
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
253 |
+
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
254 |
+
model.pos_embed.copy_(pos_embed_w)
|
255 |
+
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
256 |
+
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
257 |
+
# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
258 |
+
# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
259 |
+
# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
260 |
+
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
261 |
+
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
262 |
+
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
263 |
+
for i, block in enumerate(model.blocks.children()):
|
264 |
+
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
265 |
+
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
266 |
+
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
267 |
+
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
268 |
+
block.attn.qkv.weight.copy_(torch.cat([
|
269 |
+
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
270 |
+
block.attn.qkv.bias.copy_(torch.cat([
|
271 |
+
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
272 |
+
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
273 |
+
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
274 |
+
for r in range(2):
|
275 |
+
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
|
276 |
+
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
|
277 |
+
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
|
278 |
+
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
279 |
+
|
280 |
+
|
281 |
+
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
|
282 |
+
# interpolate position embedding
|
283 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
284 |
+
num_patches = visual_encoder.patch_embed.num_patches
|
285 |
+
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
|
286 |
+
# height (== width) for the checkpoint position embedding
|
287 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
288 |
+
# height (== width) for the new position embedding
|
289 |
+
new_size = int(num_patches ** 0.5)
|
290 |
+
|
291 |
+
if orig_size!=new_size:
|
292 |
+
# class_token and dist_token are kept unchanged
|
293 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
294 |
+
# only the position tokens are interpolated
|
295 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
296 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
297 |
+
pos_tokens = torch.nn.functional.interpolate(
|
298 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
299 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
300 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
301 |
+
print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
|
302 |
+
|
303 |
+
return new_pos_embed
|
304 |
+
else:
|
305 |
+
return pos_embed_checkpoint
|
finetune/clean_captions_and_tags.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# このスクリプトのライセンスは、Apache License 2.0とします
|
2 |
+
# (c) 2022 Kohya S. @kohya_ss
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import glob
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
import re
|
9 |
+
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
|
13 |
+
PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
|
14 |
+
PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ')
|
15 |
+
PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ')
|
16 |
+
|
17 |
+
# 複数人がいるとき、複数の髪色や目の色が定義されていれば削除する
|
18 |
+
PATTERNS_REMOVE_IN_MULTI = [
|
19 |
+
PATTERN_HAIR_LENGTH,
|
20 |
+
PATTERN_HAIR_CUT,
|
21 |
+
re.compile(r', [\w\-]+ eyes, '),
|
22 |
+
re.compile(r', ([\w\-]+ sleeves|sleeveless), '),
|
23 |
+
# 複数の髪型定義がある場合は削除する
|
24 |
+
re.compile(
|
25 |
+
r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '),
|
26 |
+
]
|
27 |
+
|
28 |
+
|
29 |
+
def clean_tags(image_key, tags):
|
30 |
+
# replace '_' to ' '
|
31 |
+
tags = tags.replace('^_^', '^@@@^')
|
32 |
+
tags = tags.replace('_', ' ')
|
33 |
+
tags = tags.replace('^@@@^', '^_^')
|
34 |
+
|
35 |
+
# remove rating: deepdanbooruのみ
|
36 |
+
tokens = tags.split(", rating")
|
37 |
+
if len(tokens) == 1:
|
38 |
+
# WD14 taggerのときはこちらになるのでメッセージは出さない
|
39 |
+
# print("no rating:")
|
40 |
+
# print(f"{image_key} {tags}")
|
41 |
+
pass
|
42 |
+
else:
|
43 |
+
if len(tokens) > 2:
|
44 |
+
print("multiple ratings:")
|
45 |
+
print(f"{image_key} {tags}")
|
46 |
+
tags = tokens[0]
|
47 |
+
|
48 |
+
tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
|
49 |
+
|
50 |
+
# 複数の人物がいる場合は髪色等のタグを削除する
|
51 |
+
if 'girls' in tags or 'boys' in tags:
|
52 |
+
for pat in PATTERNS_REMOVE_IN_MULTI:
|
53 |
+
found = pat.findall(tags)
|
54 |
+
if len(found) > 1: # 二つ以上、タグがある
|
55 |
+
tags = pat.sub("", tags)
|
56 |
+
|
57 |
+
# 髪の特殊対応
|
58 |
+
srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) # 髪の長さタグは例外なので避けておく(全員が同じ髪の長さの場合)
|
59 |
+
if srch_hair_len:
|
60 |
+
org = srch_hair_len.group()
|
61 |
+
tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags)
|
62 |
+
|
63 |
+
found = PATTERN_HAIR.findall(tags)
|
64 |
+
if len(found) > 1:
|
65 |
+
tags = PATTERN_HAIR.sub("", tags)
|
66 |
+
|
67 |
+
if srch_hair_len:
|
68 |
+
tags = tags.replace(", @@@, ", org) # 戻す
|
69 |
+
|
70 |
+
# white shirtとshirtみたいな重複タグの削除
|
71 |
+
found = PATTERN_WORD.findall(tags)
|
72 |
+
for word in found:
|
73 |
+
if re.search(f", ((\w+) )+{word}, ", tags):
|
74 |
+
tags = tags.replace(f", {word}, ", "")
|
75 |
+
|
76 |
+
tags = tags.replace(", , ", ", ")
|
77 |
+
assert tags.startswith(", ") and tags.endswith(", ")
|
78 |
+
tags = tags[2:-2]
|
79 |
+
return tags
|
80 |
+
|
81 |
+
|
82 |
+
# 上から順に検索、置換される
|
83 |
+
# ('置換元文字列', '置換後文字列')
|
84 |
+
CAPTION_REPLACEMENTS = [
|
85 |
+
('anime anime', 'anime'),
|
86 |
+
('young ', ''),
|
87 |
+
('anime girl', 'girl'),
|
88 |
+
('cartoon female', 'girl'),
|
89 |
+
('cartoon lady', 'girl'),
|
90 |
+
('cartoon character', 'girl'), # a or ~s
|
91 |
+
('cartoon woman', 'girl'),
|
92 |
+
('cartoon women', 'girls'),
|
93 |
+
('cartoon girl', 'girl'),
|
94 |
+
('anime female', 'girl'),
|
95 |
+
('anime lady', 'girl'),
|
96 |
+
('anime character', 'girl'), # a or ~s
|
97 |
+
('anime woman', 'girl'),
|
98 |
+
('anime women', 'girls'),
|
99 |
+
('lady', 'girl'),
|
100 |
+
('female', 'girl'),
|
101 |
+
('woman', 'girl'),
|
102 |
+
('women', 'girls'),
|
103 |
+
('people', 'girls'),
|
104 |
+
('person', 'girl'),
|
105 |
+
('a cartoon figure', 'a figure'),
|
106 |
+
('a cartoon image', 'an image'),
|
107 |
+
('a cartoon picture', 'a picture'),
|
108 |
+
('an anime cartoon image', 'an image'),
|
109 |
+
('a cartoon anime drawing', 'a drawing'),
|
110 |
+
('a cartoon drawing', 'a drawing'),
|
111 |
+
('girl girl', 'girl'),
|
112 |
+
]
|
113 |
+
|
114 |
+
|
115 |
+
def clean_caption(caption):
|
116 |
+
for rf, rt in CAPTION_REPLACEMENTS:
|
117 |
+
replaced = True
|
118 |
+
while replaced:
|
119 |
+
bef = caption
|
120 |
+
caption = caption.replace(rf, rt)
|
121 |
+
replaced = bef != caption
|
122 |
+
return caption
|
123 |
+
|
124 |
+
|
125 |
+
def main(args):
|
126 |
+
if os.path.exists(args.in_json):
|
127 |
+
print(f"loading existing metadata: {args.in_json}")
|
128 |
+
with open(args.in_json, "rt", encoding='utf-8') as f:
|
129 |
+
metadata = json.load(f)
|
130 |
+
else:
|
131 |
+
print("no metadata / メタデータファイルがありません")
|
132 |
+
return
|
133 |
+
|
134 |
+
print("cleaning captions and tags.")
|
135 |
+
image_keys = list(metadata.keys())
|
136 |
+
for image_key in tqdm(image_keys):
|
137 |
+
tags = metadata[image_key].get('tags')
|
138 |
+
if tags is None:
|
139 |
+
print(f"image does not have tags / メタデータにタグがありません: {image_key}")
|
140 |
+
else:
|
141 |
+
org = tags
|
142 |
+
tags = clean_tags(image_key, tags)
|
143 |
+
metadata[image_key]['tags'] = tags
|
144 |
+
if args.debug and org != tags:
|
145 |
+
print("FROM: " + org)
|
146 |
+
print("TO: " + tags)
|
147 |
+
|
148 |
+
caption = metadata[image_key].get('caption')
|
149 |
+
if caption is None:
|
150 |
+
print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
|
151 |
+
else:
|
152 |
+
org = caption
|
153 |
+
caption = clean_caption(caption)
|
154 |
+
metadata[image_key]['caption'] = caption
|
155 |
+
if args.debug and org != caption:
|
156 |
+
print("FROM: " + org)
|
157 |
+
print("TO: " + caption)
|
158 |
+
|
159 |
+
# metadataを書き出して終わり
|
160 |
+
print(f"writing metadata: {args.out_json}")
|
161 |
+
with open(args.out_json, "wt", encoding='utf-8') as f:
|
162 |
+
json.dump(metadata, f, indent=2)
|
163 |
+
print("done!")
|
164 |
+
|
165 |
+
|
166 |
+
if __name__ == '__main__':
|
167 |
+
parser = argparse.ArgumentParser()
|
168 |
+
# parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
169 |
+
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
170 |
+
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
171 |
+
parser.add_argument("--debug", action="store_true", help="debug mode")
|
172 |
+
|
173 |
+
args, unknown = parser.parse_known_args()
|
174 |
+
if len(unknown) == 1:
|
175 |
+
print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
|
176 |
+
print("All captions and tags in the metadata are processed.")
|
177 |
+
print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
|
178 |
+
print("メタデータ内のすべてのキャプションとタグが処理されます。")
|
179 |
+
args.in_json = args.out_json
|
180 |
+
args.out_json = unknown[0]
|
181 |
+
elif len(unknown) > 0:
|
182 |
+
raise ValueError(f"error: unrecognized arguments: {unknown}")
|
183 |
+
|
184 |
+
main(args)
|
finetune/hypernetwork_nai.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# NAI compatible
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class HypernetworkModule(torch.nn.Module):
|
7 |
+
def __init__(self, dim, multiplier=1.0):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
linear1 = torch.nn.Linear(dim, dim * 2)
|
11 |
+
linear2 = torch.nn.Linear(dim * 2, dim)
|
12 |
+
linear1.weight.data.normal_(mean=0.0, std=0.01)
|
13 |
+
linear1.bias.data.zero_()
|
14 |
+
linear2.weight.data.normal_(mean=0.0, std=0.01)
|
15 |
+
linear2.bias.data.zero_()
|
16 |
+
linears = [linear1, linear2]
|
17 |
+
|
18 |
+
self.linear = torch.nn.Sequential(*linears)
|
19 |
+
self.multiplier = multiplier
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
return x + self.linear(x) * self.multiplier
|
23 |
+
|
24 |
+
|
25 |
+
class Hypernetwork(torch.nn.Module):
|
26 |
+
enable_sizes = [320, 640, 768, 1280]
|
27 |
+
# return self.modules[Hypernetwork.enable_sizes.index(size)]
|
28 |
+
|
29 |
+
def __init__(self, multiplier=1.0) -> None:
|
30 |
+
super().__init__()
|
31 |
+
self.modules = []
|
32 |
+
for size in Hypernetwork.enable_sizes:
|
33 |
+
self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier)))
|
34 |
+
self.register_module(f"{size}_0", self.modules[-1][0])
|
35 |
+
self.register_module(f"{size}_1", self.modules[-1][1])
|
36 |
+
|
37 |
+
def apply_to_stable_diffusion(self, text_encoder, vae, unet):
|
38 |
+
blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks
|
39 |
+
for block in blocks:
|
40 |
+
for subblk in block:
|
41 |
+
if 'SpatialTransformer' in str(type(subblk)):
|
42 |
+
for tf_block in subblk.transformer_blocks:
|
43 |
+
for attn in [tf_block.attn1, tf_block.attn2]:
|
44 |
+
size = attn.context_dim
|
45 |
+
if size in Hypernetwork.enable_sizes:
|
46 |
+
attn.hypernetwork = self
|
47 |
+
else:
|
48 |
+
attn.hypernetwork = None
|
49 |
+
|
50 |
+
def apply_to_diffusers(self, text_encoder, vae, unet):
|
51 |
+
blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks
|
52 |
+
for block in blocks:
|
53 |
+
if hasattr(block, 'attentions'):
|
54 |
+
for subblk in block.attentions:
|
55 |
+
if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~
|
56 |
+
for tf_block in subblk.transformer_blocks:
|
57 |
+
for attn in [tf_block.attn1, tf_block.attn2]:
|
58 |
+
size = attn.to_k.in_features
|
59 |
+
if size in Hypernetwork.enable_sizes:
|
60 |
+
attn.hypernetwork = self
|
61 |
+
else:
|
62 |
+
attn.hypernetwork = None
|
63 |
+
return True # TODO error checking
|
64 |
+
|
65 |
+
def forward(self, x, context):
|
66 |
+
size = context.shape[-1]
|
67 |
+
assert size in Hypernetwork.enable_sizes
|
68 |
+
module = self.modules[Hypernetwork.enable_sizes.index(size)]
|
69 |
+
return module[0].forward(context), module[1].forward(context)
|
70 |
+
|
71 |
+
def load_from_state_dict(self, state_dict):
|
72 |
+
# old ver to new ver
|
73 |
+
changes = {
|
74 |
+
'linear1.bias': 'linear.0.bias',
|
75 |
+
'linear1.weight': 'linear.0.weight',
|
76 |
+
'linear2.bias': 'linear.1.bias',
|
77 |
+
'linear2.weight': 'linear.1.weight',
|
78 |
+
}
|
79 |
+
for key_from, key_to in changes.items():
|
80 |
+
if key_from in state_dict:
|
81 |
+
state_dict[key_to] = state_dict[key_from]
|
82 |
+
del state_dict[key_from]
|
83 |
+
|
84 |
+
for size, sd in state_dict.items():
|
85 |
+
if type(size) == int:
|
86 |
+
self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True)
|
87 |
+
self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True)
|
88 |
+
return True
|
89 |
+
|
90 |
+
def get_state_dict(self):
|
91 |
+
state_dict = {}
|
92 |
+
for i, size in enumerate(Hypernetwork.enable_sizes):
|
93 |
+
sd0 = self.modules[i][0].state_dict()
|
94 |
+
sd1 = self.modules[i][1].state_dict()
|
95 |
+
state_dict[size] = [sd0, sd1]
|
96 |
+
return state_dict
|
finetune/make_captions.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import random
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
from tqdm import tqdm
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from torchvision import transforms
|
12 |
+
from torchvision.transforms.functional import InterpolationMode
|
13 |
+
from blip.blip import blip_decoder
|
14 |
+
import library.train_util as train_util
|
15 |
+
|
16 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
17 |
+
|
18 |
+
|
19 |
+
IMAGE_SIZE = 384
|
20 |
+
|
21 |
+
# 正方形でいいのか? という気がするがソースがそうなので
|
22 |
+
IMAGE_TRANSFORM = transforms.Compose([
|
23 |
+
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
|
24 |
+
transforms.ToTensor(),
|
25 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
26 |
+
])
|
27 |
+
|
28 |
+
# 共通化したいが微妙に処理が異なる……
|
29 |
+
class ImageLoadingTransformDataset(torch.utils.data.Dataset):
|
30 |
+
def __init__(self, image_paths):
|
31 |
+
self.images = image_paths
|
32 |
+
|
33 |
+
def __len__(self):
|
34 |
+
return len(self.images)
|
35 |
+
|
36 |
+
def __getitem__(self, idx):
|
37 |
+
img_path = self.images[idx]
|
38 |
+
|
39 |
+
try:
|
40 |
+
image = Image.open(img_path).convert("RGB")
|
41 |
+
# convert to tensor temporarily so dataloader will accept it
|
42 |
+
tensor = IMAGE_TRANSFORM(image)
|
43 |
+
except Exception as e:
|
44 |
+
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
45 |
+
return None
|
46 |
+
|
47 |
+
return (tensor, img_path)
|
48 |
+
|
49 |
+
|
50 |
+
def collate_fn_remove_corrupted(batch):
|
51 |
+
"""Collate function that allows to remove corrupted examples in the
|
52 |
+
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
53 |
+
The 'None's in the batch are removed.
|
54 |
+
"""
|
55 |
+
# Filter out all the Nones (corrupted examples)
|
56 |
+
batch = list(filter(lambda x: x is not None, batch))
|
57 |
+
return batch
|
58 |
+
|
59 |
+
|
60 |
+
def main(args):
|
61 |
+
# fix the seed for reproducibility
|
62 |
+
seed = args.seed # + utils.get_rank()
|
63 |
+
torch.manual_seed(seed)
|
64 |
+
np.random.seed(seed)
|
65 |
+
random.seed(seed)
|
66 |
+
|
67 |
+
if not os.path.exists("blip"):
|
68 |
+
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
|
69 |
+
|
70 |
+
cwd = os.getcwd()
|
71 |
+
print('Current Working Directory is: ', cwd)
|
72 |
+
os.chdir('finetune')
|
73 |
+
|
74 |
+
print(f"load images from {args.train_data_dir}")
|
75 |
+
image_paths = train_util.glob_images(args.train_data_dir)
|
76 |
+
print(f"found {len(image_paths)} images.")
|
77 |
+
|
78 |
+
print(f"loading BLIP caption: {args.caption_weights}")
|
79 |
+
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json")
|
80 |
+
model.eval()
|
81 |
+
model = model.to(DEVICE)
|
82 |
+
print("BLIP loaded")
|
83 |
+
|
84 |
+
# captioningする
|
85 |
+
def run_batch(path_imgs):
|
86 |
+
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
|
87 |
+
|
88 |
+
with torch.no_grad():
|
89 |
+
if args.beam_search:
|
90 |
+
captions = model.generate(imgs, sample=False, num_beams=args.num_beams,
|
91 |
+
max_length=args.max_length, min_length=args.min_length)
|
92 |
+
else:
|
93 |
+
captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length)
|
94 |
+
|
95 |
+
for (image_path, _), caption in zip(path_imgs, captions):
|
96 |
+
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
|
97 |
+
f.write(caption + "\n")
|
98 |
+
if args.debug:
|
99 |
+
print(image_path, caption)
|
100 |
+
|
101 |
+
# 読み込みの高速化のためにDataLoaderを使うオプション
|
102 |
+
if args.max_data_loader_n_workers is not None:
|
103 |
+
dataset = ImageLoadingTransformDataset(image_paths)
|
104 |
+
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
|
105 |
+
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
106 |
+
else:
|
107 |
+
data = [[(None, ip)] for ip in image_paths]
|
108 |
+
|
109 |
+
b_imgs = []
|
110 |
+
for data_entry in tqdm(data, smoothing=0.0):
|
111 |
+
for data in data_entry:
|
112 |
+
if data is None:
|
113 |
+
continue
|
114 |
+
|
115 |
+
img_tensor, image_path = data
|
116 |
+
if img_tensor is None:
|
117 |
+
try:
|
118 |
+
raw_image = Image.open(image_path)
|
119 |
+
if raw_image.mode != 'RGB':
|
120 |
+
raw_image = raw_image.convert("RGB")
|
121 |
+
img_tensor = IMAGE_TRANSFORM(raw_image)
|
122 |
+
except Exception as e:
|
123 |
+
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
124 |
+
continue
|
125 |
+
|
126 |
+
b_imgs.append((image_path, img_tensor))
|
127 |
+
if len(b_imgs) >= args.batch_size:
|
128 |
+
run_batch(b_imgs)
|
129 |
+
b_imgs.clear()
|
130 |
+
if len(b_imgs) > 0:
|
131 |
+
run_batch(b_imgs)
|
132 |
+
|
133 |
+
print("done!")
|
134 |
+
|
135 |
+
|
136 |
+
if __name__ == '__main__':
|
137 |
+
parser = argparse.ArgumentParser()
|
138 |
+
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
139 |
+
parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
|
140 |
+
help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)")
|
141 |
+
parser.add_argument("--caption_extention", type=str, default=None,
|
142 |
+
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
143 |
+
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
144 |
+
parser.add_argument("--beam_search", action="store_true",
|
145 |
+
help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)")
|
146 |
+
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
147 |
+
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
|
148 |
+
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
|
149 |
+
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
|
150 |
+
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
|
151 |
+
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
|
152 |
+
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
|
153 |
+
parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed')
|
154 |
+
parser.add_argument("--debug", action="store_true", help="debug mode")
|
155 |
+
|
156 |
+
args = parser.parse_args()
|
157 |
+
|
158 |
+
# スペルミスしていたオプションを復元する
|
159 |
+
if args.caption_extention is not None:
|
160 |
+
args.caption_extension = args.caption_extention
|
161 |
+
|
162 |
+
main(args)
|
finetune/make_captions_by_git.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
from tqdm import tqdm
|
7 |
+
import torch
|
8 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
9 |
+
from transformers.generation.utils import GenerationMixin
|
10 |
+
|
11 |
+
import library.train_util as train_util
|
12 |
+
|
13 |
+
|
14 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
15 |
+
|
16 |
+
PATTERN_REPLACE = [
|
17 |
+
re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
|
18 |
+
re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'),
|
19 |
+
re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"),
|
20 |
+
re.compile(r'with the number \d+ on (it|\w+ \w+)'),
|
21 |
+
re.compile(r'with the words "'),
|
22 |
+
re.compile(r'word \w+ on it'),
|
23 |
+
re.compile(r'that says the word \w+ on it'),
|
24 |
+
re.compile('that says\'the word "( on it)?'),
|
25 |
+
]
|
26 |
+
|
27 |
+
# 誤検知しまくりの with the word xxxx を消す
|
28 |
+
|
29 |
+
|
30 |
+
def remove_words(captions, debug):
|
31 |
+
removed_caps = []
|
32 |
+
for caption in captions:
|
33 |
+
cap = caption
|
34 |
+
for pat in PATTERN_REPLACE:
|
35 |
+
cap = pat.sub("", cap)
|
36 |
+
if debug and cap != caption:
|
37 |
+
print(caption)
|
38 |
+
print(cap)
|
39 |
+
removed_caps.append(cap)
|
40 |
+
return removed_caps
|
41 |
+
|
42 |
+
|
43 |
+
def collate_fn_remove_corrupted(batch):
|
44 |
+
"""Collate function that allows to remove corrupted examples in the
|
45 |
+
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
46 |
+
The 'None's in the batch are removed.
|
47 |
+
"""
|
48 |
+
# Filter out all the Nones (corrupted examples)
|
49 |
+
batch = list(filter(lambda x: x is not None, batch))
|
50 |
+
return batch
|
51 |
+
|
52 |
+
|
53 |
+
def main(args):
|
54 |
+
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
|
55 |
+
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
|
56 |
+
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
|
57 |
+
|
58 |
+
# input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す
|
59 |
+
# ここより上で置き換えようとするとすごく大変
|
60 |
+
def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs):
|
61 |
+
input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs)
|
62 |
+
if input_ids.size()[0] != curr_batch_size[0]:
|
63 |
+
input_ids = input_ids.repeat(curr_batch_size[0], 1)
|
64 |
+
return input_ids
|
65 |
+
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
|
66 |
+
|
67 |
+
print(f"load images from {args.train_data_dir}")
|
68 |
+
image_paths = train_util.glob_images(args.train_data_dir)
|
69 |
+
print(f"found {len(image_paths)} images.")
|
70 |
+
|
71 |
+
# できればcacheに依存せず明示的にダウンロードしたい
|
72 |
+
print(f"loading GIT: {args.model_id}")
|
73 |
+
git_processor = AutoProcessor.from_pretrained(args.model_id)
|
74 |
+
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
|
75 |
+
print("GIT loaded")
|
76 |
+
|
77 |
+
# captioningする
|
78 |
+
def run_batch(path_imgs):
|
79 |
+
imgs = [im for _, im in path_imgs]
|
80 |
+
|
81 |
+
curr_batch_size[0] = len(path_imgs)
|
82 |
+
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
|
83 |
+
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
|
84 |
+
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
|
85 |
+
|
86 |
+
if args.remove_words:
|
87 |
+
captions = remove_words(captions, args.debug)
|
88 |
+
|
89 |
+
for (image_path, _), caption in zip(path_imgs, captions):
|
90 |
+
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
|
91 |
+
f.write(caption + "\n")
|
92 |
+
if args.debug:
|
93 |
+
print(image_path, caption)
|
94 |
+
|
95 |
+
# 読み込みの高速化のためにDataLoaderを使うオプション
|
96 |
+
if args.max_data_loader_n_workers is not None:
|
97 |
+
dataset = train_util.ImageLoadingDataset(image_paths)
|
98 |
+
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
|
99 |
+
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
100 |
+
else:
|
101 |
+
data = [[(None, ip)] for ip in image_paths]
|
102 |
+
|
103 |
+
b_imgs = []
|
104 |
+
for data_entry in tqdm(data, smoothing=0.0):
|
105 |
+
for data in data_entry:
|
106 |
+
if data is None:
|
107 |
+
continue
|
108 |
+
|
109 |
+
image, image_path = data
|
110 |
+
if image is None:
|
111 |
+
try:
|
112 |
+
image = Image.open(image_path)
|
113 |
+
if image.mode != 'RGB':
|
114 |
+
image = image.convert("RGB")
|
115 |
+
except Exception as e:
|
116 |
+
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
117 |
+
continue
|
118 |
+
|
119 |
+
b_imgs.append((image_path, image))
|
120 |
+
if len(b_imgs) >= args.batch_size:
|
121 |
+
run_batch(b_imgs)
|
122 |
+
b_imgs.clear()
|
123 |
+
|
124 |
+
if len(b_imgs) > 0:
|
125 |
+
run_batch(b_imgs)
|
126 |
+
|
127 |
+
print("done!")
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == '__main__':
|
131 |
+
parser = argparse.ArgumentParser()
|
132 |
+
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
133 |
+
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
134 |
+
parser.add_argument("--model_id", type=str, default="microsoft/git-large-textcaps",
|
135 |
+
help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID")
|
136 |
+
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
137 |
+
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
|
138 |
+
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
|
139 |
+
parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
|
140 |
+
parser.add_argument("--remove_words", action="store_true",
|
141 |
+
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
|
142 |
+
parser.add_argument("--debug", action="store_true", help="debug mode")
|
143 |
+
|
144 |
+
args = parser.parse_args()
|
145 |
+
main(args)
|
finetune/merge_captions_to_metadata.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import List
|
5 |
+
from tqdm import tqdm
|
6 |
+
import library.train_util as train_util
|
7 |
+
|
8 |
+
|
9 |
+
def main(args):
|
10 |
+
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
|
11 |
+
|
12 |
+
train_data_dir_path = Path(args.train_data_dir)
|
13 |
+
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
14 |
+
print(f"found {len(image_paths)} images.")
|
15 |
+
|
16 |
+
if args.in_json is None and Path(args.out_json).is_file():
|
17 |
+
args.in_json = args.out_json
|
18 |
+
|
19 |
+
if args.in_json is not None:
|
20 |
+
print(f"loading existing metadata: {args.in_json}")
|
21 |
+
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
|
22 |
+
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
|
23 |
+
else:
|
24 |
+
print("new metadata will be created / 新しいメタデータファイルが作成されます")
|
25 |
+
metadata = {}
|
26 |
+
|
27 |
+
print("merge caption texts to metadata json.")
|
28 |
+
for image_path in tqdm(image_paths):
|
29 |
+
caption_path = image_path.with_suffix(args.caption_extension)
|
30 |
+
caption = caption_path.read_text(encoding='utf-8').strip()
|
31 |
+
|
32 |
+
image_key = str(image_path) if args.full_path else image_path.stem
|
33 |
+
if image_key not in metadata:
|
34 |
+
metadata[image_key] = {}
|
35 |
+
|
36 |
+
metadata[image_key]['caption'] = caption
|
37 |
+
if args.debug:
|
38 |
+
print(image_key, caption)
|
39 |
+
|
40 |
+
# metadataを書き出して終わり
|
41 |
+
print(f"writing metadata: {args.out_json}")
|
42 |
+
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
|
43 |
+
print("done!")
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == '__main__':
|
47 |
+
parser = argparse.ArgumentParser()
|
48 |
+
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
49 |
+
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
50 |
+
parser.add_argument("--in_json", type=str,
|
51 |
+
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
|
52 |
+
parser.add_argument("--caption_extention", type=str, default=None,
|
53 |
+
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
54 |
+
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
|
55 |
+
parser.add_argument("--full_path", action="store_true",
|
56 |
+
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
57 |
+
parser.add_argument("--recursive", action="store_true",
|
58 |
+
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
|
59 |
+
parser.add_argument("--debug", action="store_true", help="debug mode")
|
60 |
+
|
61 |
+
args = parser.parse_args()
|
62 |
+
|
63 |
+
# スペルミスしていたオプションを復元する
|
64 |
+
if args.caption_extention is not None:
|
65 |
+
args.caption_extension = args.caption_extention
|
66 |
+
|
67 |
+
main(args)
|
finetune/merge_dd_tags_to_metadata.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import List
|
5 |
+
from tqdm import tqdm
|
6 |
+
import library.train_util as train_util
|
7 |
+
|
8 |
+
|
9 |
+
def main(args):
|
10 |
+
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
|
11 |
+
|
12 |
+
train_data_dir_path = Path(args.train_data_dir)
|
13 |
+
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
14 |
+
print(f"found {len(image_paths)} images.")
|
15 |
+
|
16 |
+
if args.in_json is None and Path(args.out_json).is_file():
|
17 |
+
args.in_json = args.out_json
|
18 |
+
|
19 |
+
if args.in_json is not None:
|
20 |
+
print(f"loading existing metadata: {args.in_json}")
|
21 |
+
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
|
22 |
+
print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
|
23 |
+
else:
|
24 |
+
print("new metadata will be created / 新しいメタデータファイルが作成されます")
|
25 |
+
metadata = {}
|
26 |
+
|
27 |
+
print("merge tags to metadata json.")
|
28 |
+
for image_path in tqdm(image_paths):
|
29 |
+
tags_path = image_path.with_suffix(args.caption_extension)
|
30 |
+
tags = tags_path.read_text(encoding='utf-8').strip()
|
31 |
+
|
32 |
+
image_key = str(image_path) if args.full_path else image_path.stem
|
33 |
+
if image_key not in metadata:
|
34 |
+
metadata[image_key] = {}
|
35 |
+
|
36 |
+
metadata[image_key]['tags'] = tags
|
37 |
+
if args.debug:
|
38 |
+
print(image_key, tags)
|
39 |
+
|
40 |
+
# metadataを書き出して終わり
|
41 |
+
print(f"writing metadata: {args.out_json}")
|
42 |
+
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
|
43 |
+
|
44 |
+
print("done!")
|
45 |
+
|
46 |
+
|
47 |
+
if __name__ == '__main__':
|
48 |
+
parser = argparse.ArgumentParser()
|
49 |
+
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
50 |
+
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
51 |
+
parser.add_argument("--in_json", type=str,
|
52 |
+
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
|
53 |
+
parser.add_argument("--full_path", action="store_true",
|
54 |
+
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
55 |
+
parser.add_argument("--recursive", action="store_true",
|
56 |
+
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
|
57 |
+
parser.add_argument("--caption_extension", type=str, default=".txt",
|
58 |
+
help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子")
|
59 |
+
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
|
60 |
+
|
61 |
+
args = parser.parse_args()
|
62 |
+
main(args)
|
finetune/prepare_buckets_latents.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
import cv2
|
9 |
+
import torch
|
10 |
+
from torchvision import transforms
|
11 |
+
|
12 |
+
import library.model_util as model_util
|
13 |
+
import library.train_util as train_util
|
14 |
+
|
15 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
16 |
+
|
17 |
+
IMAGE_TRANSFORMS = transforms.Compose(
|
18 |
+
[
|
19 |
+
transforms.ToTensor(),
|
20 |
+
transforms.Normalize([0.5], [0.5]),
|
21 |
+
]
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
def collate_fn_remove_corrupted(batch):
|
26 |
+
"""Collate function that allows to remove corrupted examples in the
|
27 |
+
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
28 |
+
The 'None's in the batch are removed.
|
29 |
+
"""
|
30 |
+
# Filter out all the Nones (corrupted examples)
|
31 |
+
batch = list(filter(lambda x: x is not None, batch))
|
32 |
+
return batch
|
33 |
+
|
34 |
+
|
35 |
+
def get_latents(vae, images, weight_dtype):
|
36 |
+
img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
|
37 |
+
img_tensors = torch.stack(img_tensors)
|
38 |
+
img_tensors = img_tensors.to(DEVICE, weight_dtype)
|
39 |
+
with torch.no_grad():
|
40 |
+
latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy()
|
41 |
+
return latents
|
42 |
+
|
43 |
+
|
44 |
+
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
|
45 |
+
if is_full_path:
|
46 |
+
base_name = os.path.splitext(os.path.basename(image_key))[0]
|
47 |
+
else:
|
48 |
+
base_name = image_key
|
49 |
+
if flip:
|
50 |
+
base_name += '_flip'
|
51 |
+
return os.path.join(data_dir, base_name)
|
52 |
+
|
53 |
+
|
54 |
+
def main(args):
|
55 |
+
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
|
56 |
+
if args.bucket_reso_steps % 8 > 0:
|
57 |
+
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
|
58 |
+
|
59 |
+
image_paths = train_util.glob_images(args.train_data_dir)
|
60 |
+
print(f"found {len(image_paths)} images.")
|
61 |
+
|
62 |
+
if os.path.exists(args.in_json):
|
63 |
+
print(f"loading existing metadata: {args.in_json}")
|
64 |
+
with open(args.in_json, "rt", encoding='utf-8') as f:
|
65 |
+
metadata = json.load(f)
|
66 |
+
else:
|
67 |
+
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
|
68 |
+
return
|
69 |
+
|
70 |
+
weight_dtype = torch.float32
|
71 |
+
if args.mixed_precision == "fp16":
|
72 |
+
weight_dtype = torch.float16
|
73 |
+
elif args.mixed_precision == "bf16":
|
74 |
+
weight_dtype = torch.bfloat16
|
75 |
+
|
76 |
+
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
|
77 |
+
vae.eval()
|
78 |
+
vae.to(DEVICE, dtype=weight_dtype)
|
79 |
+
|
80 |
+
# bucketのサイズを計算する
|
81 |
+
max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
|
82 |
+
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
83 |
+
|
84 |
+
bucket_manager = train_util.BucketManager(args.bucket_no_upscale, max_reso,
|
85 |
+
args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps)
|
86 |
+
if not args.bucket_no_upscale:
|
87 |
+
bucket_manager.make_buckets()
|
88 |
+
else:
|
89 |
+
print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
|
90 |
+
|
91 |
+
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
|
92 |
+
img_ar_errors = []
|
93 |
+
|
94 |
+
def process_batch(is_last):
|
95 |
+
for bucket in bucket_manager.buckets:
|
96 |
+
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
|
97 |
+
latents = get_latents(vae, [img for _, img in bucket], weight_dtype)
|
98 |
+
assert latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8, \
|
99 |
+
f"latent shape {latents.shape}, {bucket[0][1].shape}"
|
100 |
+
|
101 |
+
for (image_key, _), latent in zip(bucket, latents):
|
102 |
+
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
|
103 |
+
np.savez(npz_file_name, latent)
|
104 |
+
|
105 |
+
# flip
|
106 |
+
if args.flip_aug:
|
107 |
+
latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない
|
108 |
+
|
109 |
+
for (image_key, _), latent in zip(bucket, latents):
|
110 |
+
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
|
111 |
+
np.savez(npz_file_name, latent)
|
112 |
+
else:
|
113 |
+
# remove existing flipped npz
|
114 |
+
for image_key, _ in bucket:
|
115 |
+
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz"
|
116 |
+
if os.path.isfile(npz_file_name):
|
117 |
+
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
|
118 |
+
os.remove(npz_file_name)
|
119 |
+
|
120 |
+
bucket.clear()
|
121 |
+
|
122 |
+
# 読み込みの高速化のためにDataLoaderを使うオプション
|
123 |
+
if args.max_data_loader_n_workers is not None:
|
124 |
+
dataset = train_util.ImageLoadingDataset(image_paths)
|
125 |
+
data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
|
126 |
+
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
127 |
+
else:
|
128 |
+
data = [[(None, ip)] for ip in image_paths]
|
129 |
+
|
130 |
+
bucket_counts = {}
|
131 |
+
for data_entry in tqdm(data, smoothing=0.0):
|
132 |
+
if data_entry[0] is None:
|
133 |
+
continue
|
134 |
+
|
135 |
+
img_tensor, image_path = data_entry[0]
|
136 |
+
if img_tensor is not None:
|
137 |
+
image = transforms.functional.to_pil_image(img_tensor)
|
138 |
+
else:
|
139 |
+
try:
|
140 |
+
image = Image.open(image_path)
|
141 |
+
if image.mode != 'RGB':
|
142 |
+
image = image.convert("RGB")
|
143 |
+
except Exception as e:
|
144 |
+
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
145 |
+
continue
|
146 |
+
|
147 |
+
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
148 |
+
if image_key not in metadata:
|
149 |
+
metadata[image_key] = {}
|
150 |
+
|
151 |
+
# 本当はこのあとの部分もDataSetに持っていけば高速化できるがいろいろ大変
|
152 |
+
|
153 |
+
reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height)
|
154 |
+
img_ar_errors.append(abs(ar_error))
|
155 |
+
bucket_counts[reso] = bucket_counts.get(reso, 0) + 1
|
156 |
+
|
157 |
+
# メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
|
158 |
+
metadata[image_key]['train_resolution'] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
|
159 |
+
|
160 |
+
if not args.bucket_no_upscale:
|
161 |
+
# upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
|
162 |
+
assert resized_size[0] == reso[0] or resized_size[1] == reso[
|
163 |
+
1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
|
164 |
+
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
|
165 |
+
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
|
166 |
+
|
167 |
+
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
|
168 |
+
1], f"internal error resized size is small: {resized_size}, {reso}"
|
169 |
+
|
170 |
+
# 既に存在するファイルがあればshapeを確認して同じならskipする
|
171 |
+
if args.skip_existing:
|
172 |
+
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"]
|
173 |
+
if args.flip_aug:
|
174 |
+
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz")
|
175 |
+
|
176 |
+
found = True
|
177 |
+
for npz_file in npz_files:
|
178 |
+
if not os.path.exists(npz_file):
|
179 |
+
found = False
|
180 |
+
break
|
181 |
+
|
182 |
+
dat = np.load(npz_file)['arr_0']
|
183 |
+
if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認
|
184 |
+
found = False
|
185 |
+
break
|
186 |
+
if found:
|
187 |
+
continue
|
188 |
+
|
189 |
+
# 画像をリサイズしてトリミングする
|
190 |
+
# PILにinter_areaがないのでcv2で……
|
191 |
+
image = np.array(image)
|
192 |
+
if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要?
|
193 |
+
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
|
194 |
+
|
195 |
+
if resized_size[0] > reso[0]:
|
196 |
+
trim_size = resized_size[0] - reso[0]
|
197 |
+
image = image[:, trim_size//2:trim_size//2 + reso[0]]
|
198 |
+
|
199 |
+
if resized_size[1] > reso[1]:
|
200 |
+
trim_size = resized_size[1] - reso[1]
|
201 |
+
image = image[trim_size//2:trim_size//2 + reso[1]]
|
202 |
+
|
203 |
+
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
204 |
+
|
205 |
+
# # debug
|
206 |
+
# cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1])
|
207 |
+
|
208 |
+
# バッチへ追加
|
209 |
+
bucket_manager.add_image(reso, (image_key, image))
|
210 |
+
|
211 |
+
# バッチを推論するか判定して推論する
|
212 |
+
process_batch(False)
|
213 |
+
|
214 |
+
# 残りを処理する
|
215 |
+
process_batch(True)
|
216 |
+
|
217 |
+
bucket_manager.sort()
|
218 |
+
for i, reso in enumerate(bucket_manager.resos):
|
219 |
+
count = bucket_counts.get(reso, 0)
|
220 |
+
if count > 0:
|
221 |
+
print(f"bucket {i} {reso}: {count}")
|
222 |
+
img_ar_errors = np.array(img_ar_errors)
|
223 |
+
print(f"mean ar error: {np.mean(img_ar_errors)}")
|
224 |
+
|
225 |
+
# metadataを書き出して終わり
|
226 |
+
print(f"writing metadata: {args.out_json}")
|
227 |
+
with open(args.out_json, "wt", encoding='utf-8') as f:
|
228 |
+
json.dump(metadata, f, indent=2)
|
229 |
+
print("done!")
|
230 |
+
|
231 |
+
|
232 |
+
if __name__ == '__main__':
|
233 |
+
parser = argparse.ArgumentParser()
|
234 |
+
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
235 |
+
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
236 |
+
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
237 |
+
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
|
238 |
+
parser.add_argument("--v2", action='store_true',
|
239 |
+
help='not used (for backward compatibility) / 使用されません(互換性のため残してあります)')
|
240 |
+
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
241 |
+
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
|
242 |
+
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
|
243 |
+
parser.add_argument("--max_resolution", type=str, default="512,512",
|
244 |
+
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
|
245 |
+
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
246 |
+
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
|
247 |
+
parser.add_argument("--bucket_reso_steps", type=int, default=64,
|
248 |
+
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
|
249 |
+
parser.add_argument("--bucket_no_upscale", action="store_true",
|
250 |
+
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
|
251 |
+
parser.add_argument("--mixed_precision", type=str, default="no",
|
252 |
+
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
253 |
+
parser.add_argument("--full_path", action="store_true",
|
254 |
+
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
255 |
+
parser.add_argument("--flip_aug", action="store_true",
|
256 |
+
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
|
257 |
+
parser.add_argument("--skip_existing", action="store_true",
|
258 |
+
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)")
|
259 |
+
|
260 |
+
args = parser.parse_args()
|
261 |
+
main(args)
|
finetune/tag_images_by_wd14_tagger.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import csv
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
import cv2
|
8 |
+
from tqdm import tqdm
|
9 |
+
import numpy as np
|
10 |
+
from tensorflow.keras.models import load_model
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
import torch
|
13 |
+
|
14 |
+
import library.train_util as train_util
|
15 |
+
|
16 |
+
# from wd14 tagger
|
17 |
+
IMAGE_SIZE = 448
|
18 |
+
|
19 |
+
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
|
20 |
+
DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2'
|
21 |
+
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
|
22 |
+
SUB_DIR = "variables"
|
23 |
+
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
|
24 |
+
CSV_FILE = FILES[-1]
|
25 |
+
|
26 |
+
|
27 |
+
def preprocess_image(image):
|
28 |
+
image = np.array(image)
|
29 |
+
image = image[:, :, ::-1] # RGB->BGR
|
30 |
+
|
31 |
+
# pad to square
|
32 |
+
size = max(image.shape[0:2])
|
33 |
+
pad_x = size - image.shape[1]
|
34 |
+
pad_y = size - image.shape[0]
|
35 |
+
pad_l = pad_x // 2
|
36 |
+
pad_t = pad_y // 2
|
37 |
+
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255)
|
38 |
+
|
39 |
+
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
|
40 |
+
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
|
41 |
+
|
42 |
+
image = image.astype(np.float32)
|
43 |
+
return image
|
44 |
+
|
45 |
+
|
46 |
+
class ImageLoadingPrepDataset(torch.utils.data.Dataset):
|
47 |
+
def __init__(self, image_paths):
|
48 |
+
self.images = image_paths
|
49 |
+
|
50 |
+
def __len__(self):
|
51 |
+
return len(self.images)
|
52 |
+
|
53 |
+
def __getitem__(self, idx):
|
54 |
+
img_path = self.images[idx]
|
55 |
+
|
56 |
+
try:
|
57 |
+
image = Image.open(img_path).convert("RGB")
|
58 |
+
image = preprocess_image(image)
|
59 |
+
tensor = torch.tensor(image)
|
60 |
+
except Exception as e:
|
61 |
+
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
62 |
+
return None
|
63 |
+
|
64 |
+
return (tensor, img_path)
|
65 |
+
|
66 |
+
|
67 |
+
def collate_fn_remove_corrupted(batch):
|
68 |
+
"""Collate function that allows to remove corrupted examples in the
|
69 |
+
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
70 |
+
The 'None's in the batch are removed.
|
71 |
+
"""
|
72 |
+
# Filter out all the Nones (corrupted examples)
|
73 |
+
batch = list(filter(lambda x: x is not None, batch))
|
74 |
+
return batch
|
75 |
+
|
76 |
+
|
77 |
+
def main(args):
|
78 |
+
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
|
79 |
+
# depreacatedの警告が出るけどなくなったらその時
|
80 |
+
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
81 |
+
if not os.path.exists(args.model_dir) or args.force_download:
|
82 |
+
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
83 |
+
for file in FILES:
|
84 |
+
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
|
85 |
+
for file in SUB_DIR_FILES:
|
86 |
+
hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join(
|
87 |
+
args.model_dir, SUB_DIR), force_download=True, force_filename=file)
|
88 |
+
else:
|
89 |
+
print("using existing wd14 tagger model")
|
90 |
+
|
91 |
+
# 画像を読み込む
|
92 |
+
image_paths = train_util.glob_images(args.train_data_dir)
|
93 |
+
print(f"found {len(image_paths)} images.")
|
94 |
+
|
95 |
+
print("loading model and labels")
|
96 |
+
model = load_model(args.model_dir)
|
97 |
+
|
98 |
+
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
|
99 |
+
# 依存ライブラリを増やしたくないので自力で読むよ
|
100 |
+
with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
|
101 |
+
reader = csv.reader(f)
|
102 |
+
l = [row for row in reader]
|
103 |
+
header = l[0] # tag_id,name,category,count
|
104 |
+
rows = l[1:]
|
105 |
+
assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}"
|
106 |
+
|
107 |
+
tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ
|
108 |
+
|
109 |
+
# 推論する
|
110 |
+
def run_batch(path_imgs):
|
111 |
+
imgs = np.array([im for _, im in path_imgs])
|
112 |
+
|
113 |
+
probs = model(imgs, training=False)
|
114 |
+
probs = probs.numpy()
|
115 |
+
|
116 |
+
for (image_path, _), prob in zip(path_imgs, probs):
|
117 |
+
# 最初の4つはratingなので無視する
|
118 |
+
# # First 4 labels are actually ratings: pick one with argmax
|
119 |
+
# ratings_names = label_names[:4]
|
120 |
+
# rating_index = ratings_names["probs"].argmax()
|
121 |
+
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
|
122 |
+
|
123 |
+
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
|
124 |
+
# Everything else is tags: pick any where prediction confidence > threshold
|
125 |
+
tag_text = ""
|
126 |
+
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
|
127 |
+
if p >= args.thresh and i < len(tags):
|
128 |
+
tag_text += ", " + tags[i]
|
129 |
+
|
130 |
+
if len(tag_text) > 0:
|
131 |
+
tag_text = tag_text[2:] # 最初の ", " を消す
|
132 |
+
|
133 |
+
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
|
134 |
+
f.write(tag_text + '\n')
|
135 |
+
if args.debug:
|
136 |
+
print(image_path, tag_text)
|
137 |
+
|
138 |
+
# 読み込みの高速化のためにDataLoaderを使うオプション
|
139 |
+
if args.max_data_loader_n_workers is not None:
|
140 |
+
dataset = ImageLoadingPrepDataset(image_paths)
|
141 |
+
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
|
142 |
+
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
143 |
+
else:
|
144 |
+
data = [[(None, ip)] for ip in image_paths]
|
145 |
+
|
146 |
+
b_imgs = []
|
147 |
+
for data_entry in tqdm(data, smoothing=0.0):
|
148 |
+
for data in data_entry:
|
149 |
+
if data is None:
|
150 |
+
continue
|
151 |
+
|
152 |
+
image, image_path = data
|
153 |
+
if image is not None:
|
154 |
+
image = image.detach().numpy()
|
155 |
+
else:
|
156 |
+
try:
|
157 |
+
image = Image.open(image_path)
|
158 |
+
if image.mode != 'RGB':
|
159 |
+
image = image.convert("RGB")
|
160 |
+
image = preprocess_image(image)
|
161 |
+
except Exception as e:
|
162 |
+
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
163 |
+
continue
|
164 |
+
b_imgs.append((image_path, image))
|
165 |
+
|
166 |
+
if len(b_imgs) >= args.batch_size:
|
167 |
+
run_batch(b_imgs)
|
168 |
+
b_imgs.clear()
|
169 |
+
|
170 |
+
if len(b_imgs) > 0:
|
171 |
+
run_batch(b_imgs)
|
172 |
+
|
173 |
+
print("done!")
|
174 |
+
|
175 |
+
|
176 |
+
if __name__ == '__main__':
|
177 |
+
parser = argparse.ArgumentParser()
|
178 |
+
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
179 |
+
parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
|
180 |
+
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
|
181 |
+
parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
|
182 |
+
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ")
|
183 |
+
parser.add_argument("--force_download", action='store_true',
|
184 |
+
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします")
|
185 |
+
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
|
186 |
+
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
187 |
+
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
|
188 |
+
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
|
189 |
+
parser.add_argument("--caption_extention", type=str, default=None,
|
190 |
+
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
191 |
+
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
192 |
+
parser.add_argument("--debug", action="store_true", help="debug mode")
|
193 |
+
|
194 |
+
args = parser.parse_args()
|
195 |
+
|
196 |
+
# スペルミスしていたオプションを復元する
|
197 |
+
if args.caption_extention is not None:
|
198 |
+
args.caption_extension = args.caption_extention
|
199 |
+
|
200 |
+
main(args)
|
gen_img_diffusers.py
CHANGED
@@ -47,7 +47,7 @@ VGG(
|
|
47 |
"""
|
48 |
|
49 |
import json
|
50 |
-
from typing import List, Optional, Union
|
51 |
import glob
|
52 |
import importlib
|
53 |
import inspect
|
@@ -60,7 +60,6 @@ import math
|
|
60 |
import os
|
61 |
import random
|
62 |
import re
|
63 |
-
from typing import Any, Callable, List, Optional, Union
|
64 |
|
65 |
import diffusers
|
66 |
import numpy as np
|
@@ -81,6 +80,9 @@ from PIL import Image
|
|
81 |
from PIL.PngImagePlugin import PngInfo
|
82 |
|
83 |
import library.model_util as model_util
|
|
|
|
|
|
|
84 |
|
85 |
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
86 |
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
@@ -487,6 +489,9 @@ class PipelineLike():
|
|
487 |
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
488 |
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
489 |
|
|
|
|
|
|
|
490 |
# Textual Inversion
|
491 |
def add_token_replacement(self, target_token_id, rep_token_ids):
|
492 |
self.token_replacements[target_token_id] = rep_token_ids
|
@@ -500,7 +505,11 @@ class PipelineLike():
|
|
500 |
new_tokens.append(token)
|
501 |
return new_tokens
|
502 |
|
|
|
|
|
|
|
503 |
# region xformersとか使う部分:独自に書き換えるので関係なし
|
|
|
504 |
def enable_xformers_memory_efficient_attention(self):
|
505 |
r"""
|
506 |
Enable memory efficient attention as implemented in xformers.
|
@@ -581,6 +590,8 @@ class PipelineLike():
|
|
581 |
latents: Optional[torch.FloatTensor] = None,
|
582 |
max_embeddings_multiples: Optional[int] = 3,
|
583 |
output_type: Optional[str] = "pil",
|
|
|
|
|
584 |
# return_dict: bool = True,
|
585 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
586 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
@@ -672,6 +683,9 @@ class PipelineLike():
|
|
672 |
else:
|
673 |
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
674 |
|
|
|
|
|
|
|
675 |
if strength < 0 or strength > 1:
|
676 |
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
677 |
|
@@ -752,7 +766,7 @@ class PipelineLike():
|
|
752 |
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
|
753 |
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
|
754 |
|
755 |
-
if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None:
|
756 |
if isinstance(clip_guide_images, PIL.Image.Image):
|
757 |
clip_guide_images = [clip_guide_images]
|
758 |
|
@@ -765,7 +779,7 @@ class PipelineLike():
|
|
765 |
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
|
766 |
if len(image_embeddings_clip) == 1:
|
767 |
image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
|
768 |
-
|
769 |
size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
|
770 |
clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
|
771 |
clip_guide_images = torch.cat(clip_guide_images, dim=0)
|
@@ -774,6 +788,10 @@ class PipelineLike():
|
|
774 |
image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
|
775 |
if len(image_embeddings_vgg16) == 1:
|
776 |
image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
|
|
|
|
|
|
|
|
|
777 |
|
778 |
# set timesteps
|
779 |
self.scheduler.set_timesteps(num_inference_steps, self.device)
|
@@ -781,7 +799,6 @@ class PipelineLike():
|
|
781 |
latents_dtype = text_embeddings.dtype
|
782 |
init_latents_orig = None
|
783 |
mask = None
|
784 |
-
noise = None
|
785 |
|
786 |
if init_image is None:
|
787 |
# get the initial random noise unless the user supplied it
|
@@ -813,6 +830,8 @@ class PipelineLike():
|
|
813 |
if isinstance(init_image[0], PIL.Image.Image):
|
814 |
init_image = [preprocess_image(im) for im in init_image]
|
815 |
init_image = torch.cat(init_image)
|
|
|
|
|
816 |
|
817 |
# mask image to tensor
|
818 |
if mask_image is not None:
|
@@ -823,9 +842,24 @@ class PipelineLike():
|
|
823 |
|
824 |
# encode the init image into latents and scale the latents
|
825 |
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
826 |
-
|
827 |
-
|
828 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
829 |
if len(init_latents) == 1:
|
830 |
init_latents = init_latents.repeat((batch_size, 1, 1, 1))
|
831 |
init_latents_orig = init_latents
|
@@ -864,12 +898,21 @@ class PipelineLike():
|
|
864 |
extra_step_kwargs["eta"] = eta
|
865 |
|
866 |
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
|
|
|
|
|
|
|
|
|
867 |
for i, t in enumerate(tqdm(timesteps)):
|
868 |
# expand the latents if we are doing classifier free guidance
|
869 |
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
870 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
|
|
871 |
# predict the noise residual
|
872 |
-
|
|
|
|
|
|
|
|
|
873 |
|
874 |
# perform guidance
|
875 |
if do_classifier_free_guidance:
|
@@ -911,8 +954,19 @@ class PipelineLike():
|
|
911 |
if is_cancelled_callback is not None and is_cancelled_callback():
|
912 |
return None
|
913 |
|
|
|
|
|
|
|
914 |
latents = 1 / 0.18215 * latents
|
915 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
916 |
|
917 |
image = (image / 2 + 0.5).clamp(0, 1)
|
918 |
|
@@ -1595,10 +1649,11 @@ def get_unweighted_text_embeddings(
|
|
1595 |
if pad == eos: # v1
|
1596 |
text_input_chunk[:, -1] = text_input[0, -1]
|
1597 |
else: # v2
|
1598 |
-
|
1599 |
-
text_input_chunk[
|
1600 |
-
|
1601 |
-
text_input_chunk[
|
|
|
1602 |
|
1603 |
if clip_skip is None or clip_skip == 1:
|
1604 |
text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
@@ -1799,7 +1854,7 @@ def preprocess_mask(mask):
|
|
1799 |
mask = mask.convert("L")
|
1800 |
w, h = mask.size
|
1801 |
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
1802 |
-
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.LANCZOS)
|
1803 |
mask = np.array(mask).astype(np.float32) / 255.0
|
1804 |
mask = np.tile(mask, (4, 1, 1))
|
1805 |
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
@@ -1817,6 +1872,35 @@ def preprocess_mask(mask):
|
|
1817 |
# return text_encoder
|
1818 |
|
1819 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1820 |
def main(args):
|
1821 |
if args.fp16:
|
1822 |
dtype = torch.float16
|
@@ -1881,10 +1965,7 @@ def main(args):
|
|
1881 |
# tokenizerを読み込む
|
1882 |
print("loading tokenizer")
|
1883 |
if use_stable_diffusion_format:
|
1884 |
-
|
1885 |
-
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
1886 |
-
else:
|
1887 |
-
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
|
1888 |
|
1889 |
# schedulerを用意する
|
1890 |
sched_init_args = {}
|
@@ -1995,11 +2076,13 @@ def main(args):
|
|
1995 |
# networkを組み込む
|
1996 |
if args.network_module:
|
1997 |
networks = []
|
|
|
1998 |
for i, network_module in enumerate(args.network_module):
|
1999 |
print("import network module:", network_module)
|
2000 |
imported_module = importlib.import_module(network_module)
|
2001 |
|
2002 |
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
|
|
2003 |
|
2004 |
net_kwargs = {}
|
2005 |
if args.network_args and i < len(args.network_args):
|
@@ -2014,7 +2097,7 @@ def main(args):
|
|
2014 |
network_weight = args.network_weights[i]
|
2015 |
print("load network weights from:", network_weight)
|
2016 |
|
2017 |
-
if model_util.is_safetensors(network_weight):
|
2018 |
from safetensors.torch import safe_open
|
2019 |
with safe_open(network_weight, framework="pt") as f:
|
2020 |
metadata = f.metadata()
|
@@ -2037,6 +2120,18 @@ def main(args):
|
|
2037 |
else:
|
2038 |
networks = []
|
2039 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2040 |
if args.opt_channels_last:
|
2041 |
print(f"set optimizing: channels last")
|
2042 |
text_encoder.to(memory_format=torch.channels_last)
|
@@ -2050,9 +2145,14 @@ def main(args):
|
|
2050 |
if vgg16_model is not None:
|
2051 |
vgg16_model.to(memory_format=torch.channels_last)
|
2052 |
|
|
|
|
|
|
|
|
|
2053 |
pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
|
2054 |
clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
|
2055 |
vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
|
|
|
2056 |
print("pipeline is ready.")
|
2057 |
|
2058 |
if args.diffusers_xformers:
|
@@ -2177,18 +2277,34 @@ def main(args):
|
|
2177 |
mask_images = l
|
2178 |
|
2179 |
# 画像サイズにオプション指定があるときはリサイズする
|
2180 |
-
if
|
2181 |
-
|
2182 |
-
|
|
|
2183 |
if mask_images is not None:
|
2184 |
print(f"resize img2img mask images to {args.W}*{args.H}")
|
2185 |
mask_images = resize_images(mask_images, (args.W, args.H))
|
2186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2187 |
prev_image = None # for VGG16 guided
|
2188 |
if args.guide_image_path is not None:
|
2189 |
-
print(f"load image for CLIP/VGG16 guidance: {args.guide_image_path}")
|
2190 |
-
guide_images =
|
2191 |
-
|
|
|
|
|
|
|
2192 |
if len(guide_images) == 0:
|
2193 |
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
|
2194 |
guide_images = None
|
@@ -2219,33 +2335,46 @@ def main(args):
|
|
2219 |
iter_seed = random.randint(0, 0x7fffffff)
|
2220 |
|
2221 |
# バッチ処理の関数
|
2222 |
-
def process_batch(batch, highres_fix, highres_1st=False):
|
2223 |
batch_size = len(batch)
|
2224 |
|
2225 |
# highres_fixの処理
|
2226 |
if highres_fix and not highres_1st:
|
2227 |
-
# 1st stage
|
2228 |
-
print("process 1st
|
2229 |
batch_1st = []
|
2230 |
-
for
|
2231 |
-
width_1st = int(width * args.highres_fix_scale + .5)
|
2232 |
-
height_1st = int(height * args.highres_fix_scale + .5)
|
2233 |
width_1st = width_1st - width_1st % 32
|
2234 |
height_1st = height_1st - height_1st % 32
|
2235 |
-
|
|
|
|
|
|
|
2236 |
images_1st = process_batch(batch_1st, True, True)
|
2237 |
|
2238 |
# 2nd stageのバッチを作成して以下処理する
|
2239 |
-
print("process 2nd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2240 |
batch_2nd = []
|
2241 |
-
for i, (
|
2242 |
-
|
2243 |
-
|
2244 |
-
|
|
|
2245 |
batch = batch_2nd
|
2246 |
|
2247 |
-
|
2248 |
-
|
|
|
2249 |
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
2250 |
|
2251 |
prompts = []
|
@@ -2278,7 +2407,7 @@ def main(args):
|
|
2278 |
all_images_are_same = True
|
2279 |
all_masks_are_same = True
|
2280 |
all_guide_images_are_same = True
|
2281 |
-
for i, ((_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
|
2282 |
prompts.append(prompt)
|
2283 |
negative_prompts.append(negative_prompt)
|
2284 |
seeds.append(seed)
|
@@ -2295,9 +2424,13 @@ def main(args):
|
|
2295 |
all_masks_are_same = mask_images[-2] is mask_image
|
2296 |
|
2297 |
if guide_image is not None:
|
2298 |
-
|
2299 |
-
|
2300 |
-
all_guide_images_are_same =
|
|
|
|
|
|
|
|
|
2301 |
|
2302 |
# make start code
|
2303 |
torch.manual_seed(seed)
|
@@ -2320,10 +2453,24 @@ def main(args):
|
|
2320 |
if guide_images is not None and all_guide_images_are_same:
|
2321 |
guide_images = guide_images[0]
|
2322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2323 |
# generate
|
|
|
|
|
|
|
|
|
2324 |
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
|
2325 |
-
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises,
|
2326 |
-
|
|
|
|
|
2327 |
return images
|
2328 |
|
2329 |
# save image
|
@@ -2398,6 +2545,7 @@ def main(args):
|
|
2398 |
strength = 0.8 if args.strength is None else args.strength
|
2399 |
negative_prompt = ""
|
2400 |
clip_prompt = None
|
|
|
2401 |
|
2402 |
prompt_args = prompt.strip().split(' --')
|
2403 |
prompt = prompt_args[0]
|
@@ -2461,6 +2609,15 @@ def main(args):
|
|
2461 |
clip_prompt = m.group(1)
|
2462 |
print(f"clip prompt: {clip_prompt}")
|
2463 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2464 |
except ValueError as ex:
|
2465 |
print(f"Exception in parsing / 解析エラー: {parg}")
|
2466 |
print(ex)
|
@@ -2498,7 +2655,12 @@ def main(args):
|
|
2498 |
mask_image = mask_images[global_step % len(mask_images)]
|
2499 |
|
2500 |
if guide_images is not None:
|
2501 |
-
|
|
|
|
|
|
|
|
|
|
|
2502 |
elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
|
2503 |
if prev_image is None:
|
2504 |
print("Generate 1st image without guide image.")
|
@@ -2506,10 +2668,9 @@ def main(args):
|
|
2506 |
print("Use previous image as guide image.")
|
2507 |
guide_image = prev_image
|
2508 |
|
2509 |
-
|
2510 |
-
|
2511 |
-
|
2512 |
-
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
|
2513 |
process_batch(batch_data, highres_fix)
|
2514 |
batch_data.clear()
|
2515 |
|
@@ -2553,6 +2714,8 @@ if __name__ == '__main__':
|
|
2553 |
parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
|
2554 |
parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
|
2555 |
parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
|
|
|
|
|
2556 |
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
|
2557 |
parser.add_argument('--sampler', type=str, default='ddim',
|
2558 |
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
|
@@ -2564,6 +2727,8 @@ if __name__ == '__main__':
|
|
2564 |
parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
|
2565 |
parser.add_argument("--vae", type=str, default=None,
|
2566 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
|
|
|
|
2567 |
# parser.add_argument("--replace_clip_l14_336", action='store_true',
|
2568 |
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
|
2569 |
parser.add_argument("--seed", type=int, default=None,
|
@@ -2578,12 +2743,15 @@ if __name__ == '__main__':
|
|
2578 |
parser.add_argument("--opt_channels_last", action='store_true',
|
2579 |
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
|
2580 |
parser.add_argument("--network_module", type=str, default=None, nargs='*',
|
2581 |
-
help='
|
2582 |
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
2583 |
-
help='
|
2584 |
-
parser.add_argument("--network_mul", type=float, default=None, nargs='*',
|
|
|
2585 |
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
2586 |
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
|
|
|
|
2587 |
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
|
2588 |
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
|
2589 |
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
@@ -2597,15 +2765,26 @@ if __name__ == '__main__':
|
|
2597 |
help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
|
2598 |
parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
|
2599 |
help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
|
2600 |
-
parser.add_argument("--guide_image_path", type=str, default=None,
|
|
|
2601 |
parser.add_argument("--highres_fix_scale", type=float, default=None,
|
2602 |
help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
|
2603 |
parser.add_argument("--highres_fix_steps", type=int, default=28,
|
2604 |
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
|
2605 |
parser.add_argument("--highres_fix_save_1st", action='store_true',
|
2606 |
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
|
|
|
|
|
2607 |
parser.add_argument("--negative_scale", type=float, default=None,
|
2608 |
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
|
2609 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2610 |
args = parser.parse_args()
|
2611 |
main(args)
|
|
|
47 |
"""
|
48 |
|
49 |
import json
|
50 |
+
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
|
51 |
import glob
|
52 |
import importlib
|
53 |
import inspect
|
|
|
60 |
import os
|
61 |
import random
|
62 |
import re
|
|
|
63 |
|
64 |
import diffusers
|
65 |
import numpy as np
|
|
|
80 |
from PIL.PngImagePlugin import PngInfo
|
81 |
|
82 |
import library.model_util as model_util
|
83 |
+
import library.train_util as train_util
|
84 |
+
import tools.original_control_net as original_control_net
|
85 |
+
from tools.original_control_net import ControlNetInfo
|
86 |
|
87 |
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
88 |
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
|
|
489 |
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
490 |
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
491 |
|
492 |
+
# ControlNet
|
493 |
+
self.control_nets: List[ControlNetInfo] = []
|
494 |
+
|
495 |
# Textual Inversion
|
496 |
def add_token_replacement(self, target_token_id, rep_token_ids):
|
497 |
self.token_replacements[target_token_id] = rep_token_ids
|
|
|
505 |
new_tokens.append(token)
|
506 |
return new_tokens
|
507 |
|
508 |
+
def set_control_nets(self, ctrl_nets):
|
509 |
+
self.control_nets = ctrl_nets
|
510 |
+
|
511 |
# region xformersとか使う部分:独自に書き換えるので関係なし
|
512 |
+
|
513 |
def enable_xformers_memory_efficient_attention(self):
|
514 |
r"""
|
515 |
Enable memory efficient attention as implemented in xformers.
|
|
|
590 |
latents: Optional[torch.FloatTensor] = None,
|
591 |
max_embeddings_multiples: Optional[int] = 3,
|
592 |
output_type: Optional[str] = "pil",
|
593 |
+
vae_batch_size: float = None,
|
594 |
+
return_latents: bool = False,
|
595 |
# return_dict: bool = True,
|
596 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
597 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
|
|
683 |
else:
|
684 |
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
685 |
|
686 |
+
vae_batch_size = batch_size if vae_batch_size is None else (
|
687 |
+
int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size)))
|
688 |
+
|
689 |
if strength < 0 or strength > 1:
|
690 |
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
691 |
|
|
|
766 |
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
|
767 |
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
|
768 |
|
769 |
+
if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None or self.control_nets:
|
770 |
if isinstance(clip_guide_images, PIL.Image.Image):
|
771 |
clip_guide_images = [clip_guide_images]
|
772 |
|
|
|
779 |
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
|
780 |
if len(image_embeddings_clip) == 1:
|
781 |
image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
|
782 |
+
elif self.vgg16_guidance_scale > 0:
|
783 |
size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
|
784 |
clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
|
785 |
clip_guide_images = torch.cat(clip_guide_images, dim=0)
|
|
|
788 |
image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
|
789 |
if len(image_embeddings_vgg16) == 1:
|
790 |
image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
|
791 |
+
else:
|
792 |
+
# ControlNetのhintにguide imageを流用する
|
793 |
+
# 前処理はControlNet側で行う
|
794 |
+
pass
|
795 |
|
796 |
# set timesteps
|
797 |
self.scheduler.set_timesteps(num_inference_steps, self.device)
|
|
|
799 |
latents_dtype = text_embeddings.dtype
|
800 |
init_latents_orig = None
|
801 |
mask = None
|
|
|
802 |
|
803 |
if init_image is None:
|
804 |
# get the initial random noise unless the user supplied it
|
|
|
830 |
if isinstance(init_image[0], PIL.Image.Image):
|
831 |
init_image = [preprocess_image(im) for im in init_image]
|
832 |
init_image = torch.cat(init_image)
|
833 |
+
if isinstance(init_image, list):
|
834 |
+
init_image = torch.stack(init_image)
|
835 |
|
836 |
# mask image to tensor
|
837 |
if mask_image is not None:
|
|
|
842 |
|
843 |
# encode the init image into latents and scale the latents
|
844 |
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
845 |
+
if init_image.size()[2:] == (height // 8, width // 8):
|
846 |
+
init_latents = init_image
|
847 |
+
else:
|
848 |
+
if vae_batch_size >= batch_size:
|
849 |
+
init_latent_dist = self.vae.encode(init_image).latent_dist
|
850 |
+
init_latents = init_latent_dist.sample(generator=generator)
|
851 |
+
else:
|
852 |
+
if torch.cuda.is_available():
|
853 |
+
torch.cuda.empty_cache()
|
854 |
+
init_latents = []
|
855 |
+
for i in tqdm(range(0, batch_size, vae_batch_size)):
|
856 |
+
init_latent_dist = self.vae.encode(init_image[i:i + vae_batch_size]
|
857 |
+
if vae_batch_size > 1 else init_image[i].unsqueeze(0)).latent_dist
|
858 |
+
init_latents.append(init_latent_dist.sample(generator=generator))
|
859 |
+
init_latents = torch.cat(init_latents)
|
860 |
+
|
861 |
+
init_latents = 0.18215 * init_latents
|
862 |
+
|
863 |
if len(init_latents) == 1:
|
864 |
init_latents = init_latents.repeat((batch_size, 1, 1, 1))
|
865 |
init_latents_orig = init_latents
|
|
|
898 |
extra_step_kwargs["eta"] = eta
|
899 |
|
900 |
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
|
901 |
+
|
902 |
+
if self.control_nets:
|
903 |
+
guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
|
904 |
+
|
905 |
for i, t in enumerate(tqdm(timesteps)):
|
906 |
# expand the latents if we are doing classifier free guidance
|
907 |
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
908 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
909 |
+
|
910 |
# predict the noise residual
|
911 |
+
if self.control_nets:
|
912 |
+
noise_pred = original_control_net.call_unet_and_control_net(
|
913 |
+
i, num_latent_input, self.unet, self.control_nets, guided_hints, i / len(timesteps), latent_model_input, t, text_embeddings).sample
|
914 |
+
else:
|
915 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
916 |
|
917 |
# perform guidance
|
918 |
if do_classifier_free_guidance:
|
|
|
954 |
if is_cancelled_callback is not None and is_cancelled_callback():
|
955 |
return None
|
956 |
|
957 |
+
if return_latents:
|
958 |
+
return (latents, False)
|
959 |
+
|
960 |
latents = 1 / 0.18215 * latents
|
961 |
+
if vae_batch_size >= batch_size:
|
962 |
+
image = self.vae.decode(latents).sample
|
963 |
+
else:
|
964 |
+
if torch.cuda.is_available():
|
965 |
+
torch.cuda.empty_cache()
|
966 |
+
images = []
|
967 |
+
for i in tqdm(range(0, batch_size, vae_batch_size)):
|
968 |
+
images.append(self.vae.decode(latents[i:i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample)
|
969 |
+
image = torch.cat(images)
|
970 |
|
971 |
image = (image / 2 + 0.5).clamp(0, 1)
|
972 |
|
|
|
1649 |
if pad == eos: # v1
|
1650 |
text_input_chunk[:, -1] = text_input[0, -1]
|
1651 |
else: # v2
|
1652 |
+
for j in range(len(text_input_chunk)):
|
1653 |
+
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
1654 |
+
text_input_chunk[j, -1] = eos
|
1655 |
+
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
1656 |
+
text_input_chunk[j, 1] = eos
|
1657 |
|
1658 |
if clip_skip is None or clip_skip == 1:
|
1659 |
text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
|
|
1854 |
mask = mask.convert("L")
|
1855 |
w, h = mask.size
|
1856 |
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
1857 |
+
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS)
|
1858 |
mask = np.array(mask).astype(np.float32) / 255.0
|
1859 |
mask = np.tile(mask, (4, 1, 1))
|
1860 |
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
|
|
1872 |
# return text_encoder
|
1873 |
|
1874 |
|
1875 |
+
class BatchDataBase(NamedTuple):
|
1876 |
+
# バッチ分割が必要ないデータ
|
1877 |
+
step: int
|
1878 |
+
prompt: str
|
1879 |
+
negative_prompt: str
|
1880 |
+
seed: int
|
1881 |
+
init_image: Any
|
1882 |
+
mask_image: Any
|
1883 |
+
clip_prompt: str
|
1884 |
+
guide_image: Any
|
1885 |
+
|
1886 |
+
|
1887 |
+
class BatchDataExt(NamedTuple):
|
1888 |
+
# バッチ分割が必要なデータ
|
1889 |
+
width: int
|
1890 |
+
height: int
|
1891 |
+
steps: int
|
1892 |
+
scale: float
|
1893 |
+
negative_scale: float
|
1894 |
+
strength: float
|
1895 |
+
network_muls: Tuple[float]
|
1896 |
+
|
1897 |
+
|
1898 |
+
class BatchData(NamedTuple):
|
1899 |
+
return_latents: bool
|
1900 |
+
base: BatchDataBase
|
1901 |
+
ext: BatchDataExt
|
1902 |
+
|
1903 |
+
|
1904 |
def main(args):
|
1905 |
if args.fp16:
|
1906 |
dtype = torch.float16
|
|
|
1965 |
# tokenizerを読み込む
|
1966 |
print("loading tokenizer")
|
1967 |
if use_stable_diffusion_format:
|
1968 |
+
tokenizer = train_util.load_tokenizer(args)
|
|
|
|
|
|
|
1969 |
|
1970 |
# schedulerを用意する
|
1971 |
sched_init_args = {}
|
|
|
2076 |
# networkを組み込む
|
2077 |
if args.network_module:
|
2078 |
networks = []
|
2079 |
+
network_default_muls = []
|
2080 |
for i, network_module in enumerate(args.network_module):
|
2081 |
print("import network module:", network_module)
|
2082 |
imported_module = importlib.import_module(network_module)
|
2083 |
|
2084 |
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
2085 |
+
network_default_muls.append(network_mul)
|
2086 |
|
2087 |
net_kwargs = {}
|
2088 |
if args.network_args and i < len(args.network_args):
|
|
|
2097 |
network_weight = args.network_weights[i]
|
2098 |
print("load network weights from:", network_weight)
|
2099 |
|
2100 |
+
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
2101 |
from safetensors.torch import safe_open
|
2102 |
with safe_open(network_weight, framework="pt") as f:
|
2103 |
metadata = f.metadata()
|
|
|
2120 |
else:
|
2121 |
networks = []
|
2122 |
|
2123 |
+
# ControlNetの処理
|
2124 |
+
control_nets: List[ControlNetInfo] = []
|
2125 |
+
if args.control_net_models:
|
2126 |
+
for i, model in enumerate(args.control_net_models):
|
2127 |
+
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
|
2128 |
+
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
|
2129 |
+
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
|
2130 |
+
|
2131 |
+
ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
|
2132 |
+
prep = original_control_net.load_preprocess(prep_type)
|
2133 |
+
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
|
2134 |
+
|
2135 |
if args.opt_channels_last:
|
2136 |
print(f"set optimizing: channels last")
|
2137 |
text_encoder.to(memory_format=torch.channels_last)
|
|
|
2145 |
if vgg16_model is not None:
|
2146 |
vgg16_model.to(memory_format=torch.channels_last)
|
2147 |
|
2148 |
+
for cn in control_nets:
|
2149 |
+
cn.unet.to(memory_format=torch.channels_last)
|
2150 |
+
cn.net.to(memory_format=torch.channels_last)
|
2151 |
+
|
2152 |
pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
|
2153 |
clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
|
2154 |
vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
|
2155 |
+
pipe.set_control_nets(control_nets)
|
2156 |
print("pipeline is ready.")
|
2157 |
|
2158 |
if args.diffusers_xformers:
|
|
|
2277 |
mask_images = l
|
2278 |
|
2279 |
# 画像サイズにオプション指定があるときはリサイズする
|
2280 |
+
if args.W is not None and args.H is not None:
|
2281 |
+
if init_images is not None:
|
2282 |
+
print(f"resize img2img source images to {args.W}*{args.H}")
|
2283 |
+
init_images = resize_images(init_images, (args.W, args.H))
|
2284 |
if mask_images is not None:
|
2285 |
print(f"resize img2img mask images to {args.W}*{args.H}")
|
2286 |
mask_images = resize_images(mask_images, (args.W, args.H))
|
2287 |
|
2288 |
+
if networks and mask_images:
|
2289 |
+
# mask を領域情報として流用する、現在は1枚だけ対応
|
2290 |
+
# TODO 複数のnetwork classの混在時の考慮
|
2291 |
+
print("use mask as region")
|
2292 |
+
# import cv2
|
2293 |
+
# for i in range(3):
|
2294 |
+
# cv2.imshow("msk", np.array(mask_images[0])[:,:,i])
|
2295 |
+
# cv2.waitKey()
|
2296 |
+
# cv2.destroyAllWindows()
|
2297 |
+
networks[0].__class__.set_regions(networks, np.array(mask_images[0]))
|
2298 |
+
mask_images = None
|
2299 |
+
|
2300 |
prev_image = None # for VGG16 guided
|
2301 |
if args.guide_image_path is not None:
|
2302 |
+
print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
|
2303 |
+
guide_images = []
|
2304 |
+
for p in args.guide_image_path:
|
2305 |
+
guide_images.extend(load_images(p))
|
2306 |
+
|
2307 |
+
print(f"loaded {len(guide_images)} guide images for guidance")
|
2308 |
if len(guide_images) == 0:
|
2309 |
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
|
2310 |
guide_images = None
|
|
|
2335 |
iter_seed = random.randint(0, 0x7fffffff)
|
2336 |
|
2337 |
# バッチ処理の関数
|
2338 |
+
def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
|
2339 |
batch_size = len(batch)
|
2340 |
|
2341 |
# highres_fixの処理
|
2342 |
if highres_fix and not highres_1st:
|
2343 |
+
# 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
|
2344 |
+
print("process 1st stage")
|
2345 |
batch_1st = []
|
2346 |
+
for _, base, ext in batch:
|
2347 |
+
width_1st = int(ext.width * args.highres_fix_scale + .5)
|
2348 |
+
height_1st = int(ext.height * args.highres_fix_scale + .5)
|
2349 |
width_1st = width_1st - width_1st % 32
|
2350 |
height_1st = height_1st - height_1st % 32
|
2351 |
+
|
2352 |
+
ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale,
|
2353 |
+
ext.negative_scale, ext.strength, ext.network_muls)
|
2354 |
+
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
|
2355 |
images_1st = process_batch(batch_1st, True, True)
|
2356 |
|
2357 |
# 2nd stageのバッチを作成して以下処理する
|
2358 |
+
print("process 2nd stage")
|
2359 |
+
if args.highres_fix_latents_upscaling:
|
2360 |
+
org_dtype = images_1st.dtype
|
2361 |
+
if images_1st.dtype == torch.bfloat16:
|
2362 |
+
images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
|
2363 |
+
images_1st = torch.nn.functional.interpolate(
|
2364 |
+
images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode='bilinear') # , antialias=True)
|
2365 |
+
images_1st = images_1st.to(org_dtype)
|
2366 |
+
|
2367 |
batch_2nd = []
|
2368 |
+
for i, (bd, image) in enumerate(zip(batch, images_1st)):
|
2369 |
+
if not args.highres_fix_latents_upscaling:
|
2370 |
+
image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
|
2371 |
+
bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext)
|
2372 |
+
batch_2nd.append(bd_2nd)
|
2373 |
batch = batch_2nd
|
2374 |
|
2375 |
+
# このバッチの情報を取り出す
|
2376 |
+
return_latents, (step_first, _, _, _, init_image, mask_image, _, guide_image), \
|
2377 |
+
(width, height, steps, scale, negative_scale, strength, network_muls) = batch[0]
|
2378 |
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
2379 |
|
2380 |
prompts = []
|
|
|
2407 |
all_images_are_same = True
|
2408 |
all_masks_are_same = True
|
2409 |
all_guide_images_are_same = True
|
2410 |
+
for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
|
2411 |
prompts.append(prompt)
|
2412 |
negative_prompts.append(negative_prompt)
|
2413 |
seeds.append(seed)
|
|
|
2424 |
all_masks_are_same = mask_images[-2] is mask_image
|
2425 |
|
2426 |
if guide_image is not None:
|
2427 |
+
if type(guide_image) is list:
|
2428 |
+
guide_images.extend(guide_image)
|
2429 |
+
all_guide_images_are_same = False
|
2430 |
+
else:
|
2431 |
+
guide_images.append(guide_image)
|
2432 |
+
if i > 0 and all_guide_images_are_same:
|
2433 |
+
all_guide_images_are_same = guide_images[-2] is guide_image
|
2434 |
|
2435 |
# make start code
|
2436 |
torch.manual_seed(seed)
|
|
|
2453 |
if guide_images is not None and all_guide_images_are_same:
|
2454 |
guide_images = guide_images[0]
|
2455 |
|
2456 |
+
# ControlNet使用時はguide imageをリサイズする
|
2457 |
+
if control_nets:
|
2458 |
+
# TODO resample��メソッド
|
2459 |
+
guide_images = guide_images if type(guide_images) == list else [guide_images]
|
2460 |
+
guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images]
|
2461 |
+
if len(guide_images) == 1:
|
2462 |
+
guide_images = guide_images[0]
|
2463 |
+
|
2464 |
# generate
|
2465 |
+
if networks:
|
2466 |
+
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
2467 |
+
n.set_multiplier(m)
|
2468 |
+
|
2469 |
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
|
2470 |
+
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises,
|
2471 |
+
vae_batch_size=args.vae_batch_size, return_latents=return_latents,
|
2472 |
+
clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
|
2473 |
+
if highres_1st and not args.highres_fix_save_1st: # return images or latents
|
2474 |
return images
|
2475 |
|
2476 |
# save image
|
|
|
2545 |
strength = 0.8 if args.strength is None else args.strength
|
2546 |
negative_prompt = ""
|
2547 |
clip_prompt = None
|
2548 |
+
network_muls = None
|
2549 |
|
2550 |
prompt_args = prompt.strip().split(' --')
|
2551 |
prompt = prompt_args[0]
|
|
|
2609 |
clip_prompt = m.group(1)
|
2610 |
print(f"clip prompt: {clip_prompt}")
|
2611 |
continue
|
2612 |
+
|
2613 |
+
m = re.match(r'am ([\d\.\-,]+)', parg, re.IGNORECASE)
|
2614 |
+
if m: # network multiplies
|
2615 |
+
network_muls = [float(v) for v in m.group(1).split(",")]
|
2616 |
+
while len(network_muls) < len(networks):
|
2617 |
+
network_muls.append(network_muls[-1])
|
2618 |
+
print(f"network mul: {network_muls}")
|
2619 |
+
continue
|
2620 |
+
|
2621 |
except ValueError as ex:
|
2622 |
print(f"Exception in parsing / 解析エラー: {parg}")
|
2623 |
print(ex)
|
|
|
2655 |
mask_image = mask_images[global_step % len(mask_images)]
|
2656 |
|
2657 |
if guide_images is not None:
|
2658 |
+
if control_nets: # 複数件の場合あり
|
2659 |
+
c = len(control_nets)
|
2660 |
+
p = global_step % (len(guide_images) // c)
|
2661 |
+
guide_image = guide_images[p * c:p * c + c]
|
2662 |
+
else:
|
2663 |
+
guide_image = guide_images[global_step % len(guide_images)]
|
2664 |
elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
|
2665 |
if prev_image is None:
|
2666 |
print("Generate 1st image without guide image.")
|
|
|
2668 |
print("Use previous image as guide image.")
|
2669 |
guide_image = prev_image
|
2670 |
|
2671 |
+
b1 = BatchData(False, BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
2672 |
+
BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None))
|
2673 |
+
if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
|
|
|
2674 |
process_batch(batch_data, highres_fix)
|
2675 |
batch_data.clear()
|
2676 |
|
|
|
2714 |
parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
|
2715 |
parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
|
2716 |
parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
|
2717 |
+
parser.add_argument("--vae_batch_size", type=float, default=None,
|
2718 |
+
help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率")
|
2719 |
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
|
2720 |
parser.add_argument('--sampler', type=str, default='ddim',
|
2721 |
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
|
|
|
2727 |
parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
|
2728 |
parser.add_argument("--vae", type=str, default=None,
|
2729 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
2730 |
+
parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
|
2731 |
+
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
|
2732 |
# parser.add_argument("--replace_clip_l14_336", action='store_true',
|
2733 |
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
|
2734 |
parser.add_argument("--seed", type=int, default=None,
|
|
|
2743 |
parser.add_argument("--opt_channels_last", action='store_true',
|
2744 |
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
|
2745 |
parser.add_argument("--network_module", type=str, default=None, nargs='*',
|
2746 |
+
help='additional network module to use / 追加ネットワークを使う時そのモジュール名')
|
2747 |
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
2748 |
+
help='additional network weights to load / 追加ネットワークの重み')
|
2749 |
+
parser.add_argument("--network_mul", type=float, default=None, nargs='*',
|
2750 |
+
help='additional network multiplier / 追加ネットワークの効果の倍率')
|
2751 |
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
2752 |
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
2753 |
+
parser.add_argument("--network_show_meta", action='store_true',
|
2754 |
+
help='show metadata of network model / ネットワークモデルのメタデータを表示する')
|
2755 |
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
|
2756 |
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
|
2757 |
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
|
|
2765 |
help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
|
2766 |
parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
|
2767 |
help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
|
2768 |
+
parser.add_argument("--guide_image_path", type=str, default=None, nargs="*",
|
2769 |
+
help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
|
2770 |
parser.add_argument("--highres_fix_scale", type=float, default=None,
|
2771 |
help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
|
2772 |
parser.add_argument("--highres_fix_steps", type=int, default=28,
|
2773 |
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
|
2774 |
parser.add_argument("--highres_fix_save_1st", action='store_true',
|
2775 |
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
|
2776 |
+
parser.add_argument("--highres_fix_latents_upscaling", action='store_true',
|
2777 |
+
help="use latents upscaling for highres fix / highres fixでlatentで拡大する")
|
2778 |
parser.add_argument("--negative_scale", type=float, default=None,
|
2779 |
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
|
2780 |
|
2781 |
+
parser.add_argument("--control_net_models", type=str, default=None, nargs='*',
|
2782 |
+
help='ControlNet models to use / 使用するControlNetのモデル名')
|
2783 |
+
parser.add_argument("--control_net_preps", type=str, default=None, nargs='*',
|
2784 |
+
help='ControlNet preprocess to use / 使用するControlNetのプリプロセス名')
|
2785 |
+
parser.add_argument("--control_net_weights", type=float, default=None, nargs='*', help='ControlNet weights / ControlNetの重み')
|
2786 |
+
parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*',
|
2787 |
+
help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率')
|
2788 |
+
|
2789 |
args = parser.parse_args()
|
2790 |
main(args)
|
library/model_util.py
CHANGED
@@ -4,7 +4,7 @@
|
|
4 |
import math
|
5 |
import os
|
6 |
import torch
|
7 |
-
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
|
8 |
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
9 |
from safetensors.torch import load_file, save_file
|
10 |
|
@@ -916,7 +916,11 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
|
916 |
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
917 |
else:
|
918 |
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
|
|
|
|
919 |
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
|
|
|
|
920 |
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
921 |
print("loading text encoder:", info)
|
922 |
|
|
|
4 |
import math
|
5 |
import os
|
6 |
import torch
|
7 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
8 |
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
9 |
from safetensors.torch import load_file, save_file
|
10 |
|
|
|
916 |
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
917 |
else:
|
918 |
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
919 |
+
|
920 |
+
logging.set_verbosity_error() # don't show annoying warning
|
921 |
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
922 |
+
logging.set_verbosity_warning()
|
923 |
+
|
924 |
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
925 |
print("loading text encoder:", info)
|
926 |
|
library/train_util.py
CHANGED
@@ -1,12 +1,21 @@
|
|
1 |
# common functions for training
|
2 |
|
3 |
import argparse
|
|
|
4 |
import json
|
|
|
5 |
import shutil
|
6 |
import time
|
7 |
-
from typing import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from accelerate import Accelerator
|
9 |
-
from torch.autograd.function import Function
|
10 |
import glob
|
11 |
import math
|
12 |
import os
|
@@ -17,10 +26,16 @@ from io import BytesIO
|
|
17 |
|
18 |
from tqdm import tqdm
|
19 |
import torch
|
|
|
20 |
from torchvision import transforms
|
21 |
from transformers import CLIPTokenizer
|
|
|
22 |
import diffusers
|
23 |
-
from diffusers import
|
|
|
|
|
|
|
|
|
24 |
import albumentations as albu
|
25 |
import numpy as np
|
26 |
from PIL import Image
|
@@ -195,23 +210,95 @@ class BucketBatchIndex(NamedTuple):
|
|
195 |
batch_index: int
|
196 |
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
class BaseDataset(torch.utils.data.Dataset):
|
199 |
-
def __init__(self, tokenizer, max_token_length
|
200 |
super().__init__()
|
201 |
-
self.tokenizer
|
202 |
self.max_token_length = max_token_length
|
203 |
-
self.shuffle_caption = shuffle_caption
|
204 |
-
self.shuffle_keep_tokens = shuffle_keep_tokens
|
205 |
# width/height is used when enable_bucket==False
|
206 |
self.width, self.height = (None, None) if resolution is None else resolution
|
207 |
-
self.face_crop_aug_range = face_crop_aug_range
|
208 |
-
self.flip_aug = flip_aug
|
209 |
-
self.color_aug = color_aug
|
210 |
self.debug_dataset = debug_dataset
|
211 |
-
|
|
|
|
|
212 |
self.token_padding_disabled = False
|
213 |
-
self.dataset_dirs_info = {}
|
214 |
-
self.reg_dataset_dirs_info = {}
|
215 |
self.tag_frequency = {}
|
216 |
|
217 |
self.enable_bucket = False
|
@@ -225,49 +312,28 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
225 |
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
226 |
|
227 |
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
228 |
-
self.dropout_rate: float = 0
|
229 |
-
self.dropout_every_n_epochs: int = None
|
230 |
-
self.tag_dropout_rate: float = 0
|
231 |
|
232 |
# augmentation
|
233 |
-
|
234 |
-
if color_aug:
|
235 |
-
# わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
|
236 |
-
self.aug = albu.Compose([
|
237 |
-
albu.OneOf([
|
238 |
-
albu.HueSaturationValue(8, 0, 0, p=.5),
|
239 |
-
albu.RandomGamma((95, 105), p=.5),
|
240 |
-
], p=.33),
|
241 |
-
albu.HorizontalFlip(p=flip_p)
|
242 |
-
], p=1.)
|
243 |
-
elif flip_aug:
|
244 |
-
self.aug = albu.Compose([
|
245 |
-
albu.HorizontalFlip(p=flip_p)
|
246 |
-
], p=1.)
|
247 |
-
else:
|
248 |
-
self.aug = None
|
249 |
|
250 |
self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
|
251 |
|
252 |
self.image_data: Dict[str, ImageInfo] = {}
|
|
|
253 |
|
254 |
self.replacements = {}
|
255 |
|
256 |
def set_current_epoch(self, epoch):
|
257 |
self.current_epoch = epoch
|
258 |
-
|
259 |
-
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
|
260 |
-
# コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
|
261 |
-
self.dropout_rate = dropout_rate
|
262 |
-
self.dropout_every_n_epochs = dropout_every_n_epochs
|
263 |
-
self.tag_dropout_rate = tag_dropout_rate
|
264 |
|
265 |
def set_tag_frequency(self, dir_name, captions):
|
266 |
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
267 |
self.tag_frequency[dir_name] = frequency_for_dir
|
268 |
for caption in captions:
|
269 |
for tag in caption.split(","):
|
270 |
-
|
|
|
271 |
tag = tag.lower()
|
272 |
frequency = frequency_for_dir.get(tag, 0)
|
273 |
frequency_for_dir[tag] = frequency + 1
|
@@ -278,42 +344,36 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
278 |
def add_replacement(self, str_from, str_to):
|
279 |
self.replacements[str_from] = str_to
|
280 |
|
281 |
-
def process_caption(self, caption):
|
282 |
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
283 |
-
is_drop_out =
|
284 |
-
is_drop_out = is_drop_out or
|
285 |
|
286 |
if is_drop_out:
|
287 |
caption = ""
|
288 |
else:
|
289 |
-
if
|
290 |
def dropout_tags(tokens):
|
291 |
-
if
|
292 |
return tokens
|
293 |
l = []
|
294 |
for token in tokens:
|
295 |
-
if random.random() >=
|
296 |
l.append(token)
|
297 |
return l
|
298 |
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
tokens = dropout_tags(tokens)
|
305 |
-
else:
|
306 |
-
if len(tokens) > self.shuffle_keep_tokens:
|
307 |
-
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
308 |
-
tokens = tokens[self.shuffle_keep_tokens:]
|
309 |
|
310 |
-
|
311 |
-
|
312 |
|
313 |
-
|
314 |
|
315 |
-
|
316 |
-
caption = ", ".join(tokens)
|
317 |
|
318 |
# textual inversion対応
|
319 |
for str_from, str_to in self.replacements.items():
|
@@ -367,8 +427,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
367 |
input_ids = torch.stack(iids_list) # 3,77
|
368 |
return input_ids
|
369 |
|
370 |
-
def register_image(self, info: ImageInfo):
|
371 |
self.image_data[info.image_key] = info
|
|
|
372 |
|
373 |
def make_buckets(self):
|
374 |
'''
|
@@ -467,7 +528,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
467 |
img = np.array(image, np.uint8)
|
468 |
return img
|
469 |
|
470 |
-
def trim_and_resize_if_required(self, image, reso, resized_size):
|
471 |
image_height, image_width = image.shape[0:2]
|
472 |
|
473 |
if image_width != resized_size[0] or image_height != resized_size[1]:
|
@@ -477,22 +538,27 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
477 |
image_height, image_width = image.shape[0:2]
|
478 |
if image_width > reso[0]:
|
479 |
trim_size = image_width - reso[0]
|
480 |
-
p = trim_size // 2 if not
|
481 |
# print("w", trim_size, p)
|
482 |
image = image[:, p:p + reso[0]]
|
483 |
if image_height > reso[1]:
|
484 |
trim_size = image_height - reso[1]
|
485 |
-
p = trim_size // 2 if not
|
486 |
# print("h", trim_size, p)
|
487 |
image = image[p:p + reso[1]]
|
488 |
|
489 |
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
490 |
return image
|
491 |
|
|
|
|
|
|
|
492 |
def cache_latents(self, vae):
|
493 |
# TODO ここを高速化したい
|
494 |
print("caching latents.")
|
495 |
for info in tqdm(self.image_data.values()):
|
|
|
|
|
496 |
if info.latents_npz is not None:
|
497 |
info.latents = self.load_latents_from_npz(info, False)
|
498 |
info.latents = torch.FloatTensor(info.latents)
|
@@ -502,13 +568,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
502 |
continue
|
503 |
|
504 |
image = self.load_image(info.absolute_path)
|
505 |
-
image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
|
506 |
|
507 |
img_tensor = self.image_transforms(image)
|
508 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
509 |
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
510 |
|
511 |
-
if
|
512 |
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
|
513 |
img_tensor = self.image_transforms(image)
|
514 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
@@ -518,11 +584,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
518 |
image = Image.open(image_path)
|
519 |
return image.size
|
520 |
|
521 |
-
def load_image_with_face_info(self, image_path: str):
|
522 |
img = self.load_image(image_path)
|
523 |
|
524 |
face_cx = face_cy = face_w = face_h = 0
|
525 |
-
if
|
526 |
tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
|
527 |
if len(tokens) >= 5:
|
528 |
face_cx = int(tokens[-4])
|
@@ -533,7 +599,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
533 |
return img, face_cx, face_cy, face_w, face_h
|
534 |
|
535 |
# いい感じに切り出す
|
536 |
-
def crop_target(self, image, face_cx, face_cy, face_w, face_h):
|
537 |
height, width = image.shape[0:2]
|
538 |
if height == self.height and width == self.width:
|
539 |
return image
|
@@ -541,8 +607,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
541 |
# 画像サイズはsizeより大きいのでリサイズする
|
542 |
face_size = max(face_w, face_h)
|
543 |
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
|
544 |
-
min_scale = min(1.0, max(min_scale, self.size / (face_size *
|
545 |
-
max_scale = min(1.0, max(min_scale, self.size / (face_size *
|
546 |
if min_scale >= max_scale: # range指定がmin==max
|
547 |
scale = min_scale
|
548 |
else:
|
@@ -560,13 +626,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
560 |
for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
|
561 |
p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
|
562 |
|
563 |
-
if
|
564 |
# 背景も含めるために顔を中心に置く確率を高めつつずらす
|
565 |
range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
|
566 |
p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
|
567 |
else:
|
568 |
# range指定があるときのみ、すこしだけランダムに(わりと適当)
|
569 |
-
if
|
570 |
if face_size > self.size // 10 and face_size >= 40:
|
571 |
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
|
572 |
|
@@ -589,9 +655,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
589 |
return self._length
|
590 |
|
591 |
def __getitem__(self, index):
|
592 |
-
if index == 0:
|
593 |
-
self.shuffle_buckets()
|
594 |
-
|
595 |
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
596 |
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
597 |
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
@@ -604,28 +667,29 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
604 |
|
605 |
for image_key in bucket[image_index:image_index + bucket_batch_size]:
|
606 |
image_info = self.image_data[image_key]
|
|
|
607 |
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
608 |
|
609 |
# image/latentsを処理する
|
610 |
if image_info.latents is not None:
|
611 |
-
latents = image_info.latents if not
|
612 |
image = None
|
613 |
elif image_info.latents_npz is not None:
|
614 |
-
latents = self.load_latents_from_npz(image_info,
|
615 |
latents = torch.FloatTensor(latents)
|
616 |
image = None
|
617 |
else:
|
618 |
# 画像を読み込み、必要ならcropする
|
619 |
-
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
|
620 |
im_h, im_w = img.shape[0:2]
|
621 |
|
622 |
if self.enable_bucket:
|
623 |
-
img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
|
624 |
else:
|
625 |
if face_cx > 0: # 顔位置情報あり
|
626 |
-
img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
|
627 |
elif im_h > self.height or im_w > self.width:
|
628 |
-
assert
|
629 |
if im_h > self.height:
|
630 |
p = random.randint(0, im_h - self.height)
|
631 |
img = img[p:p + self.height]
|
@@ -637,8 +701,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
637 |
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
638 |
|
639 |
# augmentation
|
640 |
-
|
641 |
-
|
|
|
642 |
|
643 |
latents = None
|
644 |
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
@@ -646,7 +711,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
646 |
images.append(image)
|
647 |
latents_list.append(latents)
|
648 |
|
649 |
-
caption = self.process_caption(image_info.caption)
|
650 |
captions.append(caption)
|
651 |
if not self.token_padding_disabled: # this option might be omitted in future
|
652 |
input_ids_list.append(self.get_input_ids(caption))
|
@@ -677,9 +742,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
677 |
|
678 |
|
679 |
class DreamBoothDataset(BaseDataset):
|
680 |
-
def __init__(self,
|
681 |
-
super().__init__(tokenizer, max_token_length,
|
682 |
-
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
683 |
|
684 |
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
685 |
|
@@ -702,7 +766,7 @@ class DreamBoothDataset(BaseDataset):
|
|
702 |
self.bucket_reso_steps = None # この情報は使われない
|
703 |
self.bucket_no_upscale = False
|
704 |
|
705 |
-
def read_caption(img_path):
|
706 |
# captionの候補ファイル名を作る
|
707 |
base_name = os.path.splitext(img_path)[0]
|
708 |
base_name_face_det = base_name
|
@@ -725,153 +789,181 @@ class DreamBoothDataset(BaseDataset):
|
|
725 |
break
|
726 |
return caption
|
727 |
|
728 |
-
def load_dreambooth_dir(
|
729 |
-
if not os.path.isdir(
|
730 |
-
|
731 |
-
return
|
732 |
|
733 |
-
|
734 |
-
|
735 |
-
n_repeats = int(tokens[0])
|
736 |
-
except ValueError as e:
|
737 |
-
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
|
738 |
-
return 0, [], []
|
739 |
-
|
740 |
-
caption_by_folder = '_'.join(tokens[1:])
|
741 |
-
img_paths = glob_images(dir, "*")
|
742 |
-
print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
|
743 |
|
744 |
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
745 |
captions = []
|
746 |
for img_path in img_paths:
|
747 |
-
cap_for_img = read_caption(img_path)
|
748 |
-
|
|
|
|
|
|
|
|
|
749 |
|
750 |
-
self.set_tag_frequency(os.path.basename(
|
751 |
|
752 |
-
return
|
753 |
|
754 |
-
print("prepare
|
755 |
-
train_dirs = os.listdir(train_data_dir)
|
756 |
num_train_images = 0
|
757 |
-
|
758 |
-
|
759 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
760 |
|
761 |
for img_path, caption in zip(img_paths, captions):
|
762 |
-
info = ImageInfo(img_path,
|
763 |
-
|
|
|
|
|
|
|
764 |
|
765 |
-
|
|
|
766 |
|
767 |
print(f"{num_train_images} train images with repeating.")
|
768 |
self.num_train_images = num_train_images
|
769 |
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
print("prepare reg images.")
|
774 |
-
reg_infos: List[ImageInfo] = []
|
775 |
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
780 |
|
781 |
-
|
782 |
-
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
783 |
-
reg_infos.append(info)
|
784 |
|
785 |
-
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
786 |
|
787 |
-
|
788 |
-
|
789 |
-
|
|
|
|
|
|
|
|
|
|
|
790 |
|
791 |
-
|
792 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
793 |
else:
|
794 |
-
|
795 |
-
n = 0
|
796 |
-
first_loop = True
|
797 |
-
while n < num_train_images:
|
798 |
-
for info in reg_infos:
|
799 |
-
if first_loop:
|
800 |
-
self.register_image(info)
|
801 |
-
n += info.num_repeats
|
802 |
-
else:
|
803 |
-
info.num_repeats += 1
|
804 |
-
n += 1
|
805 |
-
if n >= num_train_images:
|
806 |
-
break
|
807 |
-
first_loop = False
|
808 |
|
809 |
-
|
|
|
|
|
810 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
811 |
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
print(f"loading existing metadata: {json_file_name}")
|
820 |
-
with open(json_file_name, "rt", encoding='utf-8') as f:
|
821 |
-
metadata = json.load(f)
|
822 |
-
else:
|
823 |
-
raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}")
|
824 |
|
825 |
-
|
826 |
-
|
827 |
-
self.batch_size = batch_size
|
828 |
|
829 |
-
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
838 |
-
abs_path = abs_path[0]
|
839 |
-
|
840 |
-
caption = img_md.get('caption')
|
841 |
-
tags = img_md.get('tags')
|
842 |
-
if caption is None:
|
843 |
-
caption = tags
|
844 |
-
elif tags is not None and len(tags) > 0:
|
845 |
-
caption = caption + ', ' + tags
|
846 |
-
tags_list.append(tags)
|
847 |
-
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
|
848 |
-
|
849 |
-
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
|
850 |
-
image_info.image_size = img_md.get('train_resolution')
|
851 |
-
|
852 |
-
if not self.color_aug and not self.random_crop:
|
853 |
-
# if npz exists, use them
|
854 |
-
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
|
855 |
-
|
856 |
-
self.register_image(image_info)
|
857 |
-
self.num_train_images = len(metadata) * dataset_repeats
|
858 |
-
self.num_reg_images = 0
|
859 |
|
860 |
-
|
861 |
-
|
862 |
-
|
|
|
|
|
|
|
863 |
|
864 |
# check existence of all npz files
|
865 |
-
use_npz_latents = not (
|
866 |
if use_npz_latents:
|
|
|
867 |
npz_any = False
|
868 |
npz_all = True
|
|
|
869 |
for image_info in self.image_data.values():
|
|
|
|
|
870 |
has_npz = image_info.latents_npz is not None
|
871 |
npz_any = npz_any or has_npz
|
872 |
|
873 |
-
if
|
874 |
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
|
|
875 |
npz_all = npz_all and has_npz
|
876 |
|
877 |
if npz_any and not npz_all:
|
@@ -883,7 +975,7 @@ class FineTuningDataset(BaseDataset):
|
|
883 |
elif not npz_all:
|
884 |
use_npz_latents = False
|
885 |
print(f"some of npz file does not exist. ignore npz files / いくつ���のnpzファイルが見つからないためnpzファイルを無視します")
|
886 |
-
if
|
887 |
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
888 |
# else:
|
889 |
# print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
@@ -929,7 +1021,7 @@ class FineTuningDataset(BaseDataset):
|
|
929 |
for image_info in self.image_data.values():
|
930 |
image_info.latents_npz = image_info.latents_npz_flipped = None
|
931 |
|
932 |
-
def image_key_to_npz_file(self, image_key):
|
933 |
base_name = os.path.splitext(image_key)[0]
|
934 |
npz_file_norm = base_name + '.npz'
|
935 |
|
@@ -941,8 +1033,8 @@ class FineTuningDataset(BaseDataset):
|
|
941 |
return npz_file_norm, npz_file_flip
|
942 |
|
943 |
# image_key is relative path
|
944 |
-
npz_file_norm = os.path.join(
|
945 |
-
npz_file_flip = os.path.join(
|
946 |
|
947 |
if not os.path.exists(npz_file_norm):
|
948 |
npz_file_norm = None
|
@@ -953,13 +1045,60 @@ class FineTuningDataset(BaseDataset):
|
|
953 |
return npz_file_norm, npz_file_flip
|
954 |
|
955 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
956 |
def debug_dataset(train_dataset, show_input_ids=False):
|
957 |
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
958 |
print("Escape for exit. / Escキーで中断、終了します")
|
959 |
|
960 |
train_dataset.set_current_epoch(1)
|
961 |
k = 0
|
962 |
-
|
|
|
|
|
|
|
963 |
if example['latents'] is not None:
|
964 |
print(f"sample has latents from npz file: {example['latents'].size()}")
|
965 |
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
|
@@ -1364,6 +1503,35 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
|
1364 |
help='enable v-parameterization training / v-parameterization学習を有効にする')
|
1365 |
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
|
1366 |
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1367 |
|
1368 |
|
1369 |
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
@@ -1387,10 +1555,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|
1387 |
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
1388 |
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
1389 |
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
1390 |
-
parser.add_argument("--use_8bit_adam", action="store_true",
|
1391 |
-
help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
1392 |
-
parser.add_argument("--use_lion_optimizer", action="store_true",
|
1393 |
-
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
1394 |
parser.add_argument("--mem_eff_attn", action="store_true",
|
1395 |
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
1396 |
parser.add_argument("--xformers", action="store_true",
|
@@ -1398,7 +1562,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|
1398 |
parser.add_argument("--vae", type=str, default=None,
|
1399 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
1400 |
|
1401 |
-
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
1402 |
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
1403 |
parser.add_argument("--max_train_epochs", type=int, default=None,
|
1404 |
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
@@ -1419,15 +1582,23 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|
1419 |
parser.add_argument("--logging_dir", type=str, default=None,
|
1420 |
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
1421 |
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
1422 |
-
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
1423 |
-
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
|
1424 |
-
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
1425 |
-
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
1426 |
parser.add_argument("--noise_offset", type=float, default=None,
|
1427 |
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
|
1428 |
parser.add_argument("--lowram", action="store_true",
|
1429 |
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
|
1430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1431 |
if support_dreambooth:
|
1432 |
# DreamBooth training
|
1433 |
parser.add_argument("--prior_loss_weight", type=float, default=1.0,
|
@@ -1449,8 +1620,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
|
1449 |
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
|
1450 |
parser.add_argument("--caption_extention", type=str, default=None,
|
1451 |
help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
|
1452 |
-
parser.add_argument("--keep_tokens", type=int, default=
|
1453 |
-
help="keep heading N tokens when shuffling caption tokens / caption
|
1454 |
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
|
1455 |
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
|
1456 |
parser.add_argument("--face_crop_aug_range", type=str, default=None,
|
@@ -1475,11 +1646,11 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
|
1475 |
if support_caption_dropout:
|
1476 |
# Textual Inversion はcaptionのdropoutをsupportしない
|
1477 |
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
1478 |
-
parser.add_argument("--caption_dropout_rate", type=float, default=0,
|
1479 |
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
1480 |
-
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=
|
1481 |
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
1482 |
-
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
|
1483 |
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
|
1484 |
|
1485 |
if support_dreambooth:
|
@@ -1504,16 +1675,256 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
|
1504 |
# region utils
|
1505 |
|
1506 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1507 |
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
1508 |
# backward compatibility
|
1509 |
if args.caption_extention is not None:
|
1510 |
args.caption_extension = args.caption_extention
|
1511 |
args.caption_extention = None
|
1512 |
|
1513 |
-
if args.cache_latents:
|
1514 |
-
assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
|
1515 |
-
assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
|
1516 |
-
|
1517 |
# assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
|
1518 |
if args.resolution is not None:
|
1519 |
args.resolution = tuple([int(r) for r in args.resolution.split(',')])
|
@@ -1536,12 +1947,28 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
|
1536 |
|
1537 |
def load_tokenizer(args: argparse.Namespace):
|
1538 |
print("prepare tokenizer")
|
1539 |
-
if args.v2
|
1540 |
-
|
1541 |
-
|
1542 |
-
|
1543 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1544 |
print(f"update token length: {args.max_token_length}")
|
|
|
|
|
|
|
|
|
|
|
1545 |
return tokenizer
|
1546 |
|
1547 |
|
@@ -1592,13 +2019,19 @@ def prepare_dtype(args: argparse.Namespace):
|
|
1592 |
|
1593 |
|
1594 |
def load_target_model(args: argparse.Namespace, weight_dtype):
|
1595 |
-
|
|
|
|
|
1596 |
if load_stable_diffusion_format:
|
1597 |
print("load StableDiffusion checkpoint")
|
1598 |
-
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2,
|
1599 |
else:
|
1600 |
print("load Diffusers pretrained models")
|
1601 |
-
|
|
|
|
|
|
|
|
|
1602 |
text_encoder = pipe.text_encoder
|
1603 |
vae = pipe.vae
|
1604 |
unet = pipe.unet
|
@@ -1767,6 +2200,197 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
|
1767 |
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
1768 |
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
1769 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1770 |
# endregion
|
1771 |
|
1772 |
# region 前処理用
|
|
|
1 |
# common functions for training
|
2 |
|
3 |
import argparse
|
4 |
+
import importlib
|
5 |
import json
|
6 |
+
import re
|
7 |
import shutil
|
8 |
import time
|
9 |
+
from typing import (
|
10 |
+
Dict,
|
11 |
+
List,
|
12 |
+
NamedTuple,
|
13 |
+
Optional,
|
14 |
+
Sequence,
|
15 |
+
Tuple,
|
16 |
+
Union,
|
17 |
+
)
|
18 |
from accelerate import Accelerator
|
|
|
19 |
import glob
|
20 |
import math
|
21 |
import os
|
|
|
26 |
|
27 |
from tqdm import tqdm
|
28 |
import torch
|
29 |
+
from torch.optim import Optimizer
|
30 |
from torchvision import transforms
|
31 |
from transformers import CLIPTokenizer
|
32 |
+
import transformers
|
33 |
import diffusers
|
34 |
+
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
35 |
+
from diffusers import (StableDiffusionPipeline, DDPMScheduler,
|
36 |
+
EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler,
|
37 |
+
LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler,
|
38 |
+
KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler)
|
39 |
import albumentations as albu
|
40 |
import numpy as np
|
41 |
from PIL import Image
|
|
|
210 |
batch_index: int
|
211 |
|
212 |
|
213 |
+
class AugHelper:
|
214 |
+
def __init__(self):
|
215 |
+
# prepare all possible augmentators
|
216 |
+
color_aug_method = albu.OneOf([
|
217 |
+
albu.HueSaturationValue(8, 0, 0, p=.5),
|
218 |
+
albu.RandomGamma((95, 105), p=.5),
|
219 |
+
], p=.33)
|
220 |
+
flip_aug_method = albu.HorizontalFlip(p=0.5)
|
221 |
+
|
222 |
+
# key: (use_color_aug, use_flip_aug)
|
223 |
+
self.augmentors = {
|
224 |
+
(True, True): albu.Compose([
|
225 |
+
color_aug_method,
|
226 |
+
flip_aug_method,
|
227 |
+
], p=1.),
|
228 |
+
(True, False): albu.Compose([
|
229 |
+
color_aug_method,
|
230 |
+
], p=1.),
|
231 |
+
(False, True): albu.Compose([
|
232 |
+
flip_aug_method,
|
233 |
+
], p=1.),
|
234 |
+
(False, False): None
|
235 |
+
}
|
236 |
+
|
237 |
+
def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]:
|
238 |
+
return self.augmentors[(use_color_aug, use_flip_aug)]
|
239 |
+
|
240 |
+
|
241 |
+
class BaseSubset:
|
242 |
+
def __init__(self, image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, keep_tokens: int, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], random_crop: bool, caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float) -> None:
|
243 |
+
self.image_dir = image_dir
|
244 |
+
self.num_repeats = num_repeats
|
245 |
+
self.shuffle_caption = shuffle_caption
|
246 |
+
self.keep_tokens = keep_tokens
|
247 |
+
self.color_aug = color_aug
|
248 |
+
self.flip_aug = flip_aug
|
249 |
+
self.face_crop_aug_range = face_crop_aug_range
|
250 |
+
self.random_crop = random_crop
|
251 |
+
self.caption_dropout_rate = caption_dropout_rate
|
252 |
+
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
|
253 |
+
self.caption_tag_dropout_rate = caption_tag_dropout_rate
|
254 |
+
|
255 |
+
self.img_count = 0
|
256 |
+
|
257 |
+
|
258 |
+
class DreamBoothSubset(BaseSubset):
|
259 |
+
def __init__(self, image_dir: str, is_reg: bool, class_tokens: Optional[str], caption_extension: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
|
260 |
+
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
261 |
+
|
262 |
+
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
|
263 |
+
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
|
264 |
+
|
265 |
+
self.is_reg = is_reg
|
266 |
+
self.class_tokens = class_tokens
|
267 |
+
self.caption_extension = caption_extension
|
268 |
+
|
269 |
+
def __eq__(self, other) -> bool:
|
270 |
+
if not isinstance(other, DreamBoothSubset):
|
271 |
+
return NotImplemented
|
272 |
+
return self.image_dir == other.image_dir
|
273 |
+
|
274 |
+
|
275 |
+
class FineTuningSubset(BaseSubset):
|
276 |
+
def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
|
277 |
+
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
278 |
+
|
279 |
+
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
|
280 |
+
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
|
281 |
+
|
282 |
+
self.metadata_file = metadata_file
|
283 |
+
|
284 |
+
def __eq__(self, other) -> bool:
|
285 |
+
if not isinstance(other, FineTuningSubset):
|
286 |
+
return NotImplemented
|
287 |
+
return self.metadata_file == other.metadata_file
|
288 |
+
|
289 |
+
|
290 |
class BaseDataset(torch.utils.data.Dataset):
|
291 |
+
def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None:
|
292 |
super().__init__()
|
293 |
+
self.tokenizer = tokenizer
|
294 |
self.max_token_length = max_token_length
|
|
|
|
|
295 |
# width/height is used when enable_bucket==False
|
296 |
self.width, self.height = (None, None) if resolution is None else resolution
|
|
|
|
|
|
|
297 |
self.debug_dataset = debug_dataset
|
298 |
+
|
299 |
+
self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
|
300 |
+
|
301 |
self.token_padding_disabled = False
|
|
|
|
|
302 |
self.tag_frequency = {}
|
303 |
|
304 |
self.enable_bucket = False
|
|
|
312 |
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
313 |
|
314 |
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
|
|
|
|
|
|
315 |
|
316 |
# augmentation
|
317 |
+
self.aug_helper = AugHelper()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
|
319 |
self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
|
320 |
|
321 |
self.image_data: Dict[str, ImageInfo] = {}
|
322 |
+
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
|
323 |
|
324 |
self.replacements = {}
|
325 |
|
326 |
def set_current_epoch(self, epoch):
|
327 |
self.current_epoch = epoch
|
328 |
+
self.shuffle_buckets()
|
|
|
|
|
|
|
|
|
|
|
329 |
|
330 |
def set_tag_frequency(self, dir_name, captions):
|
331 |
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
332 |
self.tag_frequency[dir_name] = frequency_for_dir
|
333 |
for caption in captions:
|
334 |
for tag in caption.split(","):
|
335 |
+
tag = tag.strip()
|
336 |
+
if tag:
|
337 |
tag = tag.lower()
|
338 |
frequency = frequency_for_dir.get(tag, 0)
|
339 |
frequency_for_dir[tag] = frequency + 1
|
|
|
344 |
def add_replacement(self, str_from, str_to):
|
345 |
self.replacements[str_from] = str_to
|
346 |
|
347 |
+
def process_caption(self, subset: BaseSubset, caption):
|
348 |
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
349 |
+
is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate
|
350 |
+
is_drop_out = is_drop_out or subset.caption_dropout_every_n_epochs > 0 and self.current_epoch % subset.caption_dropout_every_n_epochs == 0
|
351 |
|
352 |
if is_drop_out:
|
353 |
caption = ""
|
354 |
else:
|
355 |
+
if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0:
|
356 |
def dropout_tags(tokens):
|
357 |
+
if subset.caption_tag_dropout_rate <= 0:
|
358 |
return tokens
|
359 |
l = []
|
360 |
for token in tokens:
|
361 |
+
if random.random() >= subset.caption_tag_dropout_rate:
|
362 |
l.append(token)
|
363 |
return l
|
364 |
|
365 |
+
fixed_tokens = []
|
366 |
+
flex_tokens = [t.strip() for t in caption.strip().split(",")]
|
367 |
+
if subset.keep_tokens > 0:
|
368 |
+
fixed_tokens = flex_tokens[:subset.keep_tokens]
|
369 |
+
flex_tokens = flex_tokens[subset.keep_tokens:]
|
|
|
|
|
|
|
|
|
|
|
370 |
|
371 |
+
if subset.shuffle_caption:
|
372 |
+
random.shuffle(flex_tokens)
|
373 |
|
374 |
+
flex_tokens = dropout_tags(flex_tokens)
|
375 |
|
376 |
+
caption = ", ".join(fixed_tokens + flex_tokens)
|
|
|
377 |
|
378 |
# textual inversion対応
|
379 |
for str_from, str_to in self.replacements.items():
|
|
|
427 |
input_ids = torch.stack(iids_list) # 3,77
|
428 |
return input_ids
|
429 |
|
430 |
+
def register_image(self, info: ImageInfo, subset: BaseSubset):
|
431 |
self.image_data[info.image_key] = info
|
432 |
+
self.image_to_subset[info.image_key] = subset
|
433 |
|
434 |
def make_buckets(self):
|
435 |
'''
|
|
|
528 |
img = np.array(image, np.uint8)
|
529 |
return img
|
530 |
|
531 |
+
def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size):
|
532 |
image_height, image_width = image.shape[0:2]
|
533 |
|
534 |
if image_width != resized_size[0] or image_height != resized_size[1]:
|
|
|
538 |
image_height, image_width = image.shape[0:2]
|
539 |
if image_width > reso[0]:
|
540 |
trim_size = image_width - reso[0]
|
541 |
+
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
|
542 |
# print("w", trim_size, p)
|
543 |
image = image[:, p:p + reso[0]]
|
544 |
if image_height > reso[1]:
|
545 |
trim_size = image_height - reso[1]
|
546 |
+
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
|
547 |
# print("h", trim_size, p)
|
548 |
image = image[p:p + reso[1]]
|
549 |
|
550 |
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
551 |
return image
|
552 |
|
553 |
+
def is_latent_cacheable(self):
|
554 |
+
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
555 |
+
|
556 |
def cache_latents(self, vae):
|
557 |
# TODO ここを高速化したい
|
558 |
print("caching latents.")
|
559 |
for info in tqdm(self.image_data.values()):
|
560 |
+
subset = self.image_to_subset[info.image_key]
|
561 |
+
|
562 |
if info.latents_npz is not None:
|
563 |
info.latents = self.load_latents_from_npz(info, False)
|
564 |
info.latents = torch.FloatTensor(info.latents)
|
|
|
568 |
continue
|
569 |
|
570 |
image = self.load_image(info.absolute_path)
|
571 |
+
image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
|
572 |
|
573 |
img_tensor = self.image_transforms(image)
|
574 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
575 |
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
576 |
|
577 |
+
if subset.flip_aug:
|
578 |
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
|
579 |
img_tensor = self.image_transforms(image)
|
580 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
|
|
584 |
image = Image.open(image_path)
|
585 |
return image.size
|
586 |
|
587 |
+
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
|
588 |
img = self.load_image(image_path)
|
589 |
|
590 |
face_cx = face_cy = face_w = face_h = 0
|
591 |
+
if subset.face_crop_aug_range is not None:
|
592 |
tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
|
593 |
if len(tokens) >= 5:
|
594 |
face_cx = int(tokens[-4])
|
|
|
599 |
return img, face_cx, face_cy, face_w, face_h
|
600 |
|
601 |
# いい感じに切り出す
|
602 |
+
def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h):
|
603 |
height, width = image.shape[0:2]
|
604 |
if height == self.height and width == self.width:
|
605 |
return image
|
|
|
607 |
# 画像サイズはsizeより大きいのでリサイズする
|
608 |
face_size = max(face_w, face_h)
|
609 |
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
|
610 |
+
min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
|
611 |
+
max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
|
612 |
if min_scale >= max_scale: # range指定がmin==max
|
613 |
scale = min_scale
|
614 |
else:
|
|
|
626 |
for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
|
627 |
p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
|
628 |
|
629 |
+
if subset.random_crop:
|
630 |
# 背景も含めるために顔を中心に置く確率を高めつつずらす
|
631 |
range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
|
632 |
p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
|
633 |
else:
|
634 |
# range指定があるときのみ、すこしだけランダムに(わりと適当)
|
635 |
+
if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]:
|
636 |
if face_size > self.size // 10 and face_size >= 40:
|
637 |
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
|
638 |
|
|
|
655 |
return self._length
|
656 |
|
657 |
def __getitem__(self, index):
|
|
|
|
|
|
|
658 |
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
659 |
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
660 |
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
|
|
667 |
|
668 |
for image_key in bucket[image_index:image_index + bucket_batch_size]:
|
669 |
image_info = self.image_data[image_key]
|
670 |
+
subset = self.image_to_subset[image_key]
|
671 |
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
672 |
|
673 |
# image/latentsを処理する
|
674 |
if image_info.latents is not None:
|
675 |
+
latents = image_info.latents if not subset.flip_aug or random.random() < .5 else image_info.latents_flipped
|
676 |
image = None
|
677 |
elif image_info.latents_npz is not None:
|
678 |
+
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= .5)
|
679 |
latents = torch.FloatTensor(latents)
|
680 |
image = None
|
681 |
else:
|
682 |
# 画像を読み込み、必要ならcropする
|
683 |
+
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path)
|
684 |
im_h, im_w = img.shape[0:2]
|
685 |
|
686 |
if self.enable_bucket:
|
687 |
+
img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size)
|
688 |
else:
|
689 |
if face_cx > 0: # 顔位置情報あり
|
690 |
+
img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
|
691 |
elif im_h > self.height or im_w > self.width:
|
692 |
+
assert subset.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
|
693 |
if im_h > self.height:
|
694 |
p = random.randint(0, im_h - self.height)
|
695 |
img = img[p:p + self.height]
|
|
|
701 |
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
702 |
|
703 |
# augmentation
|
704 |
+
aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
|
705 |
+
if aug is not None:
|
706 |
+
img = aug(image=img)['image']
|
707 |
|
708 |
latents = None
|
709 |
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
|
|
711 |
images.append(image)
|
712 |
latents_list.append(latents)
|
713 |
|
714 |
+
caption = self.process_caption(subset, image_info.caption)
|
715 |
captions.append(caption)
|
716 |
if not self.token_padding_disabled: # this option might be omitted in future
|
717 |
input_ids_list.append(self.get_input_ids(caption))
|
|
|
742 |
|
743 |
|
744 |
class DreamBoothDataset(BaseDataset):
|
745 |
+
def __init__(self, subsets: Sequence[DreamBoothSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset) -> None:
|
746 |
+
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
|
|
747 |
|
748 |
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
749 |
|
|
|
766 |
self.bucket_reso_steps = None # この情報は使われない
|
767 |
self.bucket_no_upscale = False
|
768 |
|
769 |
+
def read_caption(img_path, caption_extension):
|
770 |
# captionの候補ファイル名を作る
|
771 |
base_name = os.path.splitext(img_path)[0]
|
772 |
base_name_face_det = base_name
|
|
|
789 |
break
|
790 |
return caption
|
791 |
|
792 |
+
def load_dreambooth_dir(subset: DreamBoothSubset):
|
793 |
+
if not os.path.isdir(subset.image_dir):
|
794 |
+
print(f"not directory: {subset.image_dir}")
|
795 |
+
return [], []
|
796 |
|
797 |
+
img_paths = glob_images(subset.image_dir, "*")
|
798 |
+
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
799 |
|
800 |
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
801 |
captions = []
|
802 |
for img_path in img_paths:
|
803 |
+
cap_for_img = read_caption(img_path, subset.caption_extension)
|
804 |
+
if cap_for_img is None and subset.class_tokens is None:
|
805 |
+
print(f"neither caption file nor class tokens are found. use empty caption for {img_path}")
|
806 |
+
captions.append("")
|
807 |
+
else:
|
808 |
+
captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
|
809 |
|
810 |
+
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
811 |
|
812 |
+
return img_paths, captions
|
813 |
|
814 |
+
print("prepare images.")
|
|
|
815 |
num_train_images = 0
|
816 |
+
num_reg_images = 0
|
817 |
+
reg_infos: List[ImageInfo] = []
|
818 |
+
for subset in subsets:
|
819 |
+
if subset.num_repeats < 1:
|
820 |
+
print(
|
821 |
+
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
822 |
+
continue
|
823 |
+
|
824 |
+
if subset in self.subsets:
|
825 |
+
print(
|
826 |
+
f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
827 |
+
continue
|
828 |
+
|
829 |
+
img_paths, captions = load_dreambooth_dir(subset)
|
830 |
+
if len(img_paths) < 1:
|
831 |
+
print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
|
832 |
+
continue
|
833 |
+
|
834 |
+
if subset.is_reg:
|
835 |
+
num_reg_images += subset.num_repeats * len(img_paths)
|
836 |
+
else:
|
837 |
+
num_train_images += subset.num_repeats * len(img_paths)
|
838 |
|
839 |
for img_path, caption in zip(img_paths, captions):
|
840 |
+
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
|
841 |
+
if subset.is_reg:
|
842 |
+
reg_infos.append(info)
|
843 |
+
else:
|
844 |
+
self.register_image(info, subset)
|
845 |
|
846 |
+
subset.img_count = len(img_paths)
|
847 |
+
self.subsets.append(subset)
|
848 |
|
849 |
print(f"{num_train_images} train images with repeating.")
|
850 |
self.num_train_images = num_train_images
|
851 |
|
852 |
+
print(f"{num_reg_images} reg images.")
|
853 |
+
if num_train_images < num_reg_images:
|
854 |
+
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
|
|
|
|
|
855 |
|
856 |
+
if num_reg_images == 0:
|
857 |
+
print("no regularization images / 正則化画像が見つかりませんでした")
|
858 |
+
else:
|
859 |
+
# num_repeatsを計算する:どうせ大した数ではないのでループで処理する
|
860 |
+
n = 0
|
861 |
+
first_loop = True
|
862 |
+
while n < num_train_images:
|
863 |
+
for info in reg_infos:
|
864 |
+
if first_loop:
|
865 |
+
self.register_image(info, subset)
|
866 |
+
n += info.num_repeats
|
867 |
+
else:
|
868 |
+
info.num_repeats += 1
|
869 |
+
n += 1
|
870 |
+
if n >= num_train_images:
|
871 |
+
break
|
872 |
+
first_loop = False
|
873 |
|
874 |
+
self.num_reg_images = num_reg_images
|
|
|
|
|
875 |
|
|
|
876 |
|
877 |
+
class FineTuningDataset(BaseDataset):
|
878 |
+
def __init__(self, subsets: Sequence[FineTuningSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset) -> None:
|
879 |
+
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
880 |
+
|
881 |
+
self.batch_size = batch_size
|
882 |
+
|
883 |
+
self.num_train_images = 0
|
884 |
+
self.num_reg_images = 0
|
885 |
|
886 |
+
for subset in subsets:
|
887 |
+
if subset.num_repeats < 1:
|
888 |
+
print(
|
889 |
+
f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
890 |
+
continue
|
891 |
+
|
892 |
+
if subset in self.subsets:
|
893 |
+
print(
|
894 |
+
f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
895 |
+
continue
|
896 |
+
|
897 |
+
# メタデータを読み込む
|
898 |
+
if os.path.exists(subset.metadata_file):
|
899 |
+
print(f"loading existing metadata: {subset.metadata_file}")
|
900 |
+
with open(subset.metadata_file, "rt", encoding='utf-8') as f:
|
901 |
+
metadata = json.load(f)
|
902 |
else:
|
903 |
+
raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
904 |
|
905 |
+
if len(metadata) < 1:
|
906 |
+
print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します")
|
907 |
+
continue
|
908 |
|
909 |
+
tags_list = []
|
910 |
+
for image_key, img_md in metadata.items():
|
911 |
+
# path情報を作る
|
912 |
+
if os.path.exists(image_key):
|
913 |
+
abs_path = image_key
|
914 |
+
else:
|
915 |
+
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
916 |
+
if os.path.exists(npz_path):
|
917 |
+
abs_path = npz_path
|
918 |
+
else:
|
919 |
+
# わりといい加減だがいい方法が思いつかん
|
920 |
+
abs_path = glob_images(subset.image_dir, image_key)
|
921 |
+
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
922 |
+
abs_path = abs_path[0]
|
923 |
|
924 |
+
caption = img_md.get('caption')
|
925 |
+
tags = img_md.get('tags')
|
926 |
+
if caption is None:
|
927 |
+
caption = tags
|
928 |
+
elif tags is not None and len(tags) > 0:
|
929 |
+
caption = caption + ', ' + tags
|
930 |
+
tags_list.append(tags)
|
|
|
|
|
|
|
|
|
|
|
931 |
|
932 |
+
if caption is None:
|
933 |
+
caption = ""
|
|
|
934 |
|
935 |
+
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
|
936 |
+
image_info.image_size = img_md.get('train_resolution')
|
937 |
+
|
938 |
+
if not subset.color_aug and not subset.random_crop:
|
939 |
+
# if npz exists, use them
|
940 |
+
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
|
941 |
+
|
942 |
+
self.register_image(image_info, subset)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
943 |
|
944 |
+
self.num_train_images += len(metadata) * subset.num_repeats
|
945 |
+
|
946 |
+
# TODO do not record tag freq when no tag
|
947 |
+
self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list)
|
948 |
+
subset.img_count = len(metadata)
|
949 |
+
self.subsets.append(subset)
|
950 |
|
951 |
# check existence of all npz files
|
952 |
+
use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
|
953 |
if use_npz_latents:
|
954 |
+
flip_aug_in_subset = False
|
955 |
npz_any = False
|
956 |
npz_all = True
|
957 |
+
|
958 |
for image_info in self.image_data.values():
|
959 |
+
subset = self.image_to_subset[image_info.image_key]
|
960 |
+
|
961 |
has_npz = image_info.latents_npz is not None
|
962 |
npz_any = npz_any or has_npz
|
963 |
|
964 |
+
if subset.flip_aug:
|
965 |
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
966 |
+
flip_aug_in_subset = True
|
967 |
npz_all = npz_all and has_npz
|
968 |
|
969 |
if npz_any and not npz_all:
|
|
|
975 |
elif not npz_all:
|
976 |
use_npz_latents = False
|
977 |
print(f"some of npz file does not exist. ignore npz files / いくつ���のnpzファイルが見つからないためnpzファイルを無視します")
|
978 |
+
if flip_aug_in_subset:
|
979 |
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
980 |
# else:
|
981 |
# print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
|
|
1021 |
for image_info in self.image_data.values():
|
1022 |
image_info.latents_npz = image_info.latents_npz_flipped = None
|
1023 |
|
1024 |
+
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
|
1025 |
base_name = os.path.splitext(image_key)[0]
|
1026 |
npz_file_norm = base_name + '.npz'
|
1027 |
|
|
|
1033 |
return npz_file_norm, npz_file_flip
|
1034 |
|
1035 |
# image_key is relative path
|
1036 |
+
npz_file_norm = os.path.join(subset.image_dir, image_key + '.npz')
|
1037 |
+
npz_file_flip = os.path.join(subset.image_dir, image_key + '_flip.npz')
|
1038 |
|
1039 |
if not os.path.exists(npz_file_norm):
|
1040 |
npz_file_norm = None
|
|
|
1045 |
return npz_file_norm, npz_file_flip
|
1046 |
|
1047 |
|
1048 |
+
# behave as Dataset mock
|
1049 |
+
class DatasetGroup(torch.utils.data.ConcatDataset):
|
1050 |
+
def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
|
1051 |
+
self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]]
|
1052 |
+
|
1053 |
+
super().__init__(datasets)
|
1054 |
+
|
1055 |
+
self.image_data = {}
|
1056 |
+
self.num_train_images = 0
|
1057 |
+
self.num_reg_images = 0
|
1058 |
+
|
1059 |
+
# simply concat together
|
1060 |
+
# TODO: handling image_data key duplication among dataset
|
1061 |
+
# In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset.
|
1062 |
+
for dataset in datasets:
|
1063 |
+
self.image_data.update(dataset.image_data)
|
1064 |
+
self.num_train_images += dataset.num_train_images
|
1065 |
+
self.num_reg_images += dataset.num_reg_images
|
1066 |
+
|
1067 |
+
def add_replacement(self, str_from, str_to):
|
1068 |
+
for dataset in self.datasets:
|
1069 |
+
dataset.add_replacement(str_from, str_to)
|
1070 |
+
|
1071 |
+
# def make_buckets(self):
|
1072 |
+
# for dataset in self.datasets:
|
1073 |
+
# dataset.make_buckets()
|
1074 |
+
|
1075 |
+
def cache_latents(self, vae):
|
1076 |
+
for i, dataset in enumerate(self.datasets):
|
1077 |
+
print(f"[Dataset {i}]")
|
1078 |
+
dataset.cache_latents(vae)
|
1079 |
+
|
1080 |
+
def is_latent_cacheable(self) -> bool:
|
1081 |
+
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
1082 |
+
|
1083 |
+
def set_current_epoch(self, epoch):
|
1084 |
+
for dataset in self.datasets:
|
1085 |
+
dataset.set_current_epoch(epoch)
|
1086 |
+
|
1087 |
+
def disable_token_padding(self):
|
1088 |
+
for dataset in self.datasets:
|
1089 |
+
dataset.disable_token_padding()
|
1090 |
+
|
1091 |
+
|
1092 |
def debug_dataset(train_dataset, show_input_ids=False):
|
1093 |
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
1094 |
print("Escape for exit. / Escキーで中断、終了します")
|
1095 |
|
1096 |
train_dataset.set_current_epoch(1)
|
1097 |
k = 0
|
1098 |
+
indices = list(range(len(train_dataset)))
|
1099 |
+
random.shuffle(indices)
|
1100 |
+
for i, idx in enumerate(indices):
|
1101 |
+
example = train_dataset[idx]
|
1102 |
if example['latents'] is not None:
|
1103 |
print(f"sample has latents from npz file: {example['latents'].size()}")
|
1104 |
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
|
|
|
1503 |
help='enable v-parameterization training / v-parameterization学習を有効にする')
|
1504 |
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
|
1505 |
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
1506 |
+
parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
|
1507 |
+
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
|
1508 |
+
|
1509 |
+
|
1510 |
+
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
1511 |
+
parser.add_argument("--optimizer_type", type=str, default="",
|
1512 |
+
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor")
|
1513 |
+
|
1514 |
+
# backward compatibility
|
1515 |
+
parser.add_argument("--use_8bit_adam", action="store_true",
|
1516 |
+
help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
1517 |
+
parser.add_argument("--use_lion_optimizer", action="store_true",
|
1518 |
+
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
1519 |
+
|
1520 |
+
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
1521 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
1522 |
+
help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない")
|
1523 |
+
|
1524 |
+
parser.add_argument("--optimizer_args", type=str, default=None, nargs='*',
|
1525 |
+
help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")")
|
1526 |
+
|
1527 |
+
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
1528 |
+
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor")
|
1529 |
+
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
1530 |
+
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
1531 |
+
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
1532 |
+
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
1533 |
+
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
1534 |
+
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
1535 |
|
1536 |
|
1537 |
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
|
|
1555 |
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
1556 |
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
1557 |
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
|
|
|
|
|
|
|
|
1558 |
parser.add_argument("--mem_eff_attn", action="store_true",
|
1559 |
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
1560 |
parser.add_argument("--xformers", action="store_true",
|
|
|
1562 |
parser.add_argument("--vae", type=str, default=None,
|
1563 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
1564 |
|
|
|
1565 |
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
1566 |
parser.add_argument("--max_train_epochs", type=int, default=None,
|
1567 |
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
|
|
1582 |
parser.add_argument("--logging_dir", type=str, default=None,
|
1583 |
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
1584 |
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
|
|
|
|
|
|
|
|
1585 |
parser.add_argument("--noise_offset", type=float, default=None,
|
1586 |
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
|
1587 |
parser.add_argument("--lowram", action="store_true",
|
1588 |
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
|
1589 |
|
1590 |
+
parser.add_argument("--sample_every_n_steps", type=int, default=None,
|
1591 |
+
help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する")
|
1592 |
+
parser.add_argument("--sample_every_n_epochs", type=int, default=None,
|
1593 |
+
help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)")
|
1594 |
+
parser.add_argument("--sample_prompts", type=str, default=None,
|
1595 |
+
help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル")
|
1596 |
+
parser.add_argument('--sample_sampler', type=str, default='ddim',
|
1597 |
+
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
|
1598 |
+
'dpmsolver++', 'dpmsingle',
|
1599 |
+
'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'],
|
1600 |
+
help=f'sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類')
|
1601 |
+
|
1602 |
if support_dreambooth:
|
1603 |
# DreamBooth training
|
1604 |
parser.add_argument("--prior_loss_weight", type=float, default=1.0,
|
|
|
1620 |
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
|
1621 |
parser.add_argument("--caption_extention", type=str, default=None,
|
1622 |
help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
|
1623 |
+
parser.add_argument("--keep_tokens", type=int, default=0,
|
1624 |
+
help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)")
|
1625 |
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
|
1626 |
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
|
1627 |
parser.add_argument("--face_crop_aug_range", type=str, default=None,
|
|
|
1646 |
if support_caption_dropout:
|
1647 |
# Textual Inversion はcaptionのdropoutをsupportしない
|
1648 |
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
1649 |
+
parser.add_argument("--caption_dropout_rate", type=float, default=0.0,
|
1650 |
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
1651 |
+
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=0,
|
1652 |
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
1653 |
+
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0.0,
|
1654 |
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
|
1655 |
|
1656 |
if support_dreambooth:
|
|
|
1675 |
# region utils
|
1676 |
|
1677 |
|
1678 |
+
def get_optimizer(args, trainable_params):
|
1679 |
+
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
|
1680 |
+
|
1681 |
+
optimizer_type = args.optimizer_type
|
1682 |
+
if args.use_8bit_adam:
|
1683 |
+
assert not args.use_lion_optimizer, "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています"
|
1684 |
+
assert optimizer_type is None or optimizer_type == "", "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています"
|
1685 |
+
optimizer_type = "AdamW8bit"
|
1686 |
+
|
1687 |
+
elif args.use_lion_optimizer:
|
1688 |
+
assert optimizer_type is None or optimizer_type == "", "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています"
|
1689 |
+
optimizer_type = "Lion"
|
1690 |
+
|
1691 |
+
if optimizer_type is None or optimizer_type == "":
|
1692 |
+
optimizer_type = "AdamW"
|
1693 |
+
optimizer_type = optimizer_type.lower()
|
1694 |
+
|
1695 |
+
# 引数を分解する:boolとfloat、tupleのみ対応
|
1696 |
+
optimizer_kwargs = {}
|
1697 |
+
if args.optimizer_args is not None and len(args.optimizer_args) > 0:
|
1698 |
+
for arg in args.optimizer_args:
|
1699 |
+
key, value = arg.split('=')
|
1700 |
+
|
1701 |
+
value = value.split(",")
|
1702 |
+
for i in range(len(value)):
|
1703 |
+
if value[i].lower() == "true" or value[i].lower() == "false":
|
1704 |
+
value[i] = (value[i].lower() == "true")
|
1705 |
+
else:
|
1706 |
+
value[i] = float(value[i])
|
1707 |
+
if len(value) == 1:
|
1708 |
+
value = value[0]
|
1709 |
+
else:
|
1710 |
+
value = tuple(value)
|
1711 |
+
|
1712 |
+
optimizer_kwargs[key] = value
|
1713 |
+
# print("optkwargs:", optimizer_kwargs)
|
1714 |
+
|
1715 |
+
lr = args.learning_rate
|
1716 |
+
|
1717 |
+
if optimizer_type == "AdamW8bit".lower():
|
1718 |
+
try:
|
1719 |
+
import bitsandbytes as bnb
|
1720 |
+
except ImportError:
|
1721 |
+
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
1722 |
+
print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
|
1723 |
+
optimizer_class = bnb.optim.AdamW8bit
|
1724 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1725 |
+
|
1726 |
+
elif optimizer_type == "SGDNesterov8bit".lower():
|
1727 |
+
try:
|
1728 |
+
import bitsandbytes as bnb
|
1729 |
+
except ImportError:
|
1730 |
+
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
1731 |
+
print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
|
1732 |
+
if "momentum" not in optimizer_kwargs:
|
1733 |
+
print(f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
|
1734 |
+
optimizer_kwargs["momentum"] = 0.9
|
1735 |
+
|
1736 |
+
optimizer_class = bnb.optim.SGD8bit
|
1737 |
+
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
1738 |
+
|
1739 |
+
elif optimizer_type == "Lion".lower():
|
1740 |
+
try:
|
1741 |
+
import lion_pytorch
|
1742 |
+
except ImportError:
|
1743 |
+
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
1744 |
+
print(f"use Lion optimizer | {optimizer_kwargs}")
|
1745 |
+
optimizer_class = lion_pytorch.Lion
|
1746 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1747 |
+
|
1748 |
+
elif optimizer_type == "SGDNesterov".lower():
|
1749 |
+
print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
|
1750 |
+
if "momentum" not in optimizer_kwargs:
|
1751 |
+
print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
|
1752 |
+
optimizer_kwargs["momentum"] = 0.9
|
1753 |
+
|
1754 |
+
optimizer_class = torch.optim.SGD
|
1755 |
+
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
1756 |
+
|
1757 |
+
elif optimizer_type == "DAdaptation".lower():
|
1758 |
+
try:
|
1759 |
+
import dadaptation
|
1760 |
+
except ImportError:
|
1761 |
+
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
1762 |
+
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
|
1763 |
+
|
1764 |
+
actual_lr = lr
|
1765 |
+
lr_count = 1
|
1766 |
+
if type(trainable_params) == list and type(trainable_params[0]) == dict:
|
1767 |
+
lrs = set()
|
1768 |
+
actual_lr = trainable_params[0].get("lr", actual_lr)
|
1769 |
+
for group in trainable_params:
|
1770 |
+
lrs.add(group.get("lr", actual_lr))
|
1771 |
+
lr_count = len(lrs)
|
1772 |
+
|
1773 |
+
if actual_lr <= 0.1:
|
1774 |
+
print(
|
1775 |
+
f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}')
|
1776 |
+
print('recommend option: lr=1.0 / 推奨は1.0です')
|
1777 |
+
if lr_count > 1:
|
1778 |
+
print(
|
1779 |
+
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}")
|
1780 |
+
|
1781 |
+
optimizer_class = dadaptation.DAdaptAdam
|
1782 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1783 |
+
|
1784 |
+
elif optimizer_type == "Adafactor".lower():
|
1785 |
+
# 引数を確認して適宜補正する
|
1786 |
+
if "relative_step" not in optimizer_kwargs:
|
1787 |
+
optimizer_kwargs["relative_step"] = True # default
|
1788 |
+
if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
|
1789 |
+
print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします")
|
1790 |
+
optimizer_kwargs["relative_step"] = True
|
1791 |
+
print(f"use Adafactor optimizer | {optimizer_kwargs}")
|
1792 |
+
|
1793 |
+
if optimizer_kwargs["relative_step"]:
|
1794 |
+
print(f"relative_step is true / relative_stepがtrueです")
|
1795 |
+
if lr != 0.0:
|
1796 |
+
print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
|
1797 |
+
args.learning_rate = None
|
1798 |
+
|
1799 |
+
# trainable_paramsがgroupだった時の処理:lrを削除する
|
1800 |
+
if type(trainable_params) == list and type(trainable_params[0]) == dict:
|
1801 |
+
has_group_lr = False
|
1802 |
+
for group in trainable_params:
|
1803 |
+
p = group.pop("lr", None)
|
1804 |
+
has_group_lr = has_group_lr or (p is not None)
|
1805 |
+
|
1806 |
+
if has_group_lr:
|
1807 |
+
# 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない
|
1808 |
+
print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます")
|
1809 |
+
args.unet_lr = None
|
1810 |
+
args.text_encoder_lr = None
|
1811 |
+
|
1812 |
+
if args.lr_scheduler != "adafactor":
|
1813 |
+
print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
|
1814 |
+
args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
|
1815 |
+
|
1816 |
+
lr = None
|
1817 |
+
else:
|
1818 |
+
if args.max_grad_norm != 0.0:
|
1819 |
+
print(f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません")
|
1820 |
+
if args.lr_scheduler != "constant_with_warmup":
|
1821 |
+
print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
|
1822 |
+
if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
|
1823 |
+
print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
|
1824 |
+
|
1825 |
+
optimizer_class = transformers.optimization.Adafactor
|
1826 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1827 |
+
|
1828 |
+
elif optimizer_type == "AdamW".lower():
|
1829 |
+
print(f"use AdamW optimizer | {optimizer_kwargs}")
|
1830 |
+
optimizer_class = torch.optim.AdamW
|
1831 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1832 |
+
|
1833 |
+
else:
|
1834 |
+
# 任意のoptimizerを使う
|
1835 |
+
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
1836 |
+
print(f"use {optimizer_type} | {optimizer_kwargs}")
|
1837 |
+
if "." not in optimizer_type:
|
1838 |
+
optimizer_module = torch.optim
|
1839 |
+
else:
|
1840 |
+
values = optimizer_type.split(".")
|
1841 |
+
optimizer_module = importlib.import_module(".".join(values[:-1]))
|
1842 |
+
optimizer_type = values[-1]
|
1843 |
+
|
1844 |
+
optimizer_class = getattr(optimizer_module, optimizer_type)
|
1845 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1846 |
+
|
1847 |
+
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
1848 |
+
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
|
1849 |
+
|
1850 |
+
return optimizer_name, optimizer_args, optimizer
|
1851 |
+
|
1852 |
+
|
1853 |
+
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
1854 |
+
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
1855 |
+
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
1856 |
+
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
1857 |
+
|
1858 |
+
|
1859 |
+
def get_scheduler_fix(
|
1860 |
+
name: Union[str, SchedulerType],
|
1861 |
+
optimizer: Optimizer,
|
1862 |
+
num_warmup_steps: Optional[int] = None,
|
1863 |
+
num_training_steps: Optional[int] = None,
|
1864 |
+
num_cycles: int = 1,
|
1865 |
+
power: float = 1.0,
|
1866 |
+
):
|
1867 |
+
"""
|
1868 |
+
Unified API to get any scheduler from its name.
|
1869 |
+
Args:
|
1870 |
+
name (`str` or `SchedulerType`):
|
1871 |
+
The name of the scheduler to use.
|
1872 |
+
optimizer (`torch.optim.Optimizer`):
|
1873 |
+
The optimizer that will be used during training.
|
1874 |
+
num_warmup_steps (`int`, *optional*):
|
1875 |
+
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
1876 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
1877 |
+
num_training_steps (`int``, *optional*):
|
1878 |
+
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
1879 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
1880 |
+
num_cycles (`int`, *optional*):
|
1881 |
+
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
1882 |
+
power (`float`, *optional*, defaults to 1.0):
|
1883 |
+
Power factor. See `POLYNOMIAL` scheduler
|
1884 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
1885 |
+
The index of the last epoch when resuming training.
|
1886 |
+
"""
|
1887 |
+
if name.startswith("adafactor"):
|
1888 |
+
assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
|
1889 |
+
initial_lr = float(name.split(':')[1])
|
1890 |
+
# print("adafactor scheduler init lr", initial_lr)
|
1891 |
+
return transformers.optimization.AdafactorSchedule(optimizer, initial_lr)
|
1892 |
+
|
1893 |
+
name = SchedulerType(name)
|
1894 |
+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
1895 |
+
if name == SchedulerType.CONSTANT:
|
1896 |
+
return schedule_func(optimizer)
|
1897 |
+
|
1898 |
+
# All other schedulers require `num_warmup_steps`
|
1899 |
+
if num_warmup_steps is None:
|
1900 |
+
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
1901 |
+
|
1902 |
+
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
1903 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
1904 |
+
|
1905 |
+
# All other schedulers require `num_training_steps`
|
1906 |
+
if num_training_steps is None:
|
1907 |
+
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
1908 |
+
|
1909 |
+
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
1910 |
+
return schedule_func(
|
1911 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
1912 |
+
)
|
1913 |
+
|
1914 |
+
if name == SchedulerType.POLYNOMIAL:
|
1915 |
+
return schedule_func(
|
1916 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
1917 |
+
)
|
1918 |
+
|
1919 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
1920 |
+
|
1921 |
+
|
1922 |
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
1923 |
# backward compatibility
|
1924 |
if args.caption_extention is not None:
|
1925 |
args.caption_extension = args.caption_extention
|
1926 |
args.caption_extention = None
|
1927 |
|
|
|
|
|
|
|
|
|
1928 |
# assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
|
1929 |
if args.resolution is not None:
|
1930 |
args.resolution = tuple([int(r) for r in args.resolution.split(',')])
|
|
|
1947 |
|
1948 |
def load_tokenizer(args: argparse.Namespace):
|
1949 |
print("prepare tokenizer")
|
1950 |
+
original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH
|
1951 |
+
|
1952 |
+
tokenizer: CLIPTokenizer = None
|
1953 |
+
if args.tokenizer_cache_dir:
|
1954 |
+
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace('/', '_'))
|
1955 |
+
if os.path.exists(local_tokenizer_path):
|
1956 |
+
print(f"load tokenizer from cache: {local_tokenizer_path}")
|
1957 |
+
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2
|
1958 |
+
|
1959 |
+
if tokenizer is None:
|
1960 |
+
if args.v2:
|
1961 |
+
tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer")
|
1962 |
+
else:
|
1963 |
+
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
1964 |
+
|
1965 |
+
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
1966 |
print(f"update token length: {args.max_token_length}")
|
1967 |
+
|
1968 |
+
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
1969 |
+
print(f"save Tokenizer to cache: {local_tokenizer_path}")
|
1970 |
+
tokenizer.save_pretrained(local_tokenizer_path)
|
1971 |
+
|
1972 |
return tokenizer
|
1973 |
|
1974 |
|
|
|
2019 |
|
2020 |
|
2021 |
def load_target_model(args: argparse.Namespace, weight_dtype):
|
2022 |
+
name_or_path = args.pretrained_model_name_or_path
|
2023 |
+
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
2024 |
+
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
2025 |
if load_stable_diffusion_format:
|
2026 |
print("load StableDiffusion checkpoint")
|
2027 |
+
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path)
|
2028 |
else:
|
2029 |
print("load Diffusers pretrained models")
|
2030 |
+
try:
|
2031 |
+
pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
|
2032 |
+
except EnvironmentError as ex:
|
2033 |
+
print(
|
2034 |
+
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}")
|
2035 |
text_encoder = pipe.text_encoder
|
2036 |
vae = pipe.vae
|
2037 |
unet = pipe.unet
|
|
|
2200 |
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
2201 |
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
2202 |
|
2203 |
+
|
2204 |
+
# scheduler:
|
2205 |
+
SCHEDULER_LINEAR_START = 0.00085
|
2206 |
+
SCHEDULER_LINEAR_END = 0.0120
|
2207 |
+
SCHEDULER_TIMESTEPS = 1000
|
2208 |
+
SCHEDLER_SCHEDULE = 'scaled_linear'
|
2209 |
+
|
2210 |
+
|
2211 |
+
def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None):
|
2212 |
+
"""
|
2213 |
+
生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない
|
2214 |
+
clip skipは対応した
|
2215 |
+
"""
|
2216 |
+
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
2217 |
+
return
|
2218 |
+
if args.sample_every_n_epochs is not None:
|
2219 |
+
# sample_every_n_steps は無視する
|
2220 |
+
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
2221 |
+
return
|
2222 |
+
else:
|
2223 |
+
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
2224 |
+
return
|
2225 |
+
|
2226 |
+
print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
2227 |
+
if not os.path.isfile(args.sample_prompts):
|
2228 |
+
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
2229 |
+
return
|
2230 |
+
|
2231 |
+
org_vae_device = vae.device # CPUにいるはず
|
2232 |
+
vae.to(device)
|
2233 |
+
|
2234 |
+
# clip skip 対応のための wrapper を作る
|
2235 |
+
if args.clip_skip is None:
|
2236 |
+
text_encoder_or_wrapper = text_encoder
|
2237 |
+
else:
|
2238 |
+
class Wrapper():
|
2239 |
+
def __init__(self, tenc) -> None:
|
2240 |
+
self.tenc = tenc
|
2241 |
+
self.config = {}
|
2242 |
+
super().__init__()
|
2243 |
+
|
2244 |
+
def __call__(self, input_ids, attention_mask):
|
2245 |
+
enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True)
|
2246 |
+
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
2247 |
+
encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states)
|
2248 |
+
pooled_output = enc_out['pooler_output']
|
2249 |
+
return encoder_hidden_states, pooled_output # 1st output is only used
|
2250 |
+
|
2251 |
+
text_encoder_or_wrapper = Wrapper(text_encoder)
|
2252 |
+
|
2253 |
+
# read prompts
|
2254 |
+
with open(args.sample_prompts, 'rt', encoding='utf-8') as f:
|
2255 |
+
prompts = f.readlines()
|
2256 |
+
|
2257 |
+
# schedulerを用意する
|
2258 |
+
sched_init_args = {}
|
2259 |
+
if args.sample_sampler == "ddim":
|
2260 |
+
scheduler_cls = DDIMScheduler
|
2261 |
+
elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
|
2262 |
+
scheduler_cls = DDPMScheduler
|
2263 |
+
elif args.sample_sampler == "pndm":
|
2264 |
+
scheduler_cls = PNDMScheduler
|
2265 |
+
elif args.sample_sampler == 'lms' or args.sample_sampler == 'k_lms':
|
2266 |
+
scheduler_cls = LMSDiscreteScheduler
|
2267 |
+
elif args.sample_sampler == 'euler' or args.sample_sampler == 'k_euler':
|
2268 |
+
scheduler_cls = EulerDiscreteScheduler
|
2269 |
+
elif args.sample_sampler == 'euler_a' or args.sample_sampler == 'k_euler_a':
|
2270 |
+
scheduler_cls = EulerAncestralDiscreteScheduler
|
2271 |
+
elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
|
2272 |
+
scheduler_cls = DPMSolverMultistepScheduler
|
2273 |
+
sched_init_args['algorithm_type'] = args.sample_sampler
|
2274 |
+
elif args.sample_sampler == "dpmsingle":
|
2275 |
+
scheduler_cls = DPMSolverSinglestepScheduler
|
2276 |
+
elif args.sample_sampler == "heun":
|
2277 |
+
scheduler_cls = HeunDiscreteScheduler
|
2278 |
+
elif args.sample_sampler == 'dpm_2' or args.sample_sampler == 'k_dpm_2':
|
2279 |
+
scheduler_cls = KDPM2DiscreteScheduler
|
2280 |
+
elif args.sample_sampler == 'dpm_2_a' or args.sample_sampler == 'k_dpm_2_a':
|
2281 |
+
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
2282 |
+
else:
|
2283 |
+
scheduler_cls = DDIMScheduler
|
2284 |
+
|
2285 |
+
if args.v_parameterization:
|
2286 |
+
sched_init_args['prediction_type'] = 'v_prediction'
|
2287 |
+
|
2288 |
+
scheduler = scheduler_cls(num_train_timesteps=SCHEDULER_TIMESTEPS,
|
2289 |
+
beta_start=SCHEDULER_LINEAR_START, beta_end=SCHEDULER_LINEAR_END,
|
2290 |
+
beta_schedule=SCHEDLER_SCHEDULE, **sched_init_args)
|
2291 |
+
|
2292 |
+
# clip_sample=Trueにする
|
2293 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
2294 |
+
# print("set clip_sample to True")
|
2295 |
+
scheduler.config.clip_sample = True
|
2296 |
+
|
2297 |
+
pipeline = StableDiffusionPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer,
|
2298 |
+
scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False)
|
2299 |
+
pipeline.to(device)
|
2300 |
+
|
2301 |
+
save_dir = args.output_dir + "/sample"
|
2302 |
+
os.makedirs(save_dir, exist_ok=True)
|
2303 |
+
|
2304 |
+
rng_state = torch.get_rng_state()
|
2305 |
+
cuda_rng_state = torch.cuda.get_rng_state()
|
2306 |
+
|
2307 |
+
with torch.no_grad():
|
2308 |
+
with accelerator.autocast():
|
2309 |
+
for i, prompt in enumerate(prompts):
|
2310 |
+
if not accelerator.is_main_process:
|
2311 |
+
continue
|
2312 |
+
prompt = prompt.strip()
|
2313 |
+
if len(prompt) == 0 or prompt[0] == '#':
|
2314 |
+
continue
|
2315 |
+
|
2316 |
+
# subset of gen_img_diffusers
|
2317 |
+
prompt_args = prompt.split(' --')
|
2318 |
+
prompt = prompt_args[0]
|
2319 |
+
negative_prompt = None
|
2320 |
+
sample_steps = 30
|
2321 |
+
width = height = 512
|
2322 |
+
scale = 7.5
|
2323 |
+
seed = None
|
2324 |
+
for parg in prompt_args:
|
2325 |
+
try:
|
2326 |
+
m = re.match(r'w (\d+)', parg, re.IGNORECASE)
|
2327 |
+
if m:
|
2328 |
+
width = int(m.group(1))
|
2329 |
+
continue
|
2330 |
+
|
2331 |
+
m = re.match(r'h (\d+)', parg, re.IGNORECASE)
|
2332 |
+
if m:
|
2333 |
+
height = int(m.group(1))
|
2334 |
+
continue
|
2335 |
+
|
2336 |
+
m = re.match(r'd (\d+)', parg, re.IGNORECASE)
|
2337 |
+
if m:
|
2338 |
+
seed = int(m.group(1))
|
2339 |
+
continue
|
2340 |
+
|
2341 |
+
m = re.match(r's (\d+)', parg, re.IGNORECASE)
|
2342 |
+
if m: # steps
|
2343 |
+
sample_steps = max(1, min(1000, int(m.group(1))))
|
2344 |
+
continue
|
2345 |
+
|
2346 |
+
m = re.match(r'l ([\d\.]+)', parg, re.IGNORECASE)
|
2347 |
+
if m: # scale
|
2348 |
+
scale = float(m.group(1))
|
2349 |
+
continue
|
2350 |
+
|
2351 |
+
m = re.match(r'n (.+)', parg, re.IGNORECASE)
|
2352 |
+
if m: # negative prompt
|
2353 |
+
negative_prompt = m.group(1)
|
2354 |
+
continue
|
2355 |
+
|
2356 |
+
except ValueError as ex:
|
2357 |
+
print(f"Exception in parsing / 解析エラー: {parg}")
|
2358 |
+
print(ex)
|
2359 |
+
|
2360 |
+
if seed is not None:
|
2361 |
+
torch.manual_seed(seed)
|
2362 |
+
torch.cuda.manual_seed(seed)
|
2363 |
+
|
2364 |
+
if prompt_replacement is not None:
|
2365 |
+
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
2366 |
+
if negative_prompt is not None:
|
2367 |
+
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
2368 |
+
|
2369 |
+
height = max(64, height - height % 8) # round to divisible by 8
|
2370 |
+
width = max(64, width - width % 8) # round to divisible by 8
|
2371 |
+
print(f"prompt: {prompt}")
|
2372 |
+
print(f"negative_prompt: {negative_prompt}")
|
2373 |
+
print(f"height: {height}")
|
2374 |
+
print(f"width: {width}")
|
2375 |
+
print(f"sample_steps: {sample_steps}")
|
2376 |
+
print(f"scale: {scale}")
|
2377 |
+
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
|
2378 |
+
|
2379 |
+
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
|
2380 |
+
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
2381 |
+
seed_suffix = "" if seed is None else f"_{seed}"
|
2382 |
+
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
|
2383 |
+
|
2384 |
+
image.save(os.path.join(save_dir, img_filename))
|
2385 |
+
|
2386 |
+
# clear pipeline and cache to reduce vram usage
|
2387 |
+
del pipeline
|
2388 |
+
torch.cuda.empty_cache()
|
2389 |
+
|
2390 |
+
torch.set_rng_state(rng_state)
|
2391 |
+
torch.cuda.set_rng_state(cuda_rng_state)
|
2392 |
+
vae.to(org_vae_device)
|
2393 |
+
|
2394 |
# endregion
|
2395 |
|
2396 |
# region 前処理用
|
networks/check_lora_weights.py
CHANGED
@@ -21,7 +21,7 @@ def main(file):
|
|
21 |
|
22 |
for key, value in values:
|
23 |
value = value.to(torch.float32)
|
24 |
-
print(f"{key},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
|
25 |
|
26 |
|
27 |
if __name__ == '__main__':
|
|
|
21 |
|
22 |
for key, value in values:
|
23 |
value = value.to(torch.float32)
|
24 |
+
print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
|
25 |
|
26 |
|
27 |
if __name__ == '__main__':
|
networks/extract_lora_from_models.py
CHANGED
@@ -45,8 +45,13 @@ def svd(args):
|
|
45 |
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
|
46 |
|
47 |
# create LoRA network to extract weights: Use dim (rank) as alpha
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
50 |
assert len(lora_network_o.text_encoder_loras) == len(
|
51 |
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
52 |
|
@@ -85,13 +90,28 @@ def svd(args):
|
|
85 |
|
86 |
# make LoRA with svd
|
87 |
print("calculating by svd")
|
88 |
-
rank = args.dim
|
89 |
lora_weights = {}
|
90 |
with torch.no_grad():
|
91 |
for lora_name, mat in tqdm(list(diffs.items())):
|
|
|
92 |
conv2d = (len(mat.size()) == 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
if conv2d:
|
94 |
-
|
|
|
|
|
|
|
95 |
|
96 |
U, S, Vh = torch.linalg.svd(mat)
|
97 |
|
@@ -108,30 +128,27 @@ def svd(args):
|
|
108 |
U = U.clamp(low_val, hi_val)
|
109 |
Vh = Vh.clamp(low_val, hi_val)
|
110 |
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict
|
115 |
-
lora_sd = lora_network_o.state_dict()
|
116 |
-
print(f"LoRA has {len(lora_sd)} weights.")
|
117 |
-
|
118 |
-
for key in list(lora_sd.keys()):
|
119 |
-
if "alpha" in key:
|
120 |
-
continue
|
121 |
|
122 |
-
|
123 |
-
|
124 |
|
125 |
-
|
126 |
-
# print(key, i, weights.size(), lora_sd[key].size())
|
127 |
-
if len(lora_sd[key].size()) == 4:
|
128 |
-
weights = weights.unsqueeze(2).unsqueeze(3)
|
129 |
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
|
132 |
|
133 |
# load state dict to LoRA and save it
|
134 |
-
|
|
|
|
|
|
|
135 |
print(f"Loading extracted LoRA weights: {info}")
|
136 |
|
137 |
dir_name = os.path.dirname(args.save_to)
|
@@ -139,9 +156,9 @@ def svd(args):
|
|
139 |
os.makedirs(dir_name, exist_ok=True)
|
140 |
|
141 |
# minimum metadata
|
142 |
-
metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
|
143 |
|
144 |
-
|
145 |
print(f"LoRA weights are saved to: {args.save_to}")
|
146 |
|
147 |
|
@@ -158,6 +175,8 @@ if __name__ == '__main__':
|
|
158 |
parser.add_argument("--save_to", type=str, default=None,
|
159 |
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
160 |
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
|
|
|
|
161 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイ��、cuda でGPUを使う")
|
162 |
|
163 |
args = parser.parse_args()
|
|
|
45 |
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
|
46 |
|
47 |
# create LoRA network to extract weights: Use dim (rank) as alpha
|
48 |
+
if args.conv_dim is None:
|
49 |
+
kwargs = {}
|
50 |
+
else:
|
51 |
+
kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim}
|
52 |
+
|
53 |
+
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o, **kwargs)
|
54 |
+
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t, **kwargs)
|
55 |
assert len(lora_network_o.text_encoder_loras) == len(
|
56 |
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
57 |
|
|
|
90 |
|
91 |
# make LoRA with svd
|
92 |
print("calculating by svd")
|
|
|
93 |
lora_weights = {}
|
94 |
with torch.no_grad():
|
95 |
for lora_name, mat in tqdm(list(diffs.items())):
|
96 |
+
# if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3
|
97 |
conv2d = (len(mat.size()) == 4)
|
98 |
+
kernel_size = None if not conv2d else mat.size()[2:4]
|
99 |
+
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
100 |
+
|
101 |
+
rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim
|
102 |
+
out_dim, in_dim = mat.size()[0:2]
|
103 |
+
|
104 |
+
if args.device:
|
105 |
+
mat = mat.to(args.device)
|
106 |
+
|
107 |
+
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
108 |
+
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
109 |
+
|
110 |
if conv2d:
|
111 |
+
if conv2d_3x3:
|
112 |
+
mat = mat.flatten(start_dim=1)
|
113 |
+
else:
|
114 |
+
mat = mat.squeeze()
|
115 |
|
116 |
U, S, Vh = torch.linalg.svd(mat)
|
117 |
|
|
|
128 |
U = U.clamp(low_val, hi_val)
|
129 |
Vh = Vh.clamp(low_val, hi_val)
|
130 |
|
131 |
+
if conv2d:
|
132 |
+
U = U.reshape(out_dim, rank, 1, 1)
|
133 |
+
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
+
U = U.to("cpu").contiguous()
|
136 |
+
Vh = Vh.to("cpu").contiguous()
|
137 |
|
138 |
+
lora_weights[lora_name] = (U, Vh)
|
|
|
|
|
|
|
139 |
|
140 |
+
# make state dict for LoRA
|
141 |
+
lora_sd = {}
|
142 |
+
for lora_name, (up_weight, down_weight) in lora_weights.items():
|
143 |
+
lora_sd[lora_name + '.lora_up.weight'] = up_weight
|
144 |
+
lora_sd[lora_name + '.lora_down.weight'] = down_weight
|
145 |
+
lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0])
|
146 |
|
147 |
# load state dict to LoRA and save it
|
148 |
+
lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
|
149 |
+
lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict
|
150 |
+
|
151 |
+
info = lora_network_save.load_state_dict(lora_sd)
|
152 |
print(f"Loading extracted LoRA weights: {info}")
|
153 |
|
154 |
dir_name = os.path.dirname(args.save_to)
|
|
|
156 |
os.makedirs(dir_name, exist_ok=True)
|
157 |
|
158 |
# minimum metadata
|
159 |
+
metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
|
160 |
|
161 |
+
lora_network_save.save_weights(args.save_to, save_dtype, metadata)
|
162 |
print(f"LoRA weights are saved to: {args.save_to}")
|
163 |
|
164 |
|
|
|
175 |
parser.add_argument("--save_to", type=str, default=None,
|
176 |
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
177 |
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
178 |
+
parser.add_argument("--conv_dim", type=int, default=None,
|
179 |
+
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)")
|
180 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイ��、cuda でGPUを使う")
|
181 |
|
182 |
args = parser.parse_args()
|
networks/lora.py
CHANGED
@@ -6,6 +6,7 @@
|
|
6 |
import math
|
7 |
import os
|
8 |
from typing import List
|
|
|
9 |
import torch
|
10 |
|
11 |
from library import train_util
|
@@ -20,22 +21,34 @@ class LoRAModule(torch.nn.Module):
|
|
20 |
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
21 |
super().__init__()
|
22 |
self.lora_name = lora_name
|
23 |
-
self.lora_dim = lora_dim
|
24 |
|
25 |
if org_module.__class__.__name__ == 'Conv2d':
|
26 |
in_dim = org_module.in_channels
|
27 |
out_dim = org_module.out_channels
|
28 |
-
self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
|
29 |
-
self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
|
30 |
else:
|
31 |
in_dim = org_module.in_features
|
32 |
out_dim = org_module.out_features
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
if type(alpha) == torch.Tensor:
|
37 |
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
38 |
-
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
39 |
self.scale = alpha / self.lora_dim
|
40 |
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
41 |
|
@@ -45,69 +58,192 @@ class LoRAModule(torch.nn.Module):
|
|
45 |
|
46 |
self.multiplier = multiplier
|
47 |
self.org_module = org_module # remove in applying
|
|
|
|
|
48 |
|
49 |
def apply_to(self):
|
50 |
self.org_forward = self.org_module.forward
|
51 |
self.org_module.forward = self.forward
|
52 |
del self.org_module
|
53 |
|
|
|
|
|
|
|
|
|
54 |
def forward(self, x):
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
|
58 |
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
59 |
if network_dim is None:
|
60 |
network_dim = 4 # default
|
61 |
-
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
|
62 |
-
return network
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
for key, value in weights_sd.items():
|
76 |
-
if network_alpha is None and 'alpha' in key:
|
77 |
-
network_alpha = value
|
78 |
-
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
|
79 |
-
network_dim = value.size()[0]
|
80 |
|
81 |
-
if network_alpha is None:
|
82 |
-
network_alpha = network_dim
|
83 |
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
network.weights_sd = weights_sd
|
86 |
return network
|
87 |
|
88 |
|
89 |
class LoRANetwork(torch.nn.Module):
|
|
|
90 |
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
|
|
91 |
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
92 |
LORA_PREFIX_UNET = 'lora_unet'
|
93 |
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
94 |
|
95 |
-
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
|
96 |
super().__init__()
|
97 |
self.multiplier = multiplier
|
|
|
98 |
self.lora_dim = lora_dim
|
99 |
self.alpha = alpha
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
# create module instances
|
102 |
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
103 |
loras = []
|
104 |
for name, module in root_module.named_modules():
|
105 |
if module.__class__.__name__ in target_replace_modules:
|
|
|
106 |
for child_name, child_module in module.named_modules():
|
107 |
-
|
|
|
|
|
|
|
108 |
lora_name = prefix + '.' + name + '.' + child_name
|
109 |
lora_name = lora_name.replace('.', '_')
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
loras.append(lora)
|
112 |
return loras
|
113 |
|
@@ -115,7 +251,12 @@ class LoRANetwork(torch.nn.Module):
|
|
115 |
text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
116 |
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
117 |
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
119 |
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
120 |
|
121 |
self.weights_sd = None
|
@@ -126,6 +267,11 @@ class LoRANetwork(torch.nn.Module):
|
|
126 |
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
127 |
names.add(lora.lora_name)
|
128 |
|
|
|
|
|
|
|
|
|
|
|
129 |
def load_weights(self, file):
|
130 |
if os.path.splitext(file)[1] == '.safetensors':
|
131 |
from safetensors.torch import load_file, safe_open
|
@@ -235,3 +381,18 @@ class LoRANetwork(torch.nn.Module):
|
|
235 |
save_file(state_dict, file, metadata)
|
236 |
else:
|
237 |
torch.save(state_dict, file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import math
|
7 |
import os
|
8 |
from typing import List
|
9 |
+
import numpy as np
|
10 |
import torch
|
11 |
|
12 |
from library import train_util
|
|
|
21 |
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
22 |
super().__init__()
|
23 |
self.lora_name = lora_name
|
|
|
24 |
|
25 |
if org_module.__class__.__name__ == 'Conv2d':
|
26 |
in_dim = org_module.in_channels
|
27 |
out_dim = org_module.out_channels
|
|
|
|
|
28 |
else:
|
29 |
in_dim = org_module.in_features
|
30 |
out_dim = org_module.out_features
|
31 |
+
|
32 |
+
# if limit_rank:
|
33 |
+
# self.lora_dim = min(lora_dim, in_dim, out_dim)
|
34 |
+
# if self.lora_dim != lora_dim:
|
35 |
+
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
36 |
+
# else:
|
37 |
+
self.lora_dim = lora_dim
|
38 |
+
|
39 |
+
if org_module.__class__.__name__ == 'Conv2d':
|
40 |
+
kernel_size = org_module.kernel_size
|
41 |
+
stride = org_module.stride
|
42 |
+
padding = org_module.padding
|
43 |
+
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
44 |
+
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
45 |
+
else:
|
46 |
+
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
47 |
+
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
48 |
|
49 |
if type(alpha) == torch.Tensor:
|
50 |
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
51 |
+
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
52 |
self.scale = alpha / self.lora_dim
|
53 |
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
54 |
|
|
|
58 |
|
59 |
self.multiplier = multiplier
|
60 |
self.org_module = org_module # remove in applying
|
61 |
+
self.region = None
|
62 |
+
self.region_mask = None
|
63 |
|
64 |
def apply_to(self):
|
65 |
self.org_forward = self.org_module.forward
|
66 |
self.org_module.forward = self.forward
|
67 |
del self.org_module
|
68 |
|
69 |
+
def set_region(self, region):
|
70 |
+
self.region = region
|
71 |
+
self.region_mask = None
|
72 |
+
|
73 |
def forward(self, x):
|
74 |
+
if self.region is None:
|
75 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
76 |
+
|
77 |
+
# regional LoRA FIXME same as additional-network extension
|
78 |
+
if x.size()[1] % 77 == 0:
|
79 |
+
# print(f"LoRA for context: {self.lora_name}")
|
80 |
+
self.region = None
|
81 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
82 |
+
|
83 |
+
# calculate region mask first time
|
84 |
+
if self.region_mask is None:
|
85 |
+
if len(x.size()) == 4:
|
86 |
+
h, w = x.size()[2:4]
|
87 |
+
else:
|
88 |
+
seq_len = x.size()[1]
|
89 |
+
ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
|
90 |
+
h = int(self.region.size()[0] / ratio + .5)
|
91 |
+
w = seq_len // h
|
92 |
+
|
93 |
+
r = self.region.to(x.device)
|
94 |
+
if r.dtype == torch.bfloat16:
|
95 |
+
r = r.to(torch.float)
|
96 |
+
r = r.unsqueeze(0).unsqueeze(1)
|
97 |
+
# print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
|
98 |
+
r = torch.nn.functional.interpolate(r, (h, w), mode='bilinear')
|
99 |
+
r = r.to(x.dtype)
|
100 |
+
|
101 |
+
if len(x.size()) == 3:
|
102 |
+
r = torch.reshape(r, (1, x.size()[1], -1))
|
103 |
+
|
104 |
+
self.region_mask = r
|
105 |
+
|
106 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask
|
107 |
|
108 |
|
109 |
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
110 |
if network_dim is None:
|
111 |
network_dim = 4 # default
|
|
|
|
|
112 |
|
113 |
+
# extract dim/alpha for conv2d, and block dim
|
114 |
+
conv_dim = kwargs.get('conv_dim', None)
|
115 |
+
conv_alpha = kwargs.get('conv_alpha', None)
|
116 |
+
if conv_dim is not None:
|
117 |
+
conv_dim = int(conv_dim)
|
118 |
+
if conv_alpha is None:
|
119 |
+
conv_alpha = 1.0
|
120 |
+
else:
|
121 |
+
conv_alpha = float(conv_alpha)
|
122 |
|
123 |
+
"""
|
124 |
+
block_dims = kwargs.get("block_dims")
|
125 |
+
block_alphas = None
|
126 |
+
|
127 |
+
if block_dims is not None:
|
128 |
+
block_dims = [int(d) for d in block_dims.split(',')]
|
129 |
+
assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
|
130 |
+
block_alphas = kwargs.get("block_alphas")
|
131 |
+
if block_alphas is None:
|
132 |
+
block_alphas = [1] * len(block_dims)
|
133 |
+
else:
|
134 |
+
block_alphas = [int(a) for a in block_alphas(',')]
|
135 |
+
assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
136 |
+
|
137 |
+
conv_block_dims = kwargs.get("conv_block_dims")
|
138 |
+
conv_block_alphas = None
|
139 |
+
|
140 |
+
if conv_block_dims is not None:
|
141 |
+
conv_block_dims = [int(d) for d in conv_block_dims.split(',')]
|
142 |
+
assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
|
143 |
+
conv_block_alphas = kwargs.get("conv_block_alphas")
|
144 |
+
if conv_block_alphas is None:
|
145 |
+
conv_block_alphas = [1] * len(conv_block_dims)
|
146 |
+
else:
|
147 |
+
conv_block_alphas = [int(a) for a in conv_block_alphas(',')]
|
148 |
+
assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
149 |
+
"""
|
150 |
|
151 |
+
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim,
|
152 |
+
alpha=network_alpha, conv_lora_dim=conv_dim, conv_alpha=conv_alpha)
|
153 |
+
return network
|
|
|
|
|
|
|
|
|
|
|
154 |
|
|
|
|
|
155 |
|
156 |
+
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
|
157 |
+
if weights_sd is None:
|
158 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
159 |
+
from safetensors.torch import load_file, safe_open
|
160 |
+
weights_sd = load_file(file)
|
161 |
+
else:
|
162 |
+
weights_sd = torch.load(file, map_location='cpu')
|
163 |
+
|
164 |
+
# get dim/alpha mapping
|
165 |
+
modules_dim = {}
|
166 |
+
modules_alpha = {}
|
167 |
+
for key, value in weights_sd.items():
|
168 |
+
if '.' not in key:
|
169 |
+
continue
|
170 |
+
|
171 |
+
lora_name = key.split('.')[0]
|
172 |
+
if 'alpha' in key:
|
173 |
+
modules_alpha[lora_name] = value
|
174 |
+
elif 'lora_down' in key:
|
175 |
+
dim = value.size()[0]
|
176 |
+
modules_dim[lora_name] = dim
|
177 |
+
# print(lora_name, value.size(), dim)
|
178 |
+
|
179 |
+
# support old LoRA without alpha
|
180 |
+
for key in modules_dim.keys():
|
181 |
+
if key not in modules_alpha:
|
182 |
+
modules_alpha = modules_dim[key]
|
183 |
+
|
184 |
+
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
185 |
network.weights_sd = weights_sd
|
186 |
return network
|
187 |
|
188 |
|
189 |
class LoRANetwork(torch.nn.Module):
|
190 |
+
# is it possible to apply conv_in and conv_out?
|
191 |
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
192 |
+
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
193 |
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
194 |
LORA_PREFIX_UNET = 'lora_unet'
|
195 |
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
196 |
|
197 |
+
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1, conv_lora_dim=None, conv_alpha=None, modules_dim=None, modules_alpha=None) -> None:
|
198 |
super().__init__()
|
199 |
self.multiplier = multiplier
|
200 |
+
|
201 |
self.lora_dim = lora_dim
|
202 |
self.alpha = alpha
|
203 |
+
self.conv_lora_dim = conv_lora_dim
|
204 |
+
self.conv_alpha = conv_alpha
|
205 |
+
|
206 |
+
if modules_dim is not None:
|
207 |
+
print(f"create LoRA network from weights")
|
208 |
+
else:
|
209 |
+
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
210 |
+
|
211 |
+
self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None
|
212 |
+
if self.apply_to_conv2d_3x3:
|
213 |
+
if self.conv_alpha is None:
|
214 |
+
self.conv_alpha = self.alpha
|
215 |
+
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
216 |
|
217 |
# create module instances
|
218 |
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
219 |
loras = []
|
220 |
for name, module in root_module.named_modules():
|
221 |
if module.__class__.__name__ in target_replace_modules:
|
222 |
+
# TODO get block index here
|
223 |
for child_name, child_module in module.named_modules():
|
224 |
+
is_linear = child_module.__class__.__name__ == "Linear"
|
225 |
+
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
226 |
+
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
227 |
+
if is_linear or is_conv2d:
|
228 |
lora_name = prefix + '.' + name + '.' + child_name
|
229 |
lora_name = lora_name.replace('.', '_')
|
230 |
+
|
231 |
+
if modules_dim is not None:
|
232 |
+
if lora_name not in modules_dim:
|
233 |
+
continue # no LoRA module in this weights file
|
234 |
+
dim = modules_dim[lora_name]
|
235 |
+
alpha = modules_alpha[lora_name]
|
236 |
+
else:
|
237 |
+
if is_linear or is_conv2d_1x1:
|
238 |
+
dim = self.lora_dim
|
239 |
+
alpha = self.alpha
|
240 |
+
elif self.apply_to_conv2d_3x3:
|
241 |
+
dim = self.conv_lora_dim
|
242 |
+
alpha = self.conv_alpha
|
243 |
+
else:
|
244 |
+
continue
|
245 |
+
|
246 |
+
lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
|
247 |
loras.append(lora)
|
248 |
return loras
|
249 |
|
|
|
251 |
text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
252 |
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
253 |
|
254 |
+
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
255 |
+
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
256 |
+
if modules_dim is not None or self.conv_lora_dim is not None:
|
257 |
+
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
258 |
+
|
259 |
+
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules)
|
260 |
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
261 |
|
262 |
self.weights_sd = None
|
|
|
267 |
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
268 |
names.add(lora.lora_name)
|
269 |
|
270 |
+
def set_multiplier(self, multiplier):
|
271 |
+
self.multiplier = multiplier
|
272 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
273 |
+
lora.multiplier = self.multiplier
|
274 |
+
|
275 |
def load_weights(self, file):
|
276 |
if os.path.splitext(file)[1] == '.safetensors':
|
277 |
from safetensors.torch import load_file, safe_open
|
|
|
381 |
save_file(state_dict, file, metadata)
|
382 |
else:
|
383 |
torch.save(state_dict, file)
|
384 |
+
|
385 |
+
@ staticmethod
|
386 |
+
def set_regions(networks, image):
|
387 |
+
image = image.astype(np.float32) / 255.0
|
388 |
+
for i, network in enumerate(networks[:3]):
|
389 |
+
# NOTE: consider averaging overwrapping area
|
390 |
+
region = image[:, :, i]
|
391 |
+
if region.max() == 0:
|
392 |
+
continue
|
393 |
+
region = torch.tensor(region)
|
394 |
+
network.set_region(region)
|
395 |
+
|
396 |
+
def set_region(self, region):
|
397 |
+
for lora in self.unet_loras:
|
398 |
+
lora.set_region(region)
|
networks/merge_lora.py
CHANGED
@@ -48,7 +48,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
|
48 |
for name, module in root_module.named_modules():
|
49 |
if module.__class__.__name__ in target_replace_modules:
|
50 |
for child_name, child_module in module.named_modules():
|
51 |
-
if child_module.__class__.__name__ == "Linear" or
|
52 |
lora_name = prefix + '.' + name + '.' + child_name
|
53 |
lora_name = lora_name.replace('.', '_')
|
54 |
name_to_module[lora_name] = child_module
|
@@ -80,13 +80,19 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
|
80 |
|
81 |
# W <- W + U * D
|
82 |
weight = module.weight
|
|
|
83 |
if len(weight.size()) == 2:
|
84 |
# linear
|
85 |
weight = weight + ratio * (up_weight @ down_weight) * scale
|
86 |
-
|
87 |
-
# conv2d
|
88 |
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
89 |
).unsqueeze(2).unsqueeze(3) * scale
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
module.weight = torch.nn.Parameter(weight)
|
92 |
|
@@ -123,7 +129,7 @@ def merge_lora_models(models, ratios, merge_dtype):
|
|
123 |
alphas[lora_module_name] = alpha
|
124 |
if lora_module_name not in base_alphas:
|
125 |
base_alphas[lora_module_name] = alpha
|
126 |
-
|
127 |
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
128 |
|
129 |
# merge
|
@@ -145,7 +151,7 @@ def merge_lora_models(models, ratios, merge_dtype):
|
|
145 |
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
146 |
else:
|
147 |
merged_sd[key] = lora_sd[key] * scale
|
148 |
-
|
149 |
# set alpha to sd
|
150 |
for lora_module_name, alpha in base_alphas.items():
|
151 |
key = lora_module_name + ".alpha"
|
|
|
48 |
for name, module in root_module.named_modules():
|
49 |
if module.__class__.__name__ in target_replace_modules:
|
50 |
for child_name, child_module in module.named_modules():
|
51 |
+
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
|
52 |
lora_name = prefix + '.' + name + '.' + child_name
|
53 |
lora_name = lora_name.replace('.', '_')
|
54 |
name_to_module[lora_name] = child_module
|
|
|
80 |
|
81 |
# W <- W + U * D
|
82 |
weight = module.weight
|
83 |
+
# print(module_name, down_weight.size(), up_weight.size())
|
84 |
if len(weight.size()) == 2:
|
85 |
# linear
|
86 |
weight = weight + ratio * (up_weight @ down_weight) * scale
|
87 |
+
elif down_weight.size()[2:4] == (1, 1):
|
88 |
+
# conv2d 1x1
|
89 |
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
90 |
).unsqueeze(2).unsqueeze(3) * scale
|
91 |
+
else:
|
92 |
+
# conv2d 3x3
|
93 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
94 |
+
# print(conved.size(), weight.size(), module.stride, module.padding)
|
95 |
+
weight = weight + ratio * conved * scale
|
96 |
|
97 |
module.weight = torch.nn.Parameter(weight)
|
98 |
|
|
|
129 |
alphas[lora_module_name] = alpha
|
130 |
if lora_module_name not in base_alphas:
|
131 |
base_alphas[lora_module_name] = alpha
|
132 |
+
|
133 |
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
134 |
|
135 |
# merge
|
|
|
151 |
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
152 |
else:
|
153 |
merged_sd[key] = lora_sd[key] * scale
|
154 |
+
|
155 |
# set alpha to sd
|
156 |
for lora_module_name, alpha in base_alphas.items():
|
157 |
key = lora_module_name + ".alpha"
|
networks/resize_lora.py
CHANGED
@@ -1,14 +1,15 @@
|
|
1 |
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
2 |
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
3 |
-
# Thanks to cloneofsimo
|
4 |
|
5 |
import argparse
|
6 |
-
import os
|
7 |
import torch
|
8 |
from safetensors.torch import load_file, save_file, safe_open
|
9 |
from tqdm import tqdm
|
10 |
from library import train_util, model_util
|
|
|
11 |
|
|
|
12 |
|
13 |
def load_state_dict(file_name, dtype):
|
14 |
if model_util.is_safetensors(file_name):
|
@@ -38,12 +39,149 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
|
|
38 |
torch.save(model, file_name)
|
39 |
|
40 |
|
41 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
network_alpha = None
|
43 |
network_dim = None
|
44 |
verbose_str = "\n"
|
45 |
-
|
46 |
-
CLAMP_QUANTILE = 0.99
|
47 |
|
48 |
# Extract loaded lora dim and alpha
|
49 |
for key, value in lora_sd.items():
|
@@ -57,9 +195,9 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
|
57 |
network_alpha = network_dim
|
58 |
|
59 |
scale = network_alpha/network_dim
|
60 |
-
new_alpha = float(scale*new_rank) # calculate new alpha from scale
|
61 |
|
62 |
-
|
|
|
63 |
|
64 |
lora_down_weight = None
|
65 |
lora_up_weight = None
|
@@ -68,7 +206,6 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
|
68 |
block_down_name = None
|
69 |
block_up_name = None
|
70 |
|
71 |
-
print("resizing lora...")
|
72 |
with torch.no_grad():
|
73 |
for key, value in tqdm(lora_sd.items()):
|
74 |
if 'lora_down' in key:
|
@@ -85,57 +222,43 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
|
85 |
conv2d = (len(lora_down_weight.size()) == 4)
|
86 |
|
87 |
if conv2d:
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
lora_up_weight = lora_up_weight.to(args.device)
|
94 |
-
lora_down_weight = lora_down_weight.to(args.device)
|
95 |
-
|
96 |
-
full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
|
97 |
-
|
98 |
-
U, S, Vh = torch.linalg.svd(full_weight_matrix)
|
99 |
|
100 |
if verbose:
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
U = U @ torch.diag(S)
|
109 |
|
110 |
-
|
|
|
|
|
|
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
U = U.clamp(low_val, hi_val)
|
117 |
-
Vh = Vh.clamp(low_val, hi_val)
|
118 |
-
|
119 |
-
if conv2d:
|
120 |
-
U = U.unsqueeze(2).unsqueeze(3)
|
121 |
-
Vh = Vh.unsqueeze(2).unsqueeze(3)
|
122 |
-
|
123 |
-
if device:
|
124 |
-
U = U.to(org_device)
|
125 |
-
Vh = Vh.to(org_device)
|
126 |
-
|
127 |
-
o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
|
128 |
-
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
|
129 |
-
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
|
130 |
|
131 |
block_down_name = None
|
132 |
block_up_name = None
|
133 |
lora_down_weight = None
|
134 |
lora_up_weight = None
|
135 |
weights_loaded = False
|
|
|
136 |
|
137 |
if verbose:
|
138 |
print(verbose_str)
|
|
|
|
|
139 |
print("resizing complete")
|
140 |
return o_lora_sd, network_dim, new_alpha
|
141 |
|
@@ -151,6 +274,9 @@ def resize(args):
|
|
151 |
return torch.bfloat16
|
152 |
return None
|
153 |
|
|
|
|
|
|
|
154 |
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
|
155 |
save_dtype = str_to_dtype(args.save_precision)
|
156 |
if save_dtype is None:
|
@@ -159,17 +285,23 @@ def resize(args):
|
|
159 |
print("loading Model...")
|
160 |
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
161 |
|
162 |
-
print("
|
163 |
-
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose)
|
164 |
|
165 |
# update metadata
|
166 |
if metadata is None:
|
167 |
metadata = {}
|
168 |
|
169 |
comment = metadata.get("ss_training_comment", "")
|
170 |
-
|
171 |
-
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
175 |
metadata["sshs_model_hash"] = model_hash
|
@@ -193,6 +325,11 @@ if __name__ == '__main__':
|
|
193 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
194 |
parser.add_argument("--verbose", action="store_true",
|
195 |
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
args = parser.parse_args()
|
198 |
resize(args)
|
|
|
1 |
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
2 |
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
3 |
+
# Thanks to cloneofsimo
|
4 |
|
5 |
import argparse
|
|
|
6 |
import torch
|
7 |
from safetensors.torch import load_file, save_file, safe_open
|
8 |
from tqdm import tqdm
|
9 |
from library import train_util, model_util
|
10 |
+
import numpy as np
|
11 |
|
12 |
+
MIN_SV = 1e-6
|
13 |
|
14 |
def load_state_dict(file_name, dtype):
|
15 |
if model_util.is_safetensors(file_name):
|
|
|
39 |
torch.save(model, file_name)
|
40 |
|
41 |
|
42 |
+
def index_sv_cumulative(S, target):
|
43 |
+
original_sum = float(torch.sum(S))
|
44 |
+
cumulative_sums = torch.cumsum(S, dim=0)/original_sum
|
45 |
+
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
46 |
+
if index >= len(S):
|
47 |
+
index = len(S) - 1
|
48 |
+
|
49 |
+
return index
|
50 |
+
|
51 |
+
|
52 |
+
def index_sv_fro(S, target):
|
53 |
+
S_squared = S.pow(2)
|
54 |
+
s_fro_sq = float(torch.sum(S_squared))
|
55 |
+
sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
|
56 |
+
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
57 |
+
if index >= len(S):
|
58 |
+
index = len(S) - 1
|
59 |
+
|
60 |
+
return index
|
61 |
+
|
62 |
+
|
63 |
+
# Modified from Kohaku-blueleaf's extract/merge functions
|
64 |
+
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
65 |
+
out_size, in_size, kernel_size, _ = weight.size()
|
66 |
+
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
|
67 |
+
|
68 |
+
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
69 |
+
lora_rank = param_dict["new_rank"]
|
70 |
+
|
71 |
+
U = U[:, :lora_rank]
|
72 |
+
S = S[:lora_rank]
|
73 |
+
U = U @ torch.diag(S)
|
74 |
+
Vh = Vh[:lora_rank, :]
|
75 |
+
|
76 |
+
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
|
77 |
+
param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
|
78 |
+
del U, S, Vh, weight
|
79 |
+
return param_dict
|
80 |
+
|
81 |
+
|
82 |
+
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
83 |
+
out_size, in_size = weight.size()
|
84 |
+
|
85 |
+
U, S, Vh = torch.linalg.svd(weight.to(device))
|
86 |
+
|
87 |
+
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
88 |
+
lora_rank = param_dict["new_rank"]
|
89 |
+
|
90 |
+
U = U[:, :lora_rank]
|
91 |
+
S = S[:lora_rank]
|
92 |
+
U = U @ torch.diag(S)
|
93 |
+
Vh = Vh[:lora_rank, :]
|
94 |
+
|
95 |
+
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
|
96 |
+
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
|
97 |
+
del U, S, Vh, weight
|
98 |
+
return param_dict
|
99 |
+
|
100 |
+
|
101 |
+
def merge_conv(lora_down, lora_up, device):
|
102 |
+
in_rank, in_size, kernel_size, k_ = lora_down.shape
|
103 |
+
out_size, out_rank, _, _ = lora_up.shape
|
104 |
+
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
|
105 |
+
|
106 |
+
lora_down = lora_down.to(device)
|
107 |
+
lora_up = lora_up.to(device)
|
108 |
+
|
109 |
+
merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
|
110 |
+
weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
|
111 |
+
del lora_up, lora_down
|
112 |
+
return weight
|
113 |
+
|
114 |
+
|
115 |
+
def merge_linear(lora_down, lora_up, device):
|
116 |
+
in_rank, in_size = lora_down.shape
|
117 |
+
out_size, out_rank = lora_up.shape
|
118 |
+
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
|
119 |
+
|
120 |
+
lora_down = lora_down.to(device)
|
121 |
+
lora_up = lora_up.to(device)
|
122 |
+
|
123 |
+
weight = lora_up @ lora_down
|
124 |
+
del lora_up, lora_down
|
125 |
+
return weight
|
126 |
+
|
127 |
+
|
128 |
+
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
129 |
+
param_dict = {}
|
130 |
+
|
131 |
+
if dynamic_method=="sv_ratio":
|
132 |
+
# Calculate new dim and alpha based off ratio
|
133 |
+
max_sv = S[0]
|
134 |
+
min_sv = max_sv/dynamic_param
|
135 |
+
new_rank = max(torch.sum(S > min_sv).item(),1)
|
136 |
+
new_alpha = float(scale*new_rank)
|
137 |
+
|
138 |
+
elif dynamic_method=="sv_cumulative":
|
139 |
+
# Calculate new dim and alpha based off cumulative sum
|
140 |
+
new_rank = index_sv_cumulative(S, dynamic_param)
|
141 |
+
new_rank = max(new_rank, 1)
|
142 |
+
new_alpha = float(scale*new_rank)
|
143 |
+
|
144 |
+
elif dynamic_method=="sv_fro":
|
145 |
+
# Calculate new dim and alpha based off sqrt sum of squares
|
146 |
+
new_rank = index_sv_fro(S, dynamic_param)
|
147 |
+
new_rank = min(max(new_rank, 1), len(S)-1)
|
148 |
+
new_alpha = float(scale*new_rank)
|
149 |
+
else:
|
150 |
+
new_rank = rank
|
151 |
+
new_alpha = float(scale*new_rank)
|
152 |
+
|
153 |
+
|
154 |
+
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
|
155 |
+
new_rank = 1
|
156 |
+
new_alpha = float(scale*new_rank)
|
157 |
+
elif new_rank > rank: # cap max rank at rank
|
158 |
+
new_rank = rank
|
159 |
+
new_alpha = float(scale*new_rank)
|
160 |
+
|
161 |
+
|
162 |
+
# Calculate resize info
|
163 |
+
s_sum = torch.sum(torch.abs(S))
|
164 |
+
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
165 |
+
|
166 |
+
S_squared = S.pow(2)
|
167 |
+
s_fro = torch.sqrt(torch.sum(S_squared))
|
168 |
+
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
|
169 |
+
fro_percent = float(s_red_fro/s_fro)
|
170 |
+
|
171 |
+
param_dict["new_rank"] = new_rank
|
172 |
+
param_dict["new_alpha"] = new_alpha
|
173 |
+
param_dict["sum_retained"] = (s_rank)/s_sum
|
174 |
+
param_dict["fro_retained"] = fro_percent
|
175 |
+
param_dict["max_ratio"] = S[0]/S[new_rank]
|
176 |
+
|
177 |
+
return param_dict
|
178 |
+
|
179 |
+
|
180 |
+
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
181 |
network_alpha = None
|
182 |
network_dim = None
|
183 |
verbose_str = "\n"
|
184 |
+
fro_list = []
|
|
|
185 |
|
186 |
# Extract loaded lora dim and alpha
|
187 |
for key, value in lora_sd.items():
|
|
|
195 |
network_alpha = network_dim
|
196 |
|
197 |
scale = network_alpha/network_dim
|
|
|
198 |
|
199 |
+
if dynamic_method:
|
200 |
+
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
|
201 |
|
202 |
lora_down_weight = None
|
203 |
lora_up_weight = None
|
|
|
206 |
block_down_name = None
|
207 |
block_up_name = None
|
208 |
|
|
|
209 |
with torch.no_grad():
|
210 |
for key, value in tqdm(lora_sd.items()):
|
211 |
if 'lora_down' in key:
|
|
|
222 |
conv2d = (len(lora_down_weight.size()) == 4)
|
223 |
|
224 |
if conv2d:
|
225 |
+
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
226 |
+
param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
227 |
+
else:
|
228 |
+
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
|
229 |
+
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
if verbose:
|
232 |
+
max_ratio = param_dict['max_ratio']
|
233 |
+
sum_retained = param_dict['sum_retained']
|
234 |
+
fro_retained = param_dict['fro_retained']
|
235 |
+
if not np.isnan(fro_retained):
|
236 |
+
fro_list.append(float(fro_retained))
|
237 |
|
238 |
+
verbose_str+=f"{block_down_name:75} | "
|
239 |
+
verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
|
|
|
240 |
|
241 |
+
if verbose and dynamic_method:
|
242 |
+
verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
|
243 |
+
else:
|
244 |
+
verbose_str+=f"\n"
|
245 |
|
246 |
+
new_alpha = param_dict['new_alpha']
|
247 |
+
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
|
248 |
+
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
|
249 |
+
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
|
251 |
block_down_name = None
|
252 |
block_up_name = None
|
253 |
lora_down_weight = None
|
254 |
lora_up_weight = None
|
255 |
weights_loaded = False
|
256 |
+
del param_dict
|
257 |
|
258 |
if verbose:
|
259 |
print(verbose_str)
|
260 |
+
|
261 |
+
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
262 |
print("resizing complete")
|
263 |
return o_lora_sd, network_dim, new_alpha
|
264 |
|
|
|
274 |
return torch.bfloat16
|
275 |
return None
|
276 |
|
277 |
+
if args.dynamic_method and not args.dynamic_param:
|
278 |
+
raise Exception("If using dynamic_method, then dynamic_param is required")
|
279 |
+
|
280 |
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
|
281 |
save_dtype = str_to_dtype(args.save_precision)
|
282 |
if save_dtype is None:
|
|
|
285 |
print("loading Model...")
|
286 |
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
287 |
|
288 |
+
print("Resizing Lora...")
|
289 |
+
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
|
290 |
|
291 |
# update metadata
|
292 |
if metadata is None:
|
293 |
metadata = {}
|
294 |
|
295 |
comment = metadata.get("ss_training_comment", "")
|
296 |
+
|
297 |
+
if not args.dynamic_method:
|
298 |
+
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
|
299 |
+
metadata["ss_network_dim"] = str(args.new_rank)
|
300 |
+
metadata["ss_network_alpha"] = str(new_alpha)
|
301 |
+
else:
|
302 |
+
metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
|
303 |
+
metadata["ss_network_dim"] = 'Dynamic'
|
304 |
+
metadata["ss_network_alpha"] = 'Dynamic'
|
305 |
|
306 |
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
307 |
metadata["sshs_model_hash"] = model_hash
|
|
|
325 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
326 |
parser.add_argument("--verbose", action="store_true",
|
327 |
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
|
328 |
+
parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
|
329 |
+
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
|
330 |
+
parser.add_argument("--dynamic_param", type=float, default=None,
|
331 |
+
help="Specify target for dynamic reduction")
|
332 |
+
|
333 |
|
334 |
args = parser.parse_args()
|
335 |
resize(args)
|
networks/svd_merge_lora.py
CHANGED
@@ -23,19 +23,20 @@ def load_state_dict(file_name, dtype):
|
|
23 |
return sd
|
24 |
|
25 |
|
26 |
-
def save_to_file(file_name,
|
27 |
if dtype is not None:
|
28 |
for key in list(state_dict.keys()):
|
29 |
if type(state_dict[key]) == torch.Tensor:
|
30 |
state_dict[key] = state_dict[key].to(dtype)
|
31 |
|
32 |
if os.path.splitext(file_name)[1] == '.safetensors':
|
33 |
-
save_file(
|
34 |
else:
|
35 |
-
torch.save(
|
36 |
|
37 |
|
38 |
-
def merge_lora_models(models, ratios, new_rank, device,
|
|
|
39 |
merged_sd = {}
|
40 |
for model, ratio in zip(models, ratios):
|
41 |
print(f"loading: {model}")
|
@@ -58,11 +59,12 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
|
58 |
in_dim = down_weight.size()[1]
|
59 |
out_dim = up_weight.size()[0]
|
60 |
conv2d = len(down_weight.size()) == 4
|
61 |
-
|
|
|
62 |
|
63 |
# make original weight if not exist
|
64 |
if lora_module_name not in merged_sd:
|
65 |
-
weight = torch.zeros((out_dim, in_dim,
|
66 |
if device:
|
67 |
weight = weight.to(device)
|
68 |
else:
|
@@ -75,11 +77,18 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
|
75 |
|
76 |
# W <- W + U * D
|
77 |
scale = (alpha / network_dim)
|
|
|
|
|
|
|
|
|
78 |
if not conv2d: # linear
|
79 |
weight = weight + ratio * (up_weight @ down_weight) * scale
|
80 |
-
|
81 |
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
82 |
).unsqueeze(2).unsqueeze(3) * scale
|
|
|
|
|
|
|
83 |
|
84 |
merged_sd[lora_module_name] = weight
|
85 |
|
@@ -89,16 +98,26 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
|
89 |
with torch.no_grad():
|
90 |
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
91 |
conv2d = (len(mat.size()) == 4)
|
|
|
|
|
|
|
|
|
92 |
if conv2d:
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
U, S, Vh = torch.linalg.svd(mat)
|
96 |
|
97 |
-
U = U[:, :
|
98 |
-
S = S[:
|
99 |
U = U @ torch.diag(S)
|
100 |
|
101 |
-
Vh = Vh[:
|
102 |
|
103 |
dist = torch.cat([U.flatten(), Vh.flatten()])
|
104 |
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
@@ -107,16 +126,16 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
|
107 |
U = U.clamp(low_val, hi_val)
|
108 |
Vh = Vh.clamp(low_val, hi_val)
|
109 |
|
|
|
|
|
|
|
|
|
110 |
up_weight = U
|
111 |
down_weight = Vh
|
112 |
|
113 |
-
if conv2d:
|
114 |
-
up_weight = up_weight.unsqueeze(2).unsqueeze(3)
|
115 |
-
down_weight = down_weight.unsqueeze(2).unsqueeze(3)
|
116 |
-
|
117 |
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
|
118 |
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
|
119 |
-
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(
|
120 |
|
121 |
return merged_lora_sd
|
122 |
|
@@ -138,10 +157,11 @@ def merge(args):
|
|
138 |
if save_dtype is None:
|
139 |
save_dtype = merge_dtype
|
140 |
|
141 |
-
|
|
|
142 |
|
143 |
print(f"saving model to: {args.save_to}")
|
144 |
-
save_to_file(args.save_to, state_dict,
|
145 |
|
146 |
|
147 |
if __name__ == '__main__':
|
@@ -158,6 +178,8 @@ if __name__ == '__main__':
|
|
158 |
help="ratios for each model / それぞれのLoRAモデルの比率")
|
159 |
parser.add_argument("--new_rank", type=int, default=4,
|
160 |
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
|
|
|
|
161 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
162 |
|
163 |
args = parser.parse_args()
|
|
|
23 |
return sd
|
24 |
|
25 |
|
26 |
+
def save_to_file(file_name, state_dict, dtype):
|
27 |
if dtype is not None:
|
28 |
for key in list(state_dict.keys()):
|
29 |
if type(state_dict[key]) == torch.Tensor:
|
30 |
state_dict[key] = state_dict[key].to(dtype)
|
31 |
|
32 |
if os.path.splitext(file_name)[1] == '.safetensors':
|
33 |
+
save_file(state_dict, file_name)
|
34 |
else:
|
35 |
+
torch.save(state_dict, file_name)
|
36 |
|
37 |
|
38 |
+
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
|
39 |
+
print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
|
40 |
merged_sd = {}
|
41 |
for model, ratio in zip(models, ratios):
|
42 |
print(f"loading: {model}")
|
|
|
59 |
in_dim = down_weight.size()[1]
|
60 |
out_dim = up_weight.size()[0]
|
61 |
conv2d = len(down_weight.size()) == 4
|
62 |
+
kernel_size = None if not conv2d else down_weight.size()[2:4]
|
63 |
+
# print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
|
64 |
|
65 |
# make original weight if not exist
|
66 |
if lora_module_name not in merged_sd:
|
67 |
+
weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
|
68 |
if device:
|
69 |
weight = weight.to(device)
|
70 |
else:
|
|
|
77 |
|
78 |
# W <- W + U * D
|
79 |
scale = (alpha / network_dim)
|
80 |
+
|
81 |
+
if device: # and isinstance(scale, torch.Tensor):
|
82 |
+
scale = scale.to(device)
|
83 |
+
|
84 |
if not conv2d: # linear
|
85 |
weight = weight + ratio * (up_weight @ down_weight) * scale
|
86 |
+
elif kernel_size == (1, 1):
|
87 |
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
88 |
).unsqueeze(2).unsqueeze(3) * scale
|
89 |
+
else:
|
90 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
91 |
+
weight = weight + ratio * conved * scale
|
92 |
|
93 |
merged_sd[lora_module_name] = weight
|
94 |
|
|
|
98 |
with torch.no_grad():
|
99 |
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
100 |
conv2d = (len(mat.size()) == 4)
|
101 |
+
kernel_size = None if not conv2d else mat.size()[2:4]
|
102 |
+
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
103 |
+
out_dim, in_dim = mat.size()[0:2]
|
104 |
+
|
105 |
if conv2d:
|
106 |
+
if conv2d_3x3:
|
107 |
+
mat = mat.flatten(start_dim=1)
|
108 |
+
else:
|
109 |
+
mat = mat.squeeze()
|
110 |
+
|
111 |
+
module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
|
112 |
+
module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
113 |
|
114 |
U, S, Vh = torch.linalg.svd(mat)
|
115 |
|
116 |
+
U = U[:, :module_new_rank]
|
117 |
+
S = S[:module_new_rank]
|
118 |
U = U @ torch.diag(S)
|
119 |
|
120 |
+
Vh = Vh[:module_new_rank, :]
|
121 |
|
122 |
dist = torch.cat([U.flatten(), Vh.flatten()])
|
123 |
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
|
|
126 |
U = U.clamp(low_val, hi_val)
|
127 |
Vh = Vh.clamp(low_val, hi_val)
|
128 |
|
129 |
+
if conv2d:
|
130 |
+
U = U.reshape(out_dim, module_new_rank, 1, 1)
|
131 |
+
Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1])
|
132 |
+
|
133 |
up_weight = U
|
134 |
down_weight = Vh
|
135 |
|
|
|
|
|
|
|
|
|
136 |
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
|
137 |
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
|
138 |
+
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(module_new_rank)
|
139 |
|
140 |
return merged_lora_sd
|
141 |
|
|
|
157 |
if save_dtype is None:
|
158 |
save_dtype = merge_dtype
|
159 |
|
160 |
+
new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
|
161 |
+
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype)
|
162 |
|
163 |
print(f"saving model to: {args.save_to}")
|
164 |
+
save_to_file(args.save_to, state_dict, save_dtype)
|
165 |
|
166 |
|
167 |
if __name__ == '__main__':
|
|
|
178 |
help="ratios for each model / それぞれのLoRAモデルの比率")
|
179 |
parser.add_argument("--new_rank", type=int, default=4,
|
180 |
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
181 |
+
parser.add_argument("--new_conv_rank", type=int, default=None,
|
182 |
+
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ")
|
183 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
184 |
|
185 |
args = parser.parse_args()
|
requirements.txt
CHANGED
@@ -12,6 +12,8 @@ safetensors==0.2.6
|
|
12 |
gradio==3.16.2
|
13 |
altair==4.2.2
|
14 |
easygui==0.98.3
|
|
|
|
|
15 |
# for BLIP captioning
|
16 |
requests==2.28.2
|
17 |
timm==0.6.12
|
|
|
12 |
gradio==3.16.2
|
13 |
altair==4.2.2
|
14 |
easygui==0.98.3
|
15 |
+
toml==0.10.2
|
16 |
+
voluptuous==0.13.1
|
17 |
# for BLIP captioning
|
18 |
requests==2.28.2
|
19 |
timm==0.6.12
|
tools/canny.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import cv2
|
3 |
+
|
4 |
+
|
5 |
+
def canny(args):
|
6 |
+
img = cv2.imread(args.input)
|
7 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
8 |
+
|
9 |
+
canny_img = cv2.Canny(img, args.thres1, args.thres2)
|
10 |
+
# canny_img = 255 - canny_img
|
11 |
+
|
12 |
+
cv2.imwrite(args.output, canny_img)
|
13 |
+
print("done!")
|
14 |
+
|
15 |
+
|
16 |
+
if __name__ == '__main__':
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument("--input", type=str, default=None, help="input path")
|
19 |
+
parser.add_argument("--output", type=str, default=None, help="output path")
|
20 |
+
parser.add_argument("--thres1", type=int, default=32, help="thres1")
|
21 |
+
parser.add_argument("--thres2", type=int, default=224, help="thres2")
|
22 |
+
|
23 |
+
args = parser.parse_args()
|
24 |
+
canny(args)
|
tools/original_control_net.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, NamedTuple, Any
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
from safetensors.torch import load_file
|
6 |
+
|
7 |
+
from diffusers import UNet2DConditionModel
|
8 |
+
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
9 |
+
|
10 |
+
import library.model_util as model_util
|
11 |
+
|
12 |
+
|
13 |
+
class ControlNetInfo(NamedTuple):
|
14 |
+
unet: Any
|
15 |
+
net: Any
|
16 |
+
prep: Any
|
17 |
+
weight: float
|
18 |
+
ratio: float
|
19 |
+
|
20 |
+
|
21 |
+
class ControlNet(torch.nn.Module):
|
22 |
+
def __init__(self) -> None:
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
# make control model
|
26 |
+
self.control_model = torch.nn.Module()
|
27 |
+
|
28 |
+
dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280]
|
29 |
+
zero_convs = torch.nn.ModuleList()
|
30 |
+
for i, dim in enumerate(dims):
|
31 |
+
sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)])
|
32 |
+
zero_convs.append(sub_list)
|
33 |
+
self.control_model.add_module("zero_convs", zero_convs)
|
34 |
+
|
35 |
+
middle_block_out = torch.nn.Conv2d(1280, 1280, 1)
|
36 |
+
self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out]))
|
37 |
+
|
38 |
+
dims = [16, 16, 32, 32, 96, 96, 256, 320]
|
39 |
+
strides = [1, 1, 2, 1, 2, 1, 2, 1]
|
40 |
+
prev_dim = 3
|
41 |
+
input_hint_block = torch.nn.Sequential()
|
42 |
+
for i, (dim, stride) in enumerate(zip(dims, strides)):
|
43 |
+
input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1))
|
44 |
+
if i < len(dims) - 1:
|
45 |
+
input_hint_block.append(torch.nn.SiLU())
|
46 |
+
prev_dim = dim
|
47 |
+
self.control_model.add_module("input_hint_block", input_hint_block)
|
48 |
+
|
49 |
+
|
50 |
+
def load_control_net(v2, unet, model):
|
51 |
+
device = unet.device
|
52 |
+
|
53 |
+
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
|
54 |
+
# state dictを読み込む
|
55 |
+
print(f"ControlNet: loading control SD model : {model}")
|
56 |
+
|
57 |
+
if model_util.is_safetensors(model):
|
58 |
+
ctrl_sd_sd = load_file(model)
|
59 |
+
else:
|
60 |
+
ctrl_sd_sd = torch.load(model, map_location='cpu')
|
61 |
+
ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd)
|
62 |
+
|
63 |
+
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
|
64 |
+
is_difference = "difference" in ctrl_sd_sd
|
65 |
+
print("ControlNet: loading difference")
|
66 |
+
|
67 |
+
# ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
|
68 |
+
# またTransfer Controlの元weightとなる
|
69 |
+
ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
70 |
+
|
71 |
+
# 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける
|
72 |
+
for key in list(ctrl_unet_sd_sd.keys()):
|
73 |
+
ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone()
|
74 |
+
|
75 |
+
zero_conv_sd = {}
|
76 |
+
for key in list(ctrl_sd_sd.keys()):
|
77 |
+
if key.startswith("control_"):
|
78 |
+
unet_key = "model.diffusion_" + key[len("control_"):]
|
79 |
+
if unet_key not in ctrl_unet_sd_sd: # zero conv
|
80 |
+
zero_conv_sd[key] = ctrl_sd_sd[key]
|
81 |
+
continue
|
82 |
+
if is_difference: # Transfer Control
|
83 |
+
ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype)
|
84 |
+
else:
|
85 |
+
ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype)
|
86 |
+
|
87 |
+
unet_config = model_util.create_unet_diffusers_config(v2)
|
88 |
+
ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict
|
89 |
+
|
90 |
+
# ControlNetのU-Netを作成する
|
91 |
+
ctrl_unet = UNet2DConditionModel(**unet_config)
|
92 |
+
info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
|
93 |
+
print("ControlNet: loading Control U-Net:", info)
|
94 |
+
|
95 |
+
# U-Net以外のControlNetを作成する
|
96 |
+
# TODO support middle only
|
97 |
+
ctrl_net = ControlNet()
|
98 |
+
info = ctrl_net.load_state_dict(zero_conv_sd)
|
99 |
+
print("ControlNet: loading ControlNet:", info)
|
100 |
+
|
101 |
+
ctrl_unet.to(unet.device, dtype=unet.dtype)
|
102 |
+
ctrl_net.to(unet.device, dtype=unet.dtype)
|
103 |
+
return ctrl_unet, ctrl_net
|
104 |
+
|
105 |
+
|
106 |
+
def load_preprocess(prep_type: str):
|
107 |
+
if prep_type is None or prep_type.lower() == "none":
|
108 |
+
return None
|
109 |
+
|
110 |
+
if prep_type.startswith("canny"):
|
111 |
+
args = prep_type.split("_")
|
112 |
+
th1 = int(args[1]) if len(args) >= 2 else 63
|
113 |
+
th2 = int(args[2]) if len(args) >= 3 else 191
|
114 |
+
|
115 |
+
def canny(img):
|
116 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
117 |
+
return cv2.Canny(img, th1, th2)
|
118 |
+
return canny
|
119 |
+
|
120 |
+
print("Unsupported prep type:", prep_type)
|
121 |
+
return None
|
122 |
+
|
123 |
+
|
124 |
+
def preprocess_ctrl_net_hint_image(image):
|
125 |
+
image = np.array(image).astype(np.float32) / 255.0
|
126 |
+
image = image[:, :, ::-1].copy() # rgb to bgr
|
127 |
+
image = image[None].transpose(0, 3, 1, 2) # nchw
|
128 |
+
image = torch.from_numpy(image)
|
129 |
+
return image # 0 to 1
|
130 |
+
|
131 |
+
|
132 |
+
def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_size, hints):
|
133 |
+
guided_hints = []
|
134 |
+
for i, cnet_info in enumerate(control_nets):
|
135 |
+
# hintは 1枚目の画像のcnet1, 1枚目の画像のcnet2, 1枚目の画像のcnet3, 2枚目の画像のcnet1, 2枚目の画像のcnet2 ... と並んでいること
|
136 |
+
b_hints = []
|
137 |
+
if len(hints) == 1: # すべて同じ画像をhintとして使う
|
138 |
+
hint = hints[0]
|
139 |
+
if cnet_info.prep is not None:
|
140 |
+
hint = cnet_info.prep(hint)
|
141 |
+
hint = preprocess_ctrl_net_hint_image(hint)
|
142 |
+
b_hints = [hint for _ in range(b_size)]
|
143 |
+
else:
|
144 |
+
for bi in range(b_size):
|
145 |
+
hint = hints[(bi * len(control_nets) + i) % len(hints)]
|
146 |
+
if cnet_info.prep is not None:
|
147 |
+
hint = cnet_info.prep(hint)
|
148 |
+
hint = preprocess_ctrl_net_hint_image(hint)
|
149 |
+
b_hints.append(hint)
|
150 |
+
b_hints = torch.cat(b_hints, dim=0)
|
151 |
+
b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype)
|
152 |
+
|
153 |
+
guided_hint = cnet_info.net.control_model.input_hint_block(b_hints)
|
154 |
+
guided_hints.append(guided_hint)
|
155 |
+
return guided_hints
|
156 |
+
|
157 |
+
|
158 |
+
def call_unet_and_control_net(step, num_latent_input, original_unet, control_nets: List[ControlNetInfo], guided_hints, current_ratio, sample, timestep, encoder_hidden_states):
|
159 |
+
# ControlNet
|
160 |
+
# 複数のControlNetの場合は、出力をマージするのではなく交互に適用する
|
161 |
+
cnet_cnt = len(control_nets)
|
162 |
+
cnet_idx = step % cnet_cnt
|
163 |
+
cnet_info = control_nets[cnet_idx]
|
164 |
+
|
165 |
+
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
|
166 |
+
if cnet_info.ratio < current_ratio:
|
167 |
+
return original_unet(sample, timestep, encoder_hidden_states)
|
168 |
+
|
169 |
+
guided_hint = guided_hints[cnet_idx]
|
170 |
+
guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1))
|
171 |
+
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states)
|
172 |
+
outs = [o * cnet_info.weight for o in outs]
|
173 |
+
|
174 |
+
# U-Net
|
175 |
+
return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states)
|
176 |
+
|
177 |
+
|
178 |
+
"""
|
179 |
+
# これはmergeのバージョン
|
180 |
+
# ControlNet
|
181 |
+
cnet_outs_list = []
|
182 |
+
for i, cnet_info in enumerate(control_nets):
|
183 |
+
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
|
184 |
+
if cnet_info.ratio < current_ratio:
|
185 |
+
continue
|
186 |
+
guided_hint = guided_hints[i]
|
187 |
+
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states)
|
188 |
+
for i in range(len(outs)):
|
189 |
+
outs[i] *= cnet_info.weight
|
190 |
+
|
191 |
+
cnet_outs_list.append(outs)
|
192 |
+
|
193 |
+
count = len(cnet_outs_list)
|
194 |
+
if count == 0:
|
195 |
+
return original_unet(sample, timestep, encoder_hidden_states)
|
196 |
+
|
197 |
+
# sum of controlnets
|
198 |
+
for i in range(1, count):
|
199 |
+
cnet_outs_list[0] += cnet_outs_list[i]
|
200 |
+
|
201 |
+
# U-Net
|
202 |
+
return unet_forward(False, cnet_info.net, original_unet, None, cnet_outs_list[0], sample, timestep, encoder_hidden_states)
|
203 |
+
"""
|
204 |
+
|
205 |
+
|
206 |
+
def unet_forward(is_control_net, control_net: ControlNet, unet: UNet2DConditionModel, guided_hint, ctrl_outs, sample, timestep, encoder_hidden_states):
|
207 |
+
# copy from UNet2DConditionModel
|
208 |
+
default_overall_up_factor = 2**unet.num_upsamplers
|
209 |
+
|
210 |
+
forward_upsample_size = False
|
211 |
+
upsample_size = None
|
212 |
+
|
213 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
214 |
+
print("Forward upsample size to force interpolation output size.")
|
215 |
+
forward_upsample_size = True
|
216 |
+
|
217 |
+
# 0. center input if necessary
|
218 |
+
if unet.config.center_input_sample:
|
219 |
+
sample = 2 * sample - 1.0
|
220 |
+
|
221 |
+
# 1. time
|
222 |
+
timesteps = timestep
|
223 |
+
if not torch.is_tensor(timesteps):
|
224 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
225 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
226 |
+
is_mps = sample.device.type == "mps"
|
227 |
+
if isinstance(timestep, float):
|
228 |
+
dtype = torch.float32 if is_mps else torch.float64
|
229 |
+
else:
|
230 |
+
dtype = torch.int32 if is_mps else torch.int64
|
231 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
232 |
+
elif len(timesteps.shape) == 0:
|
233 |
+
timesteps = timesteps[None].to(sample.device)
|
234 |
+
|
235 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
236 |
+
timesteps = timesteps.expand(sample.shape[0])
|
237 |
+
|
238 |
+
t_emb = unet.time_proj(timesteps)
|
239 |
+
|
240 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
241 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
242 |
+
# there might be better ways to encapsulate this.
|
243 |
+
t_emb = t_emb.to(dtype=unet.dtype)
|
244 |
+
emb = unet.time_embedding(t_emb)
|
245 |
+
|
246 |
+
outs = [] # output of ControlNet
|
247 |
+
zc_idx = 0
|
248 |
+
|
249 |
+
# 2. pre-process
|
250 |
+
sample = unet.conv_in(sample)
|
251 |
+
if is_control_net:
|
252 |
+
sample += guided_hint
|
253 |
+
outs.append(control_net.control_model.zero_convs[zc_idx][0](sample)) # , emb, encoder_hidden_states))
|
254 |
+
zc_idx += 1
|
255 |
+
|
256 |
+
# 3. down
|
257 |
+
down_block_res_samples = (sample,)
|
258 |
+
for downsample_block in unet.down_blocks:
|
259 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
260 |
+
sample, res_samples = downsample_block(
|
261 |
+
hidden_states=sample,
|
262 |
+
temb=emb,
|
263 |
+
encoder_hidden_states=encoder_hidden_states,
|
264 |
+
)
|
265 |
+
else:
|
266 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
267 |
+
if is_control_net:
|
268 |
+
for rs in res_samples:
|
269 |
+
outs.append(control_net.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states))
|
270 |
+
zc_idx += 1
|
271 |
+
|
272 |
+
down_block_res_samples += res_samples
|
273 |
+
|
274 |
+
# 4. mid
|
275 |
+
sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
276 |
+
if is_control_net:
|
277 |
+
outs.append(control_net.control_model.middle_block_out[0](sample))
|
278 |
+
return outs
|
279 |
+
|
280 |
+
if not is_control_net:
|
281 |
+
sample += ctrl_outs.pop()
|
282 |
+
|
283 |
+
# 5. up
|
284 |
+
for i, upsample_block in enumerate(unet.up_blocks):
|
285 |
+
is_final_block = i == len(unet.up_blocks) - 1
|
286 |
+
|
287 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
288 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
289 |
+
|
290 |
+
if not is_control_net and len(ctrl_outs) > 0:
|
291 |
+
res_samples = list(res_samples)
|
292 |
+
apply_ctrl_outs = ctrl_outs[-len(res_samples):]
|
293 |
+
ctrl_outs = ctrl_outs[:-len(res_samples)]
|
294 |
+
for j in range(len(res_samples)):
|
295 |
+
res_samples[j] = res_samples[j] + apply_ctrl_outs[j]
|
296 |
+
res_samples = tuple(res_samples)
|
297 |
+
|
298 |
+
# if we have not reached the final block and need to forward the
|
299 |
+
# upsample size, we do it here
|
300 |
+
if not is_final_block and forward_upsample_size:
|
301 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
302 |
+
|
303 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
304 |
+
sample = upsample_block(
|
305 |
+
hidden_states=sample,
|
306 |
+
temb=emb,
|
307 |
+
res_hidden_states_tuple=res_samples,
|
308 |
+
encoder_hidden_states=encoder_hidden_states,
|
309 |
+
upsample_size=upsample_size,
|
310 |
+
)
|
311 |
+
else:
|
312 |
+
sample = upsample_block(
|
313 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
314 |
+
)
|
315 |
+
# 6. post-process
|
316 |
+
sample = unet.conv_norm_out(sample)
|
317 |
+
sample = unet.conv_act(sample)
|
318 |
+
sample = unet.conv_out(sample)
|
319 |
+
|
320 |
+
return UNet2DConditionOutput(sample=sample)
|
train_README-ja.md
ADDED
@@ -0,0 +1,936 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__ドキュメント更新中のため記述に誤りがあるかもしれません。__
|
2 |
+
|
3 |
+
# 学習について、共通編
|
4 |
+
|
5 |
+
当リポジトリではモデルのfine tuning、DreamBooth、およびLoRAとTextual Inversionの学習をサポートします。この文書ではそれらに共通する、学習データの準備方法やオプション等について説明します。
|
6 |
+
|
7 |
+
# 概要
|
8 |
+
|
9 |
+
あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。
|
10 |
+
|
11 |
+
|
12 |
+
以下について説明します。
|
13 |
+
|
14 |
+
1. 学習データの準備について(設定ファイルを用いる新形式)
|
15 |
+
1. 学習で使われる用語のごく簡単な解説
|
16 |
+
1. 以前の指定形式(設定ファイルを用いずコマンドラインから指定)
|
17 |
+
1. 学習途中のサンプル画像生成
|
18 |
+
1. 各スクリプトで共通の、よく使われるオプション
|
19 |
+
1. fine tuning 方式のメタデータ準備:キャプションニングなど
|
20 |
+
|
21 |
+
1.だけ実行すればとりあえず学習は可能です(学習については各スクリプトのドキュメントを参照)。2.以降は必要に応じて参照してください。
|
22 |
+
|
23 |
+
|
24 |
+
# 学習データの準備について
|
25 |
+
|
26 |
+
任意のフォルダ(複数でも可)に学習データの画像ファイルを用意しておきます。`.png`, `.jpg`, `.jpeg`, `.webp`, `.bmp` をサポートします。リサイズなどの前処理は基本的に必要ありません。
|
27 |
+
|
28 |
+
ただし学習解像度(後述)よりも極端に小さい画像は使わないか、あらかじめ超解像AIなどで拡大しておくことをお勧めします。また極端に大きな画像(3000x3000ピクセル程度?)よりも大きな画像はエラーになる場合があるようですので事前に縮小してください。
|
29 |
+
|
30 |
+
学習時には、モデルに学ばせる画像データを整理し、スクリプトに対して指定する必要があります。学習データの数、学習対象、キャプション(画像の説明)が用意できるか否かなどにより、いくつかの方法で学習データを指定できます。以下の方式があります(それぞれの名前は一般的なものではなく、当リポジトリ独自の定義です)。正則化画像については後述します。
|
31 |
+
|
32 |
+
1. DreamBooth、class+identifier方式(正則化画像使用可)
|
33 |
+
|
34 |
+
特定の単語 (identifier) に学習対象を紐づけるように学習します。キャプションを用意する必要はありません。たとえば特定のキャラを学ばせる場合に使うとキャプションを用意する必要がない分、手軽ですが、髪型や服装、背景など学習データの全要素が identifier に紐づけられて学習されるため、生成時のプロンプトで服が変えられない、といった事態も起こりえます。
|
35 |
+
|
36 |
+
1. DreamBooth、キャプション方式(正則化画像使用可)
|
37 |
+
|
38 |
+
画像ごとにキャプションが記録されたテキストファイルを用意して学習します。たとえば特定のキャラを学ばせると、画像の詳細をキャプションに記述することで(白い服を着たキャラA、赤い服を着たキャラA、など)キャラとそれ以外の要素が分離され、より厳密にモデルがキャラだけを学ぶことが期待できます。
|
39 |
+
|
40 |
+
1. fine tuning方式(正則化画像使用不可)
|
41 |
+
|
42 |
+
あらかじめキャプションをメタデータファイルにまとめます。タグとキャプションを分けて管理したり、学習を高速化するためlatentsを事前キャッシュしたりなどの機能をサポートします(いずれも別文書で説明しています)。(fine tuning方式という名前ですが fine tuning 以外でも使えます。)
|
43 |
+
|
44 |
+
学習したいものと使用できる指定方法の組み合わせは以下の通りです。
|
45 |
+
|
46 |
+
| 学習対象または方法 | スクリプト | DB / class+identifier | DB / キャプション | fine tuning |
|
47 |
+
| ----- | ----- | ----- | ----- | ----- |
|
48 |
+
| モデルをfine tuning | `fine_tune.py`| x | x | o |
|
49 |
+
| モデルをDreamBooth | `train_db.py`| o | o | x |
|
50 |
+
| LoRA | `train_network.py`| o | o | o |
|
51 |
+
| Textual Invesion | `train_textual_inversion.py`| o | o | o |
|
52 |
+
|
53 |
+
## どれを選ぶか
|
54 |
+
|
55 |
+
LoRA、Textual Inversionについては、手軽にキャプションファイルを用意せずに学習したい場合はDreamBooth class+identifier、用意できるならDreamBooth キャプション方式がよいでしょう。学習データの枚数が多く、かつ正則化画像を使用しない場合はfine tuning方式も検討してください。
|
56 |
+
|
57 |
+
DreamBoothについても同様ですが、fine tuning方式は使えません。fine tuningの場合はfine tuning方式のみです。
|
58 |
+
|
59 |
+
# 各方式の指定方法について
|
60 |
+
|
61 |
+
ここではそれぞれの指定方法で典型的なパターンについてだけ説明します。より詳細な指定方法については [データセット��定](./config_README-ja.md) をご覧ください。
|
62 |
+
|
63 |
+
# DreamBooth、class+identifier方式(正則化画像使用可)
|
64 |
+
|
65 |
+
この方式では、各画像は `class identifier` というキャプションで学習されたのと同じことになります(`shs dog` など)。
|
66 |
+
|
67 |
+
## step 1. identifierとclassを決める
|
68 |
+
|
69 |
+
学ばせたい対象を結びつける単語identifierと、対象の属するclassを決めます。
|
70 |
+
|
71 |
+
(instanceなどいろいろな呼び方がありますが、とりあえず元の論文に合わせます。)
|
72 |
+
|
73 |
+
以下ごく簡単に説明します(詳しくは調べてください)。
|
74 |
+
|
75 |
+
classは学習対象の一般的な種別です。たとえば特定の犬種を学ばせる場合には、classはdogになります。アニメキャラならモデルによりboyやgirl、1boyや1girlになるでしょう。
|
76 |
+
|
77 |
+
identifierは学習対象を識別して学習するためのものです。任意の単語で構いませんが、元論文によると「tokinizerで1トークンになる3文字以下でレアな単語」が良いとのことです。
|
78 |
+
|
79 |
+
identifierとclassを使い、たとえば「shs dog」などでモデルを学習することで、学習させたい対象をclassから識別して学習できます。
|
80 |
+
|
81 |
+
画像生成時には「shs dog」とすれば学ばせた犬種の画像が生成されます。
|
82 |
+
|
83 |
+
(identifierとして私が最近使っているものを参考までに挙げると、``shs sts scs cpc coc cic msm usu ici lvl cic dii muk ori hru rik koo yos wny`` などです。本当は Danbooru Tag に含まれないやつがより望ましいです。)
|
84 |
+
|
85 |
+
## step 2. 正則化画像を使うか否かを決め、使う場合には正則化画像を生成する
|
86 |
+
|
87 |
+
正則化画像とは、前述のclass全体が、学習対象に引っ張られることを防ぐための画像です(language drift)。正則化画像を使わないと、たとえば `shs 1girl` で特定のキャラクタを学ばせると、単なる `1girl` というプロンプトで生成してもそのキャラに似てきます。これは `1girl` が学習時のキャプションに含まれているためです。
|
88 |
+
|
89 |
+
学習対象の画像と正則化画像を同時に学ばせることで、class は class のままで留まり、identifier をプロンプトにつけた時だけ学習対象が生成されるようになります。
|
90 |
+
|
91 |
+
LoRAやDreamBoothで特定のキャラだけ出てくればよい場合は、正則化画像を用いなくても良いといえます。
|
92 |
+
|
93 |
+
Textual Inversionでは用いなくてよいでしょう(学ばせる token string がキャプションに含まれない場合はなにも学習されないため)。
|
94 |
+
|
95 |
+
正則化画像としては、学習対象のモデルで、class 名だけで生成した画像を用いるのが一般的です(たとえば `1girl`)。ただし生成画像の品質が悪い場合には、プロンプトを工夫したり、ネットから別途ダウンロードした画像を用いることもできます。
|
96 |
+
|
97 |
+
(正則化画像も学習されるため、その品質はモデルに影響します。)
|
98 |
+
|
99 |
+
一般的には数百枚程度、用意するのが望ましいようです(枚数が少ないと class 画像が一般化されずそれらの特徴を学んでしまいます)。
|
100 |
+
|
101 |
+
生成画像を使う場合、通常、生成画像のサイズは学習解像度(より正確にはbucketの解像度、後述)にあわせてください。
|
102 |
+
|
103 |
+
## step 2. 設定ファイルの記述
|
104 |
+
|
105 |
+
テキストファイルを作成し、拡張子を `.toml` にします。たとえば以下のように記述します。
|
106 |
+
|
107 |
+
(`#` で始まっている部分はコメントですので、このままコピペしてそのままでもよいですし、削除しても問題ありません。)
|
108 |
+
|
109 |
+
```toml
|
110 |
+
[general]
|
111 |
+
enable_bucket = true # Aspect Ratio Bucketingを使うか否か
|
112 |
+
|
113 |
+
[[datasets]]
|
114 |
+
resolution = 512 # 学習解像度
|
115 |
+
batch_size = 4 # バッチサイズ
|
116 |
+
|
117 |
+
[[datasets.subsets]]
|
118 |
+
image_dir = 'C:\hoge' # 学習用画像を入れたフォルダを指定
|
119 |
+
class_tokens = 'hoge girl' # identifier class を指定
|
120 |
+
num_repeats = 10 # 学習用画像の繰り返し回数
|
121 |
+
|
122 |
+
# 以下は正則化画像を用いる場合のみ記述する。用いない場合は削除する
|
123 |
+
[[datasets.subsets]]
|
124 |
+
is_reg = true
|
125 |
+
image_dir = 'C:\reg' # 正則化画像を入れたフォルダを指定
|
126 |
+
class_tokens = 'girl' # class を指定
|
127 |
+
num_repeats = 1 # 正則化画像の繰り返し回数、基本的には1でよい
|
128 |
+
```
|
129 |
+
|
130 |
+
基本的には以下の場所のみ書き換えれば学習できます。
|
131 |
+
|
132 |
+
1. 学習解像度
|
133 |
+
|
134 |
+
数値1つを指定すると正方形(`512`なら512x512)、鍵カッコカンマ区切りで2つ指定すると横×縦(`[512,768]`なら512x768)になります。SD1.x系では��ともとの学習解像度は512です。`[512,768]` 等の大きめの解像度を指定すると縦長、横長画像生成時の破綻を小さくできるかもしれません。SD2.x 768系では `768` です。
|
135 |
+
|
136 |
+
1. バッチサイズ
|
137 |
+
|
138 |
+
同時に何件のデータを学習するかを指定します。GPUのVRAMサイズ、学習解像度によって変わってきます。詳しくは後述します。またfine tuning/DreamBooth/LoRA等でも変わってきますので各スクリプトの説明もご覧ください。
|
139 |
+
|
140 |
+
1. フォルダ指定
|
141 |
+
|
142 |
+
学習用画像、正則化画像(使用する場合のみ)のフォルダを指定します。画像データが含まれているフォルダそのものを指定します。
|
143 |
+
|
144 |
+
1. identifier と class の指定
|
145 |
+
|
146 |
+
前述のサンプルの通りです。
|
147 |
+
|
148 |
+
1. 繰り返し回数
|
149 |
+
|
150 |
+
後述します。
|
151 |
+
|
152 |
+
### 繰り返し回数について
|
153 |
+
|
154 |
+
繰り返し回数は、正則化画像の枚数と学習用画像の枚数を調整するために用いられます。正則化画像の枚数は学習用画像よりも多いため、学習用画像を繰り返して枚数を合わせ、1対1の比率で学習できるようにします。
|
155 |
+
|
156 |
+
繰り返し回数は「 __学習用画像の繰り返し回数×学習用画像の枚数≧正則化画像の繰り返し回数×正則化画像の枚数__ 」となるように指定してください。
|
157 |
+
|
158 |
+
(1 epoch(データが一周すると1 epoch)のデータ数が「学習用画像の繰り返し回数×学習用画像の枚数」となります。正則化画像の枚数がそれより多いと、余った部分の正則化画像は使用されません。)
|
159 |
+
|
160 |
+
## step 3. 学習
|
161 |
+
|
162 |
+
それぞれのドキュメントを参考に学習を行ってください。
|
163 |
+
|
164 |
+
# DreamBooth、キャプション方式(正則化画像使用可)
|
165 |
+
|
166 |
+
この方式では各画像はキャプションで学習されます。
|
167 |
+
|
168 |
+
## step 1. キャプションファイルを準備する
|
169 |
+
|
170 |
+
学習用画像のフォルダに、画像と同じファイル名で、拡張子 `.caption`(設定で変えられます)のファイルを置いてください。それぞれのファイルは1行のみとしてください。エンコーディングは `UTF-8` です。
|
171 |
+
|
172 |
+
## step 2. 正則化画像を使うか否かを決め、使う場合には正則化画像を生成する
|
173 |
+
|
174 |
+
class+identifier形式と同様です。なお正則化画像にもキャプションを付けることができますが、通常は不要でしょう。
|
175 |
+
|
176 |
+
## step 2. 設定ファイルの記述
|
177 |
+
|
178 |
+
テキストファイルを作成し、拡張子を `.toml` にします。たとえば以下のように記述します。
|
179 |
+
|
180 |
+
```toml
|
181 |
+
[general]
|
182 |
+
enable_bucket = true # Aspect Ratio Bucketingを使うか否か
|
183 |
+
|
184 |
+
[[datasets]]
|
185 |
+
resolution = 512 # 学習解像度
|
186 |
+
batch_size = 4 # バッチサイズ
|
187 |
+
|
188 |
+
[[datasets.subsets]]
|
189 |
+
image_dir = 'C:\hoge' # 学習用画像を入れたフォルダを指定
|
190 |
+
caption_extension = '.caption' # キャプションファイルの拡張子 .txt を使う場合には書き換える
|
191 |
+
num_repeats = 10 # 学習用画像の繰り返し回数
|
192 |
+
|
193 |
+
# 以下は正則化画像を用いる場合のみ記述する。用いない場合は削除する
|
194 |
+
[[datasets.subsets]]
|
195 |
+
is_reg = true
|
196 |
+
image_dir = 'C:\reg' # 正則化画像を入れたフォルダを指定
|
197 |
+
class_tokens = 'girl' # class を指定
|
198 |
+
num_repeats = 1 # 正則化画像の繰り返し回数、基本的には1でよい
|
199 |
+
```
|
200 |
+
|
201 |
+
基本的には以下を場所のみ書き換えれば学習できます。特に記述がない部分は class+identifier 方式と同じです。
|
202 |
+
|
203 |
+
1. 学習解像度
|
204 |
+
1. バッチサイズ
|
205 |
+
1. フォルダ指定
|
206 |
+
1. キャプションファイルの拡張子
|
207 |
+
|
208 |
+
任意の拡張子を指定できます。
|
209 |
+
1. 繰り返し回数
|
210 |
+
|
211 |
+
## step 3. 学習
|
212 |
+
|
213 |
+
それぞれのドキュメントを参考に学習を行ってください。
|
214 |
+
|
215 |
+
# fine tuning 方式
|
216 |
+
|
217 |
+
## step 1. メタデータを準備する
|
218 |
+
|
219 |
+
キャプションやタグをまとめた管理用ファイルをメタデータと呼びます。json形式で拡張子は `.json`
|
220 |
+
です。作成方法は長くなりますのでこの文書の末尾に書きました。
|
221 |
+
|
222 |
+
## step 2. 設定ファイルの記述
|
223 |
+
|
224 |
+
テキストファイルを作成し、拡張子を `.toml` にします。たとえば以下のように記述します。
|
225 |
+
|
226 |
+
```toml
|
227 |
+
[general]
|
228 |
+
shuffle_caption = true
|
229 |
+
keep_tokens = 1
|
230 |
+
|
231 |
+
[[datasets]]
|
232 |
+
resolution = 512 # 学習解像度
|
233 |
+
batch_size = 4 # バッチサイズ
|
234 |
+
|
235 |
+
[[datasets.subsets]]
|
236 |
+
image_dir = 'C:\piyo' # 学習用画像を入れたフォルダを指定
|
237 |
+
metadata_file = 'C:\piyo\piyo_md.json' # メタデータファイル名
|
238 |
+
```
|
239 |
+
|
240 |
+
基本的には以下を場所のみ書き換えれば学習できます。特に記述がない部分は DreamBooth, class+identifier 方式と同じです。
|
241 |
+
|
242 |
+
1. 学習解像度
|
243 |
+
1. バッチサイズ
|
244 |
+
1. フォルダ指定
|
245 |
+
1. メタデータファイル名
|
246 |
+
|
247 |
+
後述の方法で作成したメタデータファイルを指定します。
|
248 |
+
|
249 |
+
|
250 |
+
## step 3. 学習
|
251 |
+
|
252 |
+
それぞれのドキュメントを参考に学習を行ってください。
|
253 |
+
|
254 |
+
# 学習で使われる用語のごく簡単な解説
|
255 |
+
|
256 |
+
細かいことは省略していますし私も完全には理解していないため、詳しくは各自お調べください。
|
257 |
+
|
258 |
+
## fine tuning(ファインチューニング)
|
259 |
+
|
260 |
+
モデルを学習して微調整することを指します。使われ方によって意味が異なってきますが、狭義のfine tuningはStable Diffusionの場合、モデルを画像とキャプションで学習することです。DreamBoothは狭義のfine tuningのひとつの特殊なやり方と言えます。広義のfine tuningは、LoRAやTextual Inversion、Hypernetworksなどを含み、モデルを学習することすべてを含みます。
|
261 |
+
|
262 |
+
## ステップ
|
263 |
+
|
264 |
+
ざっくりいうと学習データで1回計算すると1ステップです。「学習データのキャプションを今のモデルに流してみて、出てくる画像を学習データの画像と比較し、学習データに近づくようにモデルをわずかに変更する」のが1ステップです。
|
265 |
+
|
266 |
+
## バッチサイズ
|
267 |
+
|
268 |
+
バッチサイズは1ステップで何件のデータをまとめて計算するかを指定する値です。まとめて計算するため速度は相対的に向上します。また一般的には精度も高くなるといわれています。
|
269 |
+
|
270 |
+
`バッチサイズ×ステップ数` が学習に使われるデータの件数になります。そのため、バッチサイズを増やした分だけステップ数を減らすとよいでしょう。
|
271 |
+
|
272 |
+
(ただし、たとえば「バッチサイズ1で1600ステップ」と「バッチサイズ4で400ステップ」は同じ結果にはなりません。同じ学習率の場合、一般的には後者のほうが学習不足になります。学習率を多少大きくするか(たとえば `2e-6` など)、ステップ数をたとえば500ステップにするなどして工夫してください。)
|
273 |
+
|
274 |
+
バッチサイズを大きくするとその分だけGPUメモリを消費します。メモリが足りなくなるとエラーになりますし、エラーにならないギリギリでは学習速度が低下します。タスクマネージャーや `nvidia-smi` コマンドで使用メモリ量を確認しながら調整するとよいでしょう。
|
275 |
+
|
276 |
+
なお、バッチは「一塊のデータ」位の意味です。
|
277 |
+
|
278 |
+
## 学習率
|
279 |
+
|
280 |
+
ざっくりいうと1ステップごとにどのくらい変化させるかを表します。大きな値を指定するとそれだけ速く学習が進みますが、変化しすぎてモデルが壊れたり、最適な状態にまで至れない場合があります。小さい値を指定すると学習速度は遅くなり、また最適な状態にやはり至れない場合があります。
|
281 |
+
|
282 |
+
fine tuning、DreamBoooth、LoRAそれぞれで大きく異なり、また学習データや学習させたいモデル、バッチサイズやステップ数によっても変わってきます。一般的な値から初めて学習状態を見ながら増減してください。
|
283 |
+
|
284 |
+
デフォルトでは学習全体を通して学習率は固定です。スケジューラの指定で学習率をどう変化させるか決められますので、それらによっても結果は変わってきます。
|
285 |
+
|
286 |
+
## エポック(epoch)
|
287 |
+
|
288 |
+
学習データが一通り学習されると(データが一周すると)1 epochです。繰り返し回数を指定した場合は、その繰り返し後のデータが一周すると1 epochです。
|
289 |
+
|
290 |
+
1 epochのステップ数は、基本的には `データ件数÷バッチサイズ` ですが、Aspect Ratio Bucketing を使うと微妙に増えます(異なるbucketのデータは同じバッチにできないため、ステップ数が増えます)。
|
291 |
+
|
292 |
+
## Aspect Ratio Bucketing
|
293 |
+
|
294 |
+
Stable Diffusion のv1は512\*512で学習されていますが、それに加えて256\*1024や384\*640といった解像度でも学習します。これによりトリミングされる部分が減り、より正しくキャプションと画像の関係が学習されることが期待されます。
|
295 |
+
|
296 |
+
また任意の解像度で学習するため、事前に画像データの縦横比を統一しておく必要がなくなります。
|
297 |
+
|
298 |
+
設定で有効、向こうが切り替えられますが、ここまでの設定ファイルの記述例では有効になっています(`true` が設定されています)。
|
299 |
+
|
300 |
+
学習解像度はパラメータとして与えられた解像度の面積(=メモリ使用量)を超えない範囲で、64ピクセル単位(デフォルト、変��可)で縦横に調整、作成されます。
|
301 |
+
|
302 |
+
機械学習では入力サイズをすべて統一するのが一般的ですが、特に制約があるわけではなく、実際は同一のバッチ内で統一されていれば大丈夫です。NovelAIの言うbucketingは、あらかじめ教師データを、アスペクト比に応じた学習解像度ごとに分類しておくことを指しているようです。そしてバッチを各bucket内の画像で作成することで、バッチの画像サイズを統一します。
|
303 |
+
|
304 |
+
# 以前の指定形式(設定ファイルを用いずコマンドラインから指定)
|
305 |
+
|
306 |
+
`.toml` ファイルを指定せずコマンドラインオプションで指定する方法です。DreamBooth class+identifier方式、DreamBooth キャプション方式、fine tuning方式があります。
|
307 |
+
|
308 |
+
## DreamBooth、class+identifier方式
|
309 |
+
|
310 |
+
フォルダ名で繰り返し回数を指定します。また `train_data_dir` オプションと `reg_data_dir` オプションを用います。
|
311 |
+
|
312 |
+
### step 1. 学習用画像の準備
|
313 |
+
|
314 |
+
学習用画像を格納するフォルダを作成します。 __さらにその中に__ 、以下の名前でディレクトリを作成します。
|
315 |
+
|
316 |
+
```
|
317 |
+
<繰り返し回数>_<identifier> <class>
|
318 |
+
```
|
319 |
+
|
320 |
+
間の``_``を忘れないでください。
|
321 |
+
|
322 |
+
たとえば「sls frog」というプロンプトで、データを20回繰り返す場合、「20_sls frog」となります。以下のようになります。
|
323 |
+
|
324 |
+

|
325 |
+
|
326 |
+
### 複数class、複数対象(identifier)の学習
|
327 |
+
|
328 |
+
方法は単純で、学習用画像のフォルダ内に ``繰り返し回数_<identifier> <class>`` のフォルダを複数、正則化画像フォルダにも同様に ``繰り返し回数_<class>`` のフォルダを複数、用意してください。
|
329 |
+
|
330 |
+
たとえば「sls frog」と「cpc rabbit」を同時に学習する場合、以下のようになります。
|
331 |
+
|
332 |
+

|
333 |
+
|
334 |
+
classがひとつで対象が複数の場合、正則化画像フォルダはひとつで構いません。たとえば1girlにキャラAとキャラBがいる場合は次のようにします。
|
335 |
+
|
336 |
+
- train_girls
|
337 |
+
- 10_sls 1girl
|
338 |
+
- 10_cpc 1girl
|
339 |
+
- reg_girls
|
340 |
+
- 1_1girl
|
341 |
+
|
342 |
+
### step 2. 正則化画像の準備
|
343 |
+
|
344 |
+
正則化画像を使う場合の手順です。
|
345 |
+
|
346 |
+
正則化画像を格納するフォルダを作成します。 __さらにその中に__ ``<繰り返し回数>_<class>`` という名前でディレクトリを作成します。
|
347 |
+
|
348 |
+
たとえば「frog」というプロンプトで、データを繰り返さない(1回だけ)場合、以下のようになります。
|
349 |
+
|
350 |
+

|
351 |
+
|
352 |
+
|
353 |
+
### step 3. 学習の実行
|
354 |
+
|
355 |
+
各学習スクリプトを実行します。 `--train_data_dir` オプションで前述の学習用データのフォルダを(__画像を含むフォルダではなく、その親フォルダ__)、`--reg_data_dir` オプションで正則化画像のフォルダ(__画像を含むフォルダではなく、その親フォルダ__)を指定してください。
|
356 |
+
|
357 |
+
## DreamBooth、キャプション方式
|
358 |
+
|
359 |
+
学習用画像、正則化画像のフォルダに、画像と同じファイル名で、拡張子.caption(オプションで変えられます)のファイルを置くと、そのファイルからキャプションを読み込みプロンプトとして学習します。
|
360 |
+
|
361 |
+
※それらの画像の学習に、フォルダ名(identifier class)は使用されなくなります。
|
362 |
+
|
363 |
+
キャプションファイルの拡張子はデフォルトで.captionです。学習スクリプトの `--caption_extension` オプションで変更できます。`--shuffle_caption` オプションで学習時のキャプションについて、カンマ区切りの各部分をシャッフルしながら学習します。
|
364 |
+
|
365 |
+
## fine tuning 方式
|
366 |
+
|
367 |
+
メタデータを作るところまでは設定ファイルを使う場合と同様です。`in_json` オプションでメタデータファイルを指定します。
|
368 |
+
|
369 |
+
# 学習途中でのサンプル出力
|
370 |
+
|
371 |
+
学習中のモデルで試しに画像生成することで学習の進み方を確認できます。学習スクリプトに以下のオプションを指定します。
|
372 |
+
|
373 |
+
- `--sample_every_n_steps` / `--sample_every_n_epochs`
|
374 |
+
|
375 |
+
サンプル出力するステップ数またはエポック数を指定します。この数ごとにサンプル出力します。両方指定するとエポック数が優先されます。
|
376 |
+
|
377 |
+
- `--sample_prompts`
|
378 |
+
|
379 |
+
サンプル出力用プロンプトのファイルを指定します。
|
380 |
+
|
381 |
+
- `--sample_sampler`
|
382 |
+
|
383 |
+
サンプル出力に使うサンプラーを指定します。
|
384 |
+
`'ddim', 'pndm', 'heun', 'dpmsolver', 'dpmsolver++', 'dpmsingle', 'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'`が選べます。
|
385 |
+
|
386 |
+
サンプル出力を行うにはあらかじめプロンプトを記述したテキストファイルを用意しておく必要があります。1行につき1プロンプトで記述します。
|
387 |
+
|
388 |
+
たとえば以下のようになります。
|
389 |
+
|
390 |
+
```txt
|
391 |
+
# prompt 1
|
392 |
+
masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
|
393 |
+
|
394 |
+
# prompt 2
|
395 |
+
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
|
396 |
+
```
|
397 |
+
|
398 |
+
先頭が `#` の行はコメントになります。`--n` のように 「`--` + 英小文字」で生成画像へのオプションを指定できます。以下が使えます。
|
399 |
+
|
400 |
+
- `--n` 次のオプションまでをネガティブプロンプトとします。
|
401 |
+
- `--w` 生成画像の横幅を指定します。
|
402 |
+
- `--h` 生成画像の高さを指定します。
|
403 |
+
- `--d` 生成画像のseedを指定します。
|
404 |
+
- `--l` 生成画像のCFG scaleを指定します。
|
405 |
+
- `--s` 生成時のステップ数を指定します。
|
406 |
+
|
407 |
+
|
408 |
+
# 各スクリプトで共通の、よく使われるオプション
|
409 |
+
|
410 |
+
スクリプトの更新後、ドキュメントの更新が追い付いていない場合があります。その場合は `--help` オプションで使用できるオプションを確認してください。
|
411 |
+
|
412 |
+
## 学習に使うモデル指定
|
413 |
+
|
414 |
+
- `--v2` / `--v_parameterization`
|
415 |
+
|
416 |
+
学習対象モデルとしてHugging Faceのstable-diffusion-2-base、またはそこからのfine tuningモデルを使う場合(推論時に `v2-inference.yaml` を使うように指示されているモデルの場合)は `--v2` オプションを、stable-diffusion-2や768-v-ema.ckpt、およびそれらのfine tuningモデルを使う場合(推論時に `v2-inference-v.yaml` を使うモデルの場合)は `--v2` と `--v_parameterization` の両方のオプションを指定してください。
|
417 |
+
|
418 |
+
Stable Diffusion 2.0では大きく以下の点が変わっています。
|
419 |
+
|
420 |
+
1. 使用するTokenizer
|
421 |
+
2. 使用するText Encoderおよび使用する出力層(2.0は最後から二番目の層を使う)
|
422 |
+
3. Text Encoderの出力次元数(768->1024)
|
423 |
+
4. U-Netの構造(CrossAttentionのhead数など)
|
424 |
+
5. v-parameterization(サンプリング方法が変更されているらしい)
|
425 |
+
|
426 |
+
このうちbaseでは1~4が、baseのつかない方(768-v)では1~5が採用されています。1~4を有効にするのがv2オプション、5を有効にするのがv_parameterizationオプションです。
|
427 |
+
|
428 |
+
- `--pretrained_model_name_or_path`
|
429 |
+
|
430 |
+
追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。
|
431 |
+
|
432 |
+
## 学習に関する設定
|
433 |
+
|
434 |
+
- `--output_dir`
|
435 |
+
|
436 |
+
学習後のモデルを保存するフォルダを指定します。
|
437 |
+
|
438 |
+
- `--output_name`
|
439 |
+
|
440 |
+
モデルのファイル名を拡張子を除いて指定します。
|
441 |
+
|
442 |
+
- `--dataset_config`
|
443 |
+
|
444 |
+
データセットの設定を記述した `.toml` ファイルを指定します。
|
445 |
+
|
446 |
+
- `--max_train_steps` / `--max_train_epochs`
|
447 |
+
|
448 |
+
学習するステップ数やエポック数を指定します。両方指定するとエポック数のほうが優先されます。
|
449 |
+
|
450 |
+
- `--mixed_precision`
|
451 |
+
|
452 |
+
省メモリ化のため mixed precision (混合精度)で学習します。`--mixed_precision="fp16"` のように指定します。mixed precision なし(デフォルト)と比べて精度が低くなる可能性がありますが、学習に必要なGPUメモリ量が大きく減ります。
|
453 |
+
|
454 |
+
(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。
|
455 |
+
|
456 |
+
- `--gradient_checkpointing`
|
457 |
+
|
458 |
+
学習時の重みの計算をまとめて行うのではなく少しずつ行うことで、学習に必要なGPUメモリ量を減らします。オンオフは精度には影響しませんが、オンにするとバッチサイズを大きくできるため、そちらでの影響はあります。
|
459 |
+
|
460 |
+
また一般的にはオンにすると速度は低下しますが、バッチサイズを大きくできるので、トータルでの学習時間はむしろ速くなるかもしれません。
|
461 |
+
|
462 |
+
- `--xformers` / `--mem_eff_attn`
|
463 |
+
|
464 |
+
xformersオプションを指定するとxformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(xformersよりも速度は遅くなります)。
|
465 |
+
|
466 |
+
- `--save_precision`
|
467 |
+
|
468 |
+
保存時のデータ精度を指定します。save_precisionオプションにfloat、fp16、bf16のいずれかを指定すると、その形式でモデルを保存します(DreamBooth、fine tuningでDiffusers形式でモデルを保存する場合は無効です)。モデルのサイズを削減したい場合などにお使いください。
|
469 |
+
|
470 |
+
- `--save_every_n_epochs` / `--save_state` / `--resume`
|
471 |
+
save_every_n_epochsオプションに数値を指定すると、そのエポックごとに学習途中のモデルを保存します。
|
472 |
+
|
473 |
+
save_stateオプションを同時に指定すると、optimizer等の状態も含めた学習状態を合わせて保存します(保存したモデルからも学習再開できますが、それに比べると精度の向上、学習時間の短縮が期待できます)。保存先はフォルダになります。
|
474 |
+
|
475 |
+
学習状態は保存先フォルダに `<output_name>-??????-state`(??????はエポック数)という名前のフォルダで出力されます。長時間にわたる学習時にご利用ください。
|
476 |
+
|
477 |
+
保存された学習状態から学習を再開するにはresumeオプションを使います。学習状態のフォルダ(`output_dir` ではなくその中のstateのフォルダ)を指定してください。
|
478 |
+
|
479 |
+
なおAcceleratorの仕様により、エポック数、global stepは保存されておらず、resumeしたときにも1からになりますがご容赦ください。
|
480 |
+
|
481 |
+
- `--save_model_as` (DreamBooth, fine tuning のみ)
|
482 |
+
|
483 |
+
モデルの保存形式を`ckpt, safetensors, diffusers, diffusers_safetensors` から選べます。
|
484 |
+
|
485 |
+
`--save_model_as=safetensors` のように指定します。Stable Diffusion形式(ckptまたはsafetensors)を読み込み、Diffusers形式で保存する場合、不足する情報はHugging Faceからv1.5またはv2.1の情報を落としてきて補完します。
|
486 |
+
|
487 |
+
- `--clip_skip`
|
488 |
+
|
489 |
+
`2` を指定すると、Text Encoder (CLIP) の後ろから二番目の層の出力を用います。1またはオプション省略時は最後の層を用います。
|
490 |
+
|
491 |
+
※SD2.0はデフォルトで後ろから二番目の層を使うため、SD2.0の学習では指定しないでください。
|
492 |
+
|
493 |
+
学習対象のモデルがもともと二番目の層を使うように学習されている場合は、2を指定するとよいでしょう。
|
494 |
+
|
495 |
+
そうではなく最後の層を使用していた場合はモデル全体がそれを前提に学習されています。そのため改めて二番目の層を使用して学習すると、望ましい学習結果を得るにはある程度の枚数の教師データ、長めの学習が必要になるかもしれません。
|
496 |
+
|
497 |
+
- `--max_token_length`
|
498 |
+
|
499 |
+
デフォルトは75です。`150` または `225` を指定することでトークン長を拡張して学習できます。長いキャプションで学習する場合に指定してください。
|
500 |
+
|
501 |
+
ただし学習時のトークン拡張の仕様は Automatic1111 氏のWeb UIとは微妙に異なるため(分割の仕様など)、必要なければ75で学習することをお勧めします。
|
502 |
+
|
503 |
+
clip_skipと同様に、モデルの学習状態と異なる長さで学習するには、ある程度の教師データ枚数、長めの学習時間が必要になると思われます。
|
504 |
+
|
505 |
+
- `--persistent_data_loader_workers`
|
506 |
+
|
507 |
+
Windows環境で指定するとエポック間の待ち時間が大幅に短縮されます。
|
508 |
+
|
509 |
+
- `--max_data_loader_n_workers`
|
510 |
+
|
511 |
+
データ読み込みのプロセス数を指定します。プロセス数が多いとデータ読み込みが速くなりGPUを効率的に利用できますが、メインメモリを消費します。デフォルトは「`8` または `CPU同時実行スレッド数-1` の小さいほう」なので、メインメモリに余裕がない場合や、GPU使用率が90%程度以上なら、それらの数値を見ながら `2` または `1` 程度まで下げてください。
|
512 |
+
|
513 |
+
- `--logging_dir` / `--log_prefix`
|
514 |
+
|
515 |
+
学習ログの保存に関するオプションです。logging_dirオプションにログ保存先フォルダを指定してください。TensorBoard形式のログが保存されます。
|
516 |
+
|
517 |
+
たとえば--logging_dir=logsと指定すると、作業フォルダにlogsフォルダが作成され、その中の日時フォルダにログが保存されます。
|
518 |
+
また--log_prefixオプションを指定すると、日時の前に指定した文字列が追加されます。「--logging_dir=logs --log_prefix=db_style1_」などとして識別用にお使いください。
|
519 |
+
|
520 |
+
TensorBoardでログを確認するには、別のコマンドプロンプトを開き、作業フォルダで以下のように入力します。
|
521 |
+
|
522 |
+
```
|
523 |
+
tensorboard --logdir=logs
|
524 |
+
```
|
525 |
+
|
526 |
+
(tensorboardは環境整備時にあわせてインストールされると思いますが、もし入っていないなら `pip install tensorboard` で入れてください。)
|
527 |
+
|
528 |
+
その後ブラウザを開き、http://localhost:6006/ へアクセスすると表示されます。
|
529 |
+
|
530 |
+
- `--noise_offset`
|
531 |
+
|
532 |
+
こちらの記事の実装になります: https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
533 |
+
|
534 |
+
全体的に暗い、明るい画像の生成結果が良くなる可能性があるようです。LoRA学習でも有効なようです。`0.1` 程度の値を指定するとよいようです。
|
535 |
+
|
536 |
+
- `--debug_dataset`
|
537 |
+
|
538 |
+
このオプションを付けることで学習を行う前に事前にどのような画像データ、キャプションで学習されるかを確認できます。Escキーを押すと終了してコマンドラインに戻ります。
|
539 |
+
|
540 |
+
※Linux環境(Colabを含む)では画像は表示されません。
|
541 |
+
|
542 |
+
- `--vae`
|
543 |
+
|
544 |
+
vaeオプションにStable Diffusionのcheckpoint、VAEのcheckpointファイル、DiffusesのモデルまたはVAE(ともにローカルまたはHugging FaceのモデルIDが指定できます)のいずれかを指定すると、そのVAEを使って学習します(latentsのキャッシュ時または学習中のlatents取得時)。
|
545 |
+
|
546 |
+
DreamBoothおよびfine tuningでは、保存されるモデルはこのVAEを組み込んだものになります。
|
547 |
+
|
548 |
+
|
549 |
+
## オプティマイザ関係
|
550 |
+
|
551 |
+
- `--optimizer_type`
|
552 |
+
--オプティマイザの種類を指定します。以下が指定できます。
|
553 |
+
- AdamW : [torch.optim.AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html)
|
554 |
+
- 過去のバージョンのオプション未指定時と同じ
|
555 |
+
- AdamW8bit : 引数は同上
|
556 |
+
- 過去のバージョンの--use_8bit_adam指定時と同じ
|
557 |
+
- Lion : https://github.com/lucidrains/lion-pytorch
|
558 |
+
- 過去のバージョンの--use_lion_optimizer指定時と同じ
|
559 |
+
- SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True
|
560 |
+
- SGDNesterov8bit : 引数は同上
|
561 |
+
- DAdaptation : https://github.com/facebookresearch/dadaptation
|
562 |
+
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
|
563 |
+
- 任意のオプティマイザ
|
564 |
+
|
565 |
+
- `--learning_rate`
|
566 |
+
|
567 |
+
学習率を指定します。適切な学習率は学習スクリプトにより異なりますので、それぞれの説明を参照してください。
|
568 |
+
|
569 |
+
- `--lr_scheduler` / `--lr_warmup_steps` / `--lr_scheduler_num_cycles` / `--lr_scheduler_power`
|
570 |
+
|
571 |
+
学習率のスケジューラ関連の指定です。
|
572 |
+
|
573 |
+
lr_schedulerオプションで学習率のスケジューラをlinear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmupから選べます。デフォルトはconstantです。
|
574 |
+
|
575 |
+
lr_warmup_stepsでスケジューラのウォームアップ(だんだん学習率を変えていく)ステップ数を指定できます。
|
576 |
+
|
577 |
+
lr_scheduler_num_cycles は cosine with restartsスケジューラでのリスタート回数、lr_scheduler_power は polynomialスケジューラでのpolynomial power です。
|
578 |
+
|
579 |
+
詳細については各自お調べください。
|
580 |
+
|
581 |
+
### オプティマイザの指定について
|
582 |
+
|
583 |
+
オプティマイザのオプション引数は--optimizer_argsオプションで指定してください。key=valueの形式で、複数の値が指定できます。また、valueはカンマ区切りで複数の値が指定できます。たとえばAdamWオプティマイザに引数を指定する場合は、``--optimizer_args weight_decay=0.01 betas=.9,.999``のようになります。
|
584 |
+
|
585 |
+
オプション引数を指定する場合は、それぞれのオプティマイザの仕様をご確認ください。
|
586 |
+
|
587 |
+
一部のオプティマイザでは必須の引数があり、省略すると自動的に追加されます(SGDNesterovのmomentumなど)。コンソールの出力を確認してください。
|
588 |
+
|
589 |
+
D-Adaptationオプティマイザは学習率を自動調整します。学習率のオプションに指定した値は学習率そのものではなくD-Adaptationが決定した学習率の適用率になりますので、通常は1.0を指定してください。Text EncoderにU-Netの半分の学習率を指定したい場合は、``--text_encoder_lr=0.5 --unet_lr=1.0``と指定します。
|
590 |
+
|
591 |
+
AdaFactorオプティマイザはrelative_step=Trueを指定すると学習率を自動調整できます(省略時はデフォルトで追加されます)。自動調整する場合は学習率のスケジューラにはadafactor_schedulerが強制的に使用されます。またscale_parameterとwarmup_initを指定するとよいようです。
|
592 |
+
|
593 |
+
自動調整する場合のオプション指定はたとえば ``--optimizer_args "relative_step=True" "scale_parameter=True" "warmup_init=True"`` のようになります。
|
594 |
+
|
595 |
+
学習率を自動調整しない場合はオプション引数 ``relative_step=False`` を追加してください。その場合、学習率のスケジューラにはconstant_with_warmupが、また勾配のclip normをしないことが推奨されているようです。そのため引数は ``--optimizer_type=adafactor --optimizer_args "relative_step=False" --lr_scheduler="constant_with_warmup" --max_grad_norm=0.0`` のようになります。
|
596 |
+
|
597 |
+
### 任意のオプティマイザを使う
|
598 |
+
|
599 |
+
``torch.optim`` のオプティマイザを使う場合にはクラス名のみを(``--optimizer_type=RMSprop``など)、他のモジュールのオプティマイザを使う時は「モジュール名.クラス名」を指定してください(``--optimizer_type=bitsandbytes.optim.lamb.LAMB``など)。
|
600 |
+
|
601 |
+
(内部でimportlibしているだけで動作は未確認です。必要ならパッケージをインストールしてください。)
|
602 |
+
|
603 |
+
|
604 |
+
<!--
|
605 |
+
## 任意サイズの画像での学習 --resolution
|
606 |
+
正方形以外で学習できます。resolutionに「448,640」のように「幅,高さ」で指定してください。幅と高さは64で割り切れる必要があります。学習用画像、正則化画像のサイズを合わせてください。
|
607 |
+
|
608 |
+
個人的には縦長の画像を生成することが多いため「448,640」などで学習することもあります。
|
609 |
+
|
610 |
+
## Aspect Ratio Bucketing --enable_bucket / --min_bucket_reso / --max_bucket_reso
|
611 |
+
enable_bucketオプションを指定すると有効になります。Stable Diffusionは512x512で学習されていますが、それに加えて256x768や384x640といった解像度でも学習します。
|
612 |
+
|
613 |
+
このオプションを指定した場合は、学習用画像、正則化画像を特定の解像度に統一する必要はありません。いくつかの解像度(アスペクト比)から最適なものを選び、その解像度で学習します。
|
614 |
+
解像度は64ピクセル単位のため、元画像とアスペクト比が完全に一致しない場合がありますが、その場合は、はみ出した部分がわずかにトリミングされます。
|
615 |
+
|
616 |
+
解像度の最小サイズをmin_bucket_resoオプションで、最大サイズをmax_bucket_resoで指定できます。デフォルトはそれぞれ256、1024です。
|
617 |
+
たとえば最小サイズに384を指定すると、256x1024や320x768などの解像度は使わなくなります。
|
618 |
+
解像度を768x768のように大きくした場合、最大サイズに1280などを指定しても良いかもしれません。
|
619 |
+
|
620 |
+
なおAspect Ratio Bucketingを有効にするときには、正則化画像についても、学習用画像と似た傾向の様々な解像度を用意した方がいいかもしれません。
|
621 |
+
|
622 |
+
(ひとつのバッチ内の画像が学習用画像、正則化画像に偏らなくなるため。そこまで大きな影響はないと思いますが……。)
|
623 |
+
|
624 |
+
## augmentation --color_aug / --flip_aug
|
625 |
+
augmentationは学習時に動的にデータを変化させることで、モデルの性能を上げる手法です。color_augで色合いを微妙に変えつつ、flip_augで左右反転をしつつ、学習します。
|
626 |
+
|
627 |
+
動的にデータを変化させるため、cache_latentsオプションと同時に指定できません。
|
628 |
+
|
629 |
+
|
630 |
+
## 勾配をfp16とした学習(実験的機能) --full_fp16
|
631 |
+
full_fp16オプションを指定すると勾配を通常のfloat32からfloat16(fp16)に変更して学習します(mixed precisionではなく完全なfp16学習になるようです)。
|
632 |
+
これによりSD1.xの512x512サイズでは8GB未満、SD2.xの512x512サイズで12GB未満のVRAM使用量で学習できるようです。
|
633 |
+
|
634 |
+
あらかじめaccelerate configでfp16を指定し、オプションで ``mixed_precision="fp16"`` としてください(bf16では動作しません)。
|
635 |
+
|
636 |
+
メモリ使用量を最小化するためには、xformers、use_8bit_adam、cache_latents、gradient_checkpointingの各オプションを指定し、train_batch_sizeを1としてください。
|
637 |
+
|
638 |
+
(余裕があるようならtrain_batch_sizeを段階的に増やすと若干精度が上がるはずです。)
|
639 |
+
|
640 |
+
PyTorchのソースにパッチを当てて無理やり実現しています(PyTorch 1.12.1と1.13.0で確認)。精度はかなり落ちますし、途中で学習失敗する確率も高くなります。
|
641 |
+
学習率やステップ数の設定もシビアなようです。それらを認識したうえで自己責任でお使いください。
|
642 |
+
|
643 |
+
-->
|
644 |
+
|
645 |
+
# メタデータファイルの作成
|
646 |
+
|
647 |
+
## 教師データの用意
|
648 |
+
|
649 |
+
前述のように学習させたい画像データを用意し、任意のフォルダに入れてください。
|
650 |
+
|
651 |
+
たとえば以下のように画像を格納します。
|
652 |
+
|
653 |
+

|
654 |
+
|
655 |
+
## 自動キャプショニング
|
656 |
+
|
657 |
+
キャプションを使わずタグだけで学習する場合はスキップしてください。
|
658 |
+
|
659 |
+
また手動でキャプションを用意する場合、キャプションは教師データ画像と同じディレクトリに、同じファイル名、拡張子.caption等で用意してください。各ファイルは1行のみのテキストファイルとします。
|
660 |
+
|
661 |
+
### BLIPによるキャプショニング
|
662 |
+
|
663 |
+
最新版ではBLIPのダウンロード、重みのダウンロード、仮想環境の追加は不要になりました。そのままで動作します。
|
664 |
+
|
665 |
+
finetuneフォルダ内のmake_captions.pyを実行します。
|
666 |
+
|
667 |
+
```
|
668 |
+
python finetune\make_captions.py --batch_size <バッチサイズ> <教師データフォルダ>
|
669 |
+
```
|
670 |
+
|
671 |
+
バッチサイズ8、教師データを親フォルダのtrain_dataに置いた場合、以下のようになります。
|
672 |
+
|
673 |
+
```
|
674 |
+
python finetune\make_captions.py --batch_size 8 ..\train_data
|
675 |
+
```
|
676 |
+
|
677 |
+
キャプションファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.captionで作成されます。
|
678 |
+
|
679 |
+
batch_sizeはGPUのVRAM容量に応じて増減してください。大きいほうが速くなります(VRAM 12GBでももう少し増やせると思います)。
|
680 |
+
max_lengthオプションでキャプションの最大長を指定できます。デフォルトは75です。モデルをトークン長225で学習する場合には長くしても良いかもしれません。
|
681 |
+
caption_extensionオプションでキャプションの拡張子を変更できます。デフォルトは.captionです(.txtにすると後述のDeepDanbooruと競合します)。
|
682 |
+
|
683 |
+
複数の教師データフォルダがある場合には、それぞれのフォルダに対して実行してください。
|
684 |
+
|
685 |
+
なお、推論にランダム性があるため、実行するたびに結果が変わります。固定する場合には--seedオプションで `--seed 42` のように乱数seedを指定してください。
|
686 |
+
|
687 |
+
その他のオプションは `--help` でヘルプをご参照ください(パラメータの意味についてはドキュメントがまとまっていないようで、ソースを見るしかないようです)。
|
688 |
+
|
689 |
+
デフォルトでは拡張子.captionでキャプションファイルが生成されます。
|
690 |
+
|
691 |
+

|
692 |
+
|
693 |
+
たとえば以下のようなキャプションが付きます。
|
694 |
+
|
695 |
+

|
696 |
+
|
697 |
+
## DeepDanbooruによるタグ付け
|
698 |
+
|
699 |
+
danbooruタグのタグ付け自体を行わない場合は「キャプションとタグ情報の前処理」に進んでください。
|
700 |
+
|
701 |
+
タグ付けはDeepDanbooruまたはWD14Taggerで行います。WD14Taggerのほうが精度が良いようです。WD14Taggerでタグ付けする場合は、次の章へ進んでください。
|
702 |
+
|
703 |
+
### 環境整備
|
704 |
+
|
705 |
+
DeepDanbooru https://github.com/KichangKim/DeepDanbooru を作業フォルダにcloneしてくるか、zipをダウンロードして展開します。私はzipで展開しました。
|
706 |
+
またDeepDanbooruのReleasesのページ https://github.com/KichangKim/DeepDanbooru/releases の「DeepDanbooru Pretrained Model v3-20211112-sgd-e28」のAssetsから、deepdanbooru-v3-20211112-sgd-e28.zipをダウンロードしてきてDeepDanbooruのフォルダに展開します。
|
707 |
+
|
708 |
+
以下からダウンロードします。Assetsをクリックして開き、そこからダウンロードします。
|
709 |
+
|
710 |
+

|
711 |
+
|
712 |
+
以下のようなこういうディレクトリ構造にしてください
|
713 |
+
|
714 |
+

|
715 |
+
|
716 |
+
Diffusersの環境に必要なライブラリをインストールします。DeepDanbooruのフォルダに移動してインストールします(実質的にはtensorflow-ioが追加されるだけだと思います)。
|
717 |
+
|
718 |
+
```
|
719 |
+
pip install -r requirements.txt
|
720 |
+
```
|
721 |
+
|
722 |
+
続いてDeepDanbooru自体をインストールします。
|
723 |
+
|
724 |
+
```
|
725 |
+
pip install .
|
726 |
+
```
|
727 |
+
|
728 |
+
以上でタグ付けの環境整備は完了です。
|
729 |
+
|
730 |
+
### タグ付けの実施
|
731 |
+
DeepDanbooruのフォルダに移動し、deepdanbooruを実行してタグ付けを行います。
|
732 |
+
|
733 |
+
```
|
734 |
+
deepdanbooru evaluate <教師データフォルダ> --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt
|
735 |
+
```
|
736 |
+
|
737 |
+
教師データを親フォルダのtrain_dataに置いた場合、以下のようになります。
|
738 |
+
|
739 |
+
```
|
740 |
+
deepdanbooru evaluate ../train_data --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt
|
741 |
+
```
|
742 |
+
|
743 |
+
タグファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.txtで作成されます。1件ずつ処理されるためわりと遅いです。
|
744 |
+
|
745 |
+
複数の教師データフォルダがある場合には、それぞれのフォルダに対して実行してください。
|
746 |
+
|
747 |
+
以下のように生成されます。
|
748 |
+
|
749 |
+

|
750 |
+
|
751 |
+
こんな感じにタグが付きます(すごい情報量……)。
|
752 |
+
|
753 |
+

|
754 |
+
|
755 |
+
## WD14Taggerによるタグ付け
|
756 |
+
|
757 |
+
DeepDanbooruの代わりにWD14Taggerを用いる手順です。
|
758 |
+
|
759 |
+
Automatic1111氏のWebUIで使用しているtaggerを利用します。こちらのgithubページ(https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger )の情報を参考にさせていただきました。
|
760 |
+
|
761 |
+
最初の環境整備で必要なモジュールはインストール済みです。また重みはHugging Faceから自動的にダウンロードしてきます。
|
762 |
+
|
763 |
+
### タグ付けの実施
|
764 |
+
|
765 |
+
スクリプトを実行してタグ付けを行います。
|
766 |
+
```
|
767 |
+
python tag_images_by_wd14_tagger.py --batch_size <バッチサイズ> <教師データフォルダ>
|
768 |
+
```
|
769 |
+
|
770 |
+
教師データを親フォルダのtrain_dataに置いた場合、以下のようになります。
|
771 |
+
```
|
772 |
+
python tag_images_by_wd14_tagger.py --batch_size 4 ..\train_data
|
773 |
+
```
|
774 |
+
|
775 |
+
初回起動時にはモデルファイルがwd14_tagger_modelフォルダに自動的にダウンロードされます(フォルダはオプションで変えられます)。以下のようになります。
|
776 |
+
|
777 |
+

|
778 |
+
|
779 |
+
タグファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.txtで作成されます。
|
780 |
+
|
781 |
+

|
782 |
+
|
783 |
+

|
784 |
+
|
785 |
+
threshオプションで、判定されたタグのconfidence(確信度)がいくつ以上でタグをつけるかが指定できます。デフォルトはWD14Taggerのサンプルと同じ0.35です。値を下げるとより多くのタグが付与されますが、精度は下がります。
|
786 |
+
|
787 |
+
batch_sizeはGPUのVRAM容量に応じて増減してください。大きいほうが速くなります(VRAM 12GBでももう少し増やせると思います)。caption_extensionオプションでタグファイルの拡張子を変更できます。デフォルトは.txtです。
|
788 |
+
|
789 |
+
model_dirオプションでモデルの保存先フォルダを指定できます。
|
790 |
+
|
791 |
+
またforce_downloadオプションを指定すると保存先フォルダがあってもモデルを再ダウンロードします。
|
792 |
+
|
793 |
+
複数の教師データフォルダがある場合には、それぞれのフォルダに対して実行してください。
|
794 |
+
|
795 |
+
## キャプションとタグ情報の前処理
|
796 |
+
|
797 |
+
スクリプトから処理しやすいようにキャプションとタグをメタデータとしてひとつのファイルにまとめます。
|
798 |
+
|
799 |
+
### キャプションの前処理
|
800 |
+
|
801 |
+
キャプションをメタデータに入れるには、作業フォルダ内で以下を実行してください(キャプションを学習に使わない場合は実行不要です)(実際は1行で記述します、以下同様)。`--full_path` オプションを指定してメタデータに画像ファイルの場所をフルパスで格納します。このオプションを省略すると相対パスで記録されますが、フォルダ指定が `.toml` ファイル内で別途必要になります。
|
802 |
+
|
803 |
+
```
|
804 |
+
python merge_captions_to_metadata.py --full_apth <教師データフォルダ>
|
805 |
+
--in_json <読み込むメタデータファイル名> <メタデータファイル名>
|
806 |
+
```
|
807 |
+
|
808 |
+
メタデータファイル名は任意の名前です。
|
809 |
+
教師データがtrain_data、読み込むメタデータファイルなし、メタデータファイルがmeta_cap.jsonの場合、以下のようになります。
|
810 |
+
|
811 |
+
```
|
812 |
+
python merge_captions_to_metadata.py --full_path train_data meta_cap.json
|
813 |
+
```
|
814 |
+
|
815 |
+
caption_extensionオプションでキャプションの拡張子を指定できます。
|
816 |
+
|
817 |
+
複数の教師データフォルダがある場合には、full_path引数を指定しつつ、それぞれのフォルダに対して実行してください。
|
818 |
+
|
819 |
+
```
|
820 |
+
python merge_captions_to_metadata.py --full_path
|
821 |
+
train_data1 meta_cap1.json
|
822 |
+
python merge_captions_to_metadata.py --full_path --in_json meta_cap1.json
|
823 |
+
train_data2 meta_cap2.json
|
824 |
+
```
|
825 |
+
|
826 |
+
in_jsonを省略すると書き込み先メタデータファイルがあるとそこから読み込み、そこに上書きします。
|
827 |
+
|
828 |
+
__※in_json��プションと書き込み先を都度書き換えて、別のメタデータファイルへ書き出すようにすると安全です。__
|
829 |
+
|
830 |
+
### タグの前処理
|
831 |
+
|
832 |
+
同様にタグもメタデータにまとめます(タグを学習に使わない場合は実行不要です)。
|
833 |
+
```
|
834 |
+
python merge_dd_tags_to_metadata.py --full_path <教師データフォルダ>
|
835 |
+
--in_json <読み込むメタデータファイル名> <書き込むメタデータファイル名>
|
836 |
+
```
|
837 |
+
|
838 |
+
先と同じディレクトリ構成で、meta_cap.jsonを読み、meta_cap_dd.jsonに書きだす場合、以下となります。
|
839 |
+
```
|
840 |
+
python merge_dd_tags_to_metadata.py --full_path train_data --in_json meta_cap.json meta_cap_dd.json
|
841 |
+
```
|
842 |
+
|
843 |
+
複数の教師データフォルダがある場合には、full_path引数を指定しつつ、それぞれのフォルダに対して実行してください。
|
844 |
+
|
845 |
+
```
|
846 |
+
python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap2.json
|
847 |
+
train_data1 meta_cap_dd1.json
|
848 |
+
python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap_dd1.json
|
849 |
+
train_data2 meta_cap_dd2.json
|
850 |
+
```
|
851 |
+
|
852 |
+
in_jsonを省略すると書き込み先メタデータファイルがあるとそこから読み込み、そこに上書きします。
|
853 |
+
|
854 |
+
__※in_jsonオプションと書き込み先を都度書き換えて、別のメタデータファイルへ書き出すようにすると安全です。__
|
855 |
+
|
856 |
+
### キャプションとタグのクリーニング
|
857 |
+
|
858 |
+
ここまででメタデータファイルにキャプションとDeepDanbooruのタグがまとめられています。ただ自動キャプショニングにしたキャプションは表記ゆれなどがあり微妙(※)ですし、タグにはアンダースコアが含まれていたりratingが付いていたりしますので(DeepDanbooruの場合)、エディタの置換機能などを用いてキャプションとタグのクリーニングをしたほうがいいでしょう。
|
859 |
+
|
860 |
+
※たとえばアニメ絵の少女を学習する場合、キャプションにはgirl/girls/woman/womenなどのばらつきがあります。また「anime girl」なども単に「girl」としたほうが適切かもしれません。
|
861 |
+
|
862 |
+
クリーニング用のスクリプトが用意してありますので、スクリプトの内容を状況に応じて編集してお使いください。
|
863 |
+
|
864 |
+
(教師データフォルダの指定は不要になりました。メタデータ内の全データをクリーニングします。)
|
865 |
+
|
866 |
+
```
|
867 |
+
python clean_captions_and_tags.py <読み込むメタデータファイル名> <書き込むメタデータファイル名>
|
868 |
+
```
|
869 |
+
|
870 |
+
--in_jsonは付きませんのでご注意ください。たとえば次のようになります。
|
871 |
+
|
872 |
+
```
|
873 |
+
python clean_captions_and_tags.py meta_cap_dd.json meta_clean.json
|
874 |
+
```
|
875 |
+
|
876 |
+
以上でキャプションとタグの前処理は完了です。
|
877 |
+
|
878 |
+
## latentsの事前取得
|
879 |
+
|
880 |
+
※ このステップは必須ではありません。省略しても学習時にlatentsを取得しながら学習できます。
|
881 |
+
また学習時に `random_crop` や `color_aug` などを行う場合にはlatentsの事前取得はできません(画像を毎回変えながら学習するため)。事前取得をしない場合、ここまでのメタデータで学習できます。
|
882 |
+
|
883 |
+
あらかじめ画像の潜在表現を取得しディスクに保存しておきます。それにより、学習を高速に進めることができます。あわせてbucketing(教師データをアスペクト比に応じて分類する)を行います。
|
884 |
+
|
885 |
+
作業フォルダで以下のように入力してください。
|
886 |
+
```
|
887 |
+
python prepare_buckets_latents.py --full_path <教師データフォルダ>
|
888 |
+
<読み込むメタデータファイル名> <書き込むメタデータファイル名>
|
889 |
+
<fine tuningするモデル名またはcheckpoint>
|
890 |
+
--batch_size <バッチサイズ>
|
891 |
+
--max_resolution <解像度 幅,高さ>
|
892 |
+
--mixed_precision <精度>
|
893 |
+
```
|
894 |
+
|
895 |
+
モデルがmodel.ckpt、バッチサイズ4、学習解像度は512\*512、精度no(float32)で、meta_clean.jsonからメタデータを読み込み、meta_lat.jsonに書き込む場合、以下のようになります。
|
896 |
+
|
897 |
+
```
|
898 |
+
python prepare_buckets_latents.py --full_path
|
899 |
+
train_data meta_clean.json meta_lat.json model.ckpt
|
900 |
+
--batch_size 4 --max_resolution 512,512 --mixed_precision no
|
901 |
+
```
|
902 |
+
|
903 |
+
教師データフォルダにnumpyのnpz形式でlatentsが保存されます。
|
904 |
+
|
905 |
+
解像度の最小サイズを--min_bucket_resoオプションで、最大サイズを--max_bucket_resoで指定できます。デフォルトはそれぞれ256、1024です。たとえば最小サイズに384を指定すると、256\*1024や320\*768などの解像度は使わなくなります。
|
906 |
+
解像度を768\*768のように大きくした場合、最大サイズに1280などを指定すると良いでしょう。
|
907 |
+
|
908 |
+
--flip_augオプションを指定すると左右反転のaugmentation(データ拡張)を行います。疑似的にデータ量を二���に増やすことができますが、データが左右対称でない場合に指定すると(例えばキャラクタの外見、髪型など)学習がうまく行かなくなります。
|
909 |
+
|
910 |
+
|
911 |
+
(反転した画像についてもlatentsを取得し、\*\_flip.npzファイルを保存する単純な実装です。fline_tune.pyには特にオプション指定は必要ありません。\_flip付きのファイルがある場合、flip付き・なしのファイルを、ランダムに読み込みます。)
|
912 |
+
|
913 |
+
バッチサイズはVRAM 12GBでももう少し増やせるかもしれません。
|
914 |
+
解像度は64で割り切れる数字で、"幅,高さ"で指定します。解像度はfine tuning時のメモリサイズに直結します。VRAM 12GBでは512,512が限界と思われます(※)。16GBなら512,704や512,768まで上げられるかもしれません。なお256,256等にしてもVRAM 8GBでは厳しいようです(パラメータやoptimizerなどは解像度に関係せず一定のメモリが必要なため)。
|
915 |
+
|
916 |
+
※batch size 1の学習で12GB VRAM、640,640で動いたとの報告もありました。
|
917 |
+
|
918 |
+
以下のようにbucketingの結果が表示されます。
|
919 |
+
|
920 |
+

|
921 |
+
|
922 |
+
複数の教師データフォルダがある場合には、full_path引数を指定しつつ、それぞれのフォルダに対して実行してください。
|
923 |
+
```
|
924 |
+
python prepare_buckets_latents.py --full_path
|
925 |
+
train_data1 meta_clean.json meta_lat1.json model.ckpt
|
926 |
+
--batch_size 4 --max_resolution 512,512 --mixed_precision no
|
927 |
+
|
928 |
+
python prepare_buckets_latents.py --full_path
|
929 |
+
train_data2 meta_lat1.json meta_lat2.json model.ckpt
|
930 |
+
--batch_size 4 --max_resolution 512,512 --mixed_precision no
|
931 |
+
|
932 |
+
```
|
933 |
+
読み込み元と書き込み先を同じにすることも可能ですが別々の方が安全です。
|
934 |
+
|
935 |
+
__※引数を都度書き換えて、別のメタデータファイルに書き込むと安全です。__
|
936 |
+
|
train_db.py
CHANGED
@@ -15,7 +15,11 @@ import diffusers
|
|
15 |
from diffusers import DDPMScheduler
|
16 |
|
17 |
import library.train_util as train_util
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
|
20 |
|
21 |
def collate_fn(examples):
|
@@ -33,24 +37,33 @@ def train(args):
|
|
33 |
|
34 |
tokenizer = train_util.load_tokenizer(args)
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
|
46 |
-
|
47 |
|
48 |
-
|
|
|
49 |
|
50 |
if args.debug_dataset:
|
51 |
-
train_util.debug_dataset(
|
52 |
return
|
53 |
|
|
|
|
|
|
|
54 |
# acceleratorを準備する
|
55 |
print("prepare accelerator")
|
56 |
|
@@ -91,7 +104,7 @@ def train(args):
|
|
91 |
vae.requires_grad_(False)
|
92 |
vae.eval()
|
93 |
with torch.no_grad():
|
94 |
-
|
95 |
vae.to("cpu")
|
96 |
if torch.cuda.is_available():
|
97 |
torch.cuda.empty_cache()
|
@@ -115,38 +128,18 @@ def train(args):
|
|
115 |
|
116 |
# 学習に必要なクラスを準備する
|
117 |
print("prepare optimizer, data loader etc.")
|
118 |
-
|
119 |
-
# 8-bit Adamを使う
|
120 |
-
if args.use_8bit_adam:
|
121 |
-
try:
|
122 |
-
import bitsandbytes as bnb
|
123 |
-
except ImportError:
|
124 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
125 |
-
print("use 8-bit Adam optimizer")
|
126 |
-
optimizer_class = bnb.optim.AdamW8bit
|
127 |
-
elif args.use_lion_optimizer:
|
128 |
-
try:
|
129 |
-
import lion_pytorch
|
130 |
-
except ImportError:
|
131 |
-
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
132 |
-
print("use Lion optimizer")
|
133 |
-
optimizer_class = lion_pytorch.Lion
|
134 |
-
else:
|
135 |
-
optimizer_class = torch.optim.AdamW
|
136 |
-
|
137 |
if train_text_encoder:
|
138 |
trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
139 |
else:
|
140 |
trainable_params = unet.parameters()
|
141 |
|
142 |
-
|
143 |
-
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
144 |
|
145 |
# dataloaderを準備する
|
146 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
147 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
148 |
train_dataloader = torch.utils.data.DataLoader(
|
149 |
-
|
150 |
|
151 |
# 学習ステップ数を計算する
|
152 |
if args.max_train_epochs is not None:
|
@@ -156,9 +149,10 @@ def train(args):
|
|
156 |
if args.stop_text_encoder_training is None:
|
157 |
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
158 |
|
159 |
-
# lr schedulerを用意する
|
160 |
-
lr_scheduler =
|
161 |
-
|
|
|
162 |
|
163 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
164 |
if args.full_fp16:
|
@@ -195,8 +189,8 @@ def train(args):
|
|
195 |
# 学習する
|
196 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
197 |
print("running training / 学習開始")
|
198 |
-
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {
|
199 |
-
print(f" num reg images / 正則化画像の数: {
|
200 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
201 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
202 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
@@ -217,7 +211,7 @@ def train(args):
|
|
217 |
loss_total = 0.0
|
218 |
for epoch in range(num_train_epochs):
|
219 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
220 |
-
|
221 |
|
222 |
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
223 |
unet.train()
|
@@ -281,12 +275,12 @@ def train(args):
|
|
281 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
282 |
|
283 |
accelerator.backward(loss)
|
284 |
-
if accelerator.sync_gradients:
|
285 |
if train_text_encoder:
|
286 |
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
287 |
else:
|
288 |
params_to_clip = unet.parameters()
|
289 |
-
accelerator.clip_grad_norm_(params_to_clip,
|
290 |
|
291 |
optimizer.step()
|
292 |
lr_scheduler.step()
|
@@ -297,9 +291,13 @@ def train(args):
|
|
297 |
progress_bar.update(1)
|
298 |
global_step += 1
|
299 |
|
|
|
|
|
300 |
current_loss = loss.detach().item()
|
301 |
if args.logging_dir is not None:
|
302 |
-
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
|
|
303 |
accelerator.log(logs, step=global_step)
|
304 |
|
305 |
if epoch == 0:
|
@@ -326,6 +324,8 @@ def train(args):
|
|
326 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
327 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
328 |
|
|
|
|
|
329 |
is_main_process = accelerator.is_main_process
|
330 |
if is_main_process:
|
331 |
unet = unwrap_model(unet)
|
@@ -352,6 +352,8 @@ if __name__ == '__main__':
|
|
352 |
train_util.add_dataset_arguments(parser, True, False, True)
|
353 |
train_util.add_training_arguments(parser, True)
|
354 |
train_util.add_sd_saving_arguments(parser)
|
|
|
|
|
355 |
|
356 |
parser.add_argument("--no_token_padding", action="store_true",
|
357 |
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
|
|
15 |
from diffusers import DDPMScheduler
|
16 |
|
17 |
import library.train_util as train_util
|
18 |
+
import library.config_util as config_util
|
19 |
+
from library.config_util import (
|
20 |
+
ConfigSanitizer,
|
21 |
+
BlueprintGenerator,
|
22 |
+
)
|
23 |
|
24 |
|
25 |
def collate_fn(examples):
|
|
|
37 |
|
38 |
tokenizer = train_util.load_tokenizer(args)
|
39 |
|
40 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
|
41 |
+
if args.dataset_config is not None:
|
42 |
+
print(f"Load dataset config from {args.dataset_config}")
|
43 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
44 |
+
ignored = ["train_data_dir", "reg_data_dir"]
|
45 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
46 |
+
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
47 |
+
else:
|
48 |
+
user_config = {
|
49 |
+
"datasets": [{
|
50 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
|
51 |
+
}]
|
52 |
+
}
|
53 |
|
54 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
55 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
56 |
|
57 |
+
if args.no_token_padding:
|
58 |
+
train_dataset_group.disable_token_padding()
|
59 |
|
60 |
if args.debug_dataset:
|
61 |
+
train_util.debug_dataset(train_dataset_group)
|
62 |
return
|
63 |
|
64 |
+
if cache_latents:
|
65 |
+
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
66 |
+
|
67 |
# acceleratorを準備する
|
68 |
print("prepare accelerator")
|
69 |
|
|
|
104 |
vae.requires_grad_(False)
|
105 |
vae.eval()
|
106 |
with torch.no_grad():
|
107 |
+
train_dataset_group.cache_latents(vae)
|
108 |
vae.to("cpu")
|
109 |
if torch.cuda.is_available():
|
110 |
torch.cuda.empty_cache()
|
|
|
128 |
|
129 |
# 学習に必要なクラスを準備する
|
130 |
print("prepare optimizer, data loader etc.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
if train_text_encoder:
|
132 |
trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
133 |
else:
|
134 |
trainable_params = unet.parameters()
|
135 |
|
136 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
|
|
137 |
|
138 |
# dataloaderを準備する
|
139 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
140 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
141 |
train_dataloader = torch.utils.data.DataLoader(
|
142 |
+
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
143 |
|
144 |
# 学習ステップ数を計算する
|
145 |
if args.max_train_epochs is not None:
|
|
|
149 |
if args.stop_text_encoder_training is None:
|
150 |
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
151 |
|
152 |
+
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
|
153 |
+
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
154 |
+
num_training_steps=args.max_train_steps,
|
155 |
+
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
156 |
|
157 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
158 |
if args.full_fp16:
|
|
|
189 |
# 学習する
|
190 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
191 |
print("running training / 学習開始")
|
192 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
193 |
+
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
194 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
195 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
196 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
|
|
211 |
loss_total = 0.0
|
212 |
for epoch in range(num_train_epochs):
|
213 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
214 |
+
train_dataset_group.set_current_epoch(epoch + 1)
|
215 |
|
216 |
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
217 |
unet.train()
|
|
|
275 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
276 |
|
277 |
accelerator.backward(loss)
|
278 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
279 |
if train_text_encoder:
|
280 |
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
281 |
else:
|
282 |
params_to_clip = unet.parameters()
|
283 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
284 |
|
285 |
optimizer.step()
|
286 |
lr_scheduler.step()
|
|
|
291 |
progress_bar.update(1)
|
292 |
global_step += 1
|
293 |
|
294 |
+
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
295 |
+
|
296 |
current_loss = loss.detach().item()
|
297 |
if args.logging_dir is not None:
|
298 |
+
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
299 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
300 |
+
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
301 |
accelerator.log(logs, step=global_step)
|
302 |
|
303 |
if epoch == 0:
|
|
|
324 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
325 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
326 |
|
327 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
328 |
+
|
329 |
is_main_process = accelerator.is_main_process
|
330 |
if is_main_process:
|
331 |
unet = unwrap_model(unet)
|
|
|
352 |
train_util.add_dataset_arguments(parser, True, False, True)
|
353 |
train_util.add_training_arguments(parser, True)
|
354 |
train_util.add_sd_saving_arguments(parser)
|
355 |
+
train_util.add_optimizer_arguments(parser)
|
356 |
+
config_util.add_config_arguments(parser)
|
357 |
|
358 |
parser.add_argument("--no_token_padding", action="store_true",
|
359 |
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
train_db_README-ja.md
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DreamBoothのガイドです。
|
2 |
+
|
3 |
+
[学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。
|
4 |
+
|
5 |
+
# 概要
|
6 |
+
|
7 |
+
DreamBoothとは、画像生成モデルに特定の主題を追加学習し、それを特定の識別子で生成する技術です。[論文はこちら](https://arxiv.org/abs/2208.12242)。
|
8 |
+
|
9 |
+
具体的には、Stable Diffusionのモデルにキャラや画風などを学ばせ、それを `shs` のような特定の単語で呼び出せる(生成画像に出現させる)ことができます。
|
10 |
+
|
11 |
+
スクリプトは[DiffusersのDreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth)を元にしていますが、以下のような機能追加を行っています(いくつかの機能は元のスクリプト側もその後対応しています)。
|
12 |
+
|
13 |
+
スクリプトの主な機能は以下の通りです。
|
14 |
+
|
15 |
+
- 8bit Adam optimizerおよびlatentのキャッシュによる省メモリ化([Shivam Shrirao氏版](https://github.com/ShivamShrirao/diffusers/tree/main/examples/dreambooth)と同様)。
|
16 |
+
- xformersによる省メモリ化。
|
17 |
+
- 512x512だけではなく任意サイズでの学習。
|
18 |
+
- augmentationによる品質の向上。
|
19 |
+
- DreamBoothだけではなくText Encoder+U-Netのfine tuningに対応。
|
20 |
+
- Stable Diffusion形式でのモデルの読み書き。
|
21 |
+
- Aspect Ratio Bucketing。
|
22 |
+
- Stable Diffusion v2.0対応。
|
23 |
+
|
24 |
+
# 学習の手順
|
25 |
+
|
26 |
+
あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。
|
27 |
+
|
28 |
+
## データの準備
|
29 |
+
|
30 |
+
[学習データの準備について](./train_README-ja.md) を参照してください。
|
31 |
+
|
32 |
+
## 学習の実行
|
33 |
+
|
34 |
+
スクリプトを実行します。最大限、メモリを節約したコマンドは以下のようになります(実際には1行で入力します)。それぞれの行を必要に応じて書き換えてください。12GB程度のVRAMで動作するようです。
|
35 |
+
|
36 |
+
```
|
37 |
+
accelerate launch --num_cpu_threads_per_process 1 train_db.py
|
38 |
+
--pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
|
39 |
+
--dataset_config=<データ準備で作成した.tomlファイル>
|
40 |
+
--output_dir=<学習したモデルの出力先フォルダ>
|
41 |
+
--output_name=<学習したモデル出力時のファイル名>
|
42 |
+
--save_model_as=safetensors
|
43 |
+
--prior_loss_weight=1.0
|
44 |
+
--max_train_steps=1600
|
45 |
+
--learning_rate=1e-6
|
46 |
+
--optimizer_type="AdamW8bit"
|
47 |
+
--xformers
|
48 |
+
--mixed_precision="fp16"
|
49 |
+
--cache_latents
|
50 |
+
--gradient_checkpointing
|
51 |
+
```
|
52 |
+
|
53 |
+
`num_cpu_threads_per_process` には通常は1を指定するとよいようです。
|
54 |
+
|
55 |
+
`pretrained_model_name_or_path` に追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。
|
56 |
+
|
57 |
+
`output_dir` に学習後のモデルを保存するフォルダを指定します。`output_name` にモデルのファイル名を拡張子を除いて指定します。`save_model_as` でsafetensors形式での保存を指定しています。
|
58 |
+
|
59 |
+
`dataset_config` に `.toml` ファイルを指定します。ファイル内でのバッチサイズ指定は、当初はメモリ消費を抑えるために `1` としてください。
|
60 |
+
|
61 |
+
`prior_loss_weight` は正則化画像のlossの重みです。通常は1.0を指定します。
|
62 |
+
|
63 |
+
学習させるステップ数 `max_train_steps` を1600とします。学習率 `learning_rate` はここでは1e-6を指定しています。
|
64 |
+
|
65 |
+
省メモリ化のため `mixed_precision="fp16"` を指定します(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。また `gradient_checkpointing` を指定します。
|
66 |
+
|
67 |
+
オプティマイザ(モデルを学習データにあうように最適化=学習させるクラス)にメモリ消費の少ない 8bit AdamW を使うため、 `optimizer_type="AdamW8bit"` を指定します。
|
68 |
+
|
69 |
+
`xformers` オプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。
|
70 |
+
|
71 |
+
省メモリ化のため `cache_latents` オプションを指定してVAEの出力をキャッシュします。
|
72 |
+
|
73 |
+
ある程度メモリがある場合は、`.toml` ファイルを編集してバッチサイズをたとえば `4` くらいに増やしてください(高速化と精度向上の可能性があります)。また `cache_latents` を外すことで augmentation が可能になります。
|
74 |
+
|
75 |
+
### よく使われるオプションについて
|
76 |
+
|
77 |
+
以下の場合には [学習の共通ドキュメント](./train_README-ja.md) の「よく使われるオプション」を参照してください。
|
78 |
+
|
79 |
+
- Stable Diffusion 2.xまたはそこからの派生モデルを学習する
|
80 |
+
- clip skipを2以上を前提としたモデルを学習する
|
81 |
+
- 75トークンを超えたキャプションで学習する
|
82 |
+
|
83 |
+
### DreamBoothでのステップ数について
|
84 |
+
|
85 |
+
当スクリプトでは省メモリ化のため、ステップ当たりの学習回数が元のスクリプトの半分になっています(対象の画像と正則化画像を同一のバッチではなく別のバッチに分割して学習するため)。
|
86 |
+
|
87 |
+
元のDiffusers版やXavierXiao氏のStable Diffusion版とほぼ同じ学習を行うには、ステップ数を倍にしてください。
|
88 |
+
|
89 |
+
(学習画像と正則化画像をまとめてから shuffle するため厳密にはデータの順番が変わってしまいますが、学習には大きな影響はないと思います。)
|
90 |
+
|
91 |
+
### DreamBoothでのバッチサイズについて
|
92 |
+
|
93 |
+
モデル全体を学習するためLoRA等の学習に比べるとメモリ消費量は多くなります(fine tuningと同じ)。
|
94 |
+
|
95 |
+
### 学習率について
|
96 |
+
|
97 |
+
Diffusers版では5e-6ですがStable Diffusion版は1e-6ですので、上のサンプルでは1e-6を指定しています。
|
98 |
+
|
99 |
+
### 以前の形式のデータセット指定をした場合のコマンドライン
|
100 |
+
|
101 |
+
解像度やバッチサイズをオプションで指定します。コマンドラインの例は以下の通りです。
|
102 |
+
|
103 |
+
```
|
104 |
+
accelerate launch --num_cpu_threads_per_process 1 train_db.py
|
105 |
+
--pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
|
106 |
+
--train_data_dir=<学習用データのディレクトリ>
|
107 |
+
--reg_data_dir=<正則化画像のディレクトリ>
|
108 |
+
--output_dir=<学習したモデルの出力先ディレクトリ>
|
109 |
+
--output_name=<学習したモデル出力時のファイル名>
|
110 |
+
--prior_loss_weight=1.0
|
111 |
+
--resolution=512
|
112 |
+
--train_batch_size=1
|
113 |
+
--learning_rate=1e-6
|
114 |
+
--max_train_steps=1600
|
115 |
+
--use_8bit_adam
|
116 |
+
--xformers
|
117 |
+
--mixed_precision="bf16"
|
118 |
+
--cache_latents
|
119 |
+
--gradient_checkpointing
|
120 |
+
```
|
121 |
+
|
122 |
+
## 学習したモデルで画像生成する
|
123 |
+
|
124 |
+
学習が終わると指定したフォルダに指定した名前でsafetensorsファイルが出力されます。
|
125 |
+
|
126 |
+
v1.4/1.5およびその他の派生モデルの場合、このモデルでAutomatic1111氏のWebUIなどで推論できます。models\Stable-diffusionフォルダに置いてください。
|
127 |
+
|
128 |
+
v2.xモデルでWebUIで画像生成する場合、モデルの仕様が記述された.yamlファイルが別途必要になります。v2.x baseの場合はv2-inference.yamlを、768/vの場合はv2-inference-v.yamlを、同じフォルダに置き、拡張子の前の部分をモデルと同じ名前にしてください。
|
129 |
+
|
130 |
+

|
131 |
+
|
132 |
+
各yamlファイルは[Stability AIのSD2.0のリポジトリ](https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion)にあります。
|
133 |
+
|
134 |
+
# DreamBooth特有のその他の主なオプション
|
135 |
+
|
136 |
+
すべてのオプションについては別文書を参照してください。
|
137 |
+
|
138 |
+
## Text Encoderの学習を途中から行わない --stop_text_encoder_training
|
139 |
+
|
140 |
+
stop_text_encoder_trainingオプションに数値を指定すると、そのステップ数以降はText Encoderの学習を行わずU-Netだけ学習します。場合によっては精度の向上が期待できるかもしれません。
|
141 |
+
|
142 |
+
(恐らくText Encoderだけ先に過学習することがあり、それを防げるのではないかと推測していますが、詳細な影響は不明です。)
|
143 |
+
|
144 |
+
## Tokenizerのパディングをしない --no_token_padding
|
145 |
+
no_token_paddingオプションを指定するとTokenizerの出力をpaddingしません(Diffusers版の旧DreamBoothと同じ動きになります)。
|
146 |
+
|
147 |
+
|
148 |
+
<!--
|
149 |
+
bucketing(後述)を利用しかつaugmentation(後述)を使う場合の例は以下のようになります。
|
150 |
+
|
151 |
+
```
|
152 |
+
accelerate launch --num_cpu_threads_per_process 8 train_db.py
|
153 |
+
--pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
|
154 |
+
--train_data_dir=<学習用データのディレクトリ>
|
155 |
+
--reg_data_dir=<正則化画像のディレクトリ>
|
156 |
+
--output_dir=<学習したモデルの出力先ディレクトリ>
|
157 |
+
--resolution=768,512
|
158 |
+
--train_batch_size=20 --learning_rate=5e-6 --max_train_steps=800
|
159 |
+
--use_8bit_adam --xformers --mixed_precision="bf16"
|
160 |
+
--save_every_n_epochs=1 --save_state --save_precision="bf16"
|
161 |
+
--logging_dir=logs
|
162 |
+
--enable_bucket --min_bucket_reso=384 --max_bucket_reso=1280
|
163 |
+
--color_aug --flip_aug --gradient_checkpointing --seed 42
|
164 |
+
```
|
165 |
+
|
166 |
+
|
167 |
+
-->
|
train_network.py
CHANGED
@@ -1,8 +1,4 @@
|
|
1 |
-
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
2 |
-
from torch.optim import Optimizer
|
3 |
-
from torch.cuda.amp import autocast
|
4 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
5 |
-
from typing import Optional, Union
|
6 |
import importlib
|
7 |
import argparse
|
8 |
import gc
|
@@ -15,92 +11,39 @@ import json
|
|
15 |
from tqdm import tqdm
|
16 |
import torch
|
17 |
from accelerate.utils import set_seed
|
18 |
-
import diffusers
|
19 |
from diffusers import DDPMScheduler
|
20 |
|
21 |
import library.train_util as train_util
|
22 |
-
from library.train_util import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
def collate_fn(examples):
|
26 |
return examples[0]
|
27 |
|
28 |
|
|
|
29 |
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
30 |
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
31 |
|
32 |
if args.network_train_unet_only:
|
33 |
-
logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
|
34 |
elif args.network_train_text_encoder_only:
|
35 |
-
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
36 |
else:
|
37 |
-
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
38 |
-
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
|
39 |
-
|
40 |
-
return logs
|
41 |
|
|
|
|
|
42 |
|
43 |
-
|
44 |
-
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
45 |
-
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
46 |
-
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
47 |
-
|
48 |
-
|
49 |
-
def get_scheduler_fix(
|
50 |
-
name: Union[str, SchedulerType],
|
51 |
-
optimizer: Optimizer,
|
52 |
-
num_warmup_steps: Optional[int] = None,
|
53 |
-
num_training_steps: Optional[int] = None,
|
54 |
-
num_cycles: int = 1,
|
55 |
-
power: float = 1.0,
|
56 |
-
):
|
57 |
-
"""
|
58 |
-
Unified API to get any scheduler from its name.
|
59 |
-
Args:
|
60 |
-
name (`str` or `SchedulerType`):
|
61 |
-
The name of the scheduler to use.
|
62 |
-
optimizer (`torch.optim.Optimizer`):
|
63 |
-
The optimizer that will be used during training.
|
64 |
-
num_warmup_steps (`int`, *optional*):
|
65 |
-
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
66 |
-
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
67 |
-
num_training_steps (`int``, *optional*):
|
68 |
-
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
69 |
-
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
70 |
-
num_cycles (`int`, *optional*):
|
71 |
-
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
72 |
-
power (`float`, *optional*, defaults to 1.0):
|
73 |
-
Power factor. See `POLYNOMIAL` scheduler
|
74 |
-
last_epoch (`int`, *optional*, defaults to -1):
|
75 |
-
The index of the last epoch when resuming training.
|
76 |
-
"""
|
77 |
-
name = SchedulerType(name)
|
78 |
-
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
79 |
-
if name == SchedulerType.CONSTANT:
|
80 |
-
return schedule_func(optimizer)
|
81 |
-
|
82 |
-
# All other schedulers require `num_warmup_steps`
|
83 |
-
if num_warmup_steps is None:
|
84 |
-
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
85 |
-
|
86 |
-
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
87 |
-
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
88 |
-
|
89 |
-
# All other schedulers require `num_training_steps`
|
90 |
-
if num_training_steps is None:
|
91 |
-
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
92 |
-
|
93 |
-
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
94 |
-
return schedule_func(
|
95 |
-
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
96 |
-
)
|
97 |
-
|
98 |
-
if name == SchedulerType.POLYNOMIAL:
|
99 |
-
return schedule_func(
|
100 |
-
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
101 |
-
)
|
102 |
-
|
103 |
-
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
104 |
|
105 |
|
106 |
def train(args):
|
@@ -111,6 +54,7 @@ def train(args):
|
|
111 |
|
112 |
cache_latents = args.cache_latents
|
113 |
use_dreambooth_method = args.in_json is None
|
|
|
114 |
|
115 |
if args.seed is not None:
|
116 |
set_seed(args.seed)
|
@@ -118,38 +62,51 @@ def train(args):
|
|
118 |
tokenizer = train_util.load_tokenizer(args)
|
119 |
|
120 |
# データセットを準備する
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
else:
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
if args.debug_dataset:
|
144 |
-
train_util.debug_dataset(
|
145 |
return
|
146 |
-
if len(
|
147 |
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
148 |
return
|
149 |
|
|
|
|
|
|
|
|
|
150 |
# acceleratorを準備する
|
151 |
print("prepare accelerator")
|
152 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
|
|
153 |
|
154 |
# mixed precisionに対応した型を用意しておき適宜castする
|
155 |
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
@@ -161,7 +118,7 @@ def train(args):
|
|
161 |
if args.lowram:
|
162 |
text_encoder.to("cuda")
|
163 |
unet.to("cuda")
|
164 |
-
|
165 |
# モデルに xformers とか memory efficient attention を組み込む
|
166 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
167 |
|
@@ -171,13 +128,15 @@ def train(args):
|
|
171 |
vae.requires_grad_(False)
|
172 |
vae.eval()
|
173 |
with torch.no_grad():
|
174 |
-
|
175 |
vae.to("cpu")
|
176 |
if torch.cuda.is_available():
|
177 |
torch.cuda.empty_cache()
|
178 |
gc.collect()
|
179 |
|
180 |
# prepare network
|
|
|
|
|
181 |
print("import network module:", args.network_module)
|
182 |
network_module = importlib.import_module(args.network_module)
|
183 |
|
@@ -208,48 +167,25 @@ def train(args):
|
|
208 |
# 学習に必要なクラスを準備する
|
209 |
print("prepare optimizer, data loader etc.")
|
210 |
|
211 |
-
# 8-bit Adamを使う
|
212 |
-
if args.use_8bit_adam:
|
213 |
-
try:
|
214 |
-
import bitsandbytes as bnb
|
215 |
-
except ImportError:
|
216 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
217 |
-
print("use 8-bit Adam optimizer")
|
218 |
-
optimizer_class = bnb.optim.AdamW8bit
|
219 |
-
elif args.use_lion_optimizer:
|
220 |
-
try:
|
221 |
-
import lion_pytorch
|
222 |
-
except ImportError:
|
223 |
-
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
224 |
-
print("use Lion optimizer")
|
225 |
-
optimizer_class = lion_pytorch.Lion
|
226 |
-
else:
|
227 |
-
optimizer_class = torch.optim.AdamW
|
228 |
-
|
229 |
-
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
230 |
-
|
231 |
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
232 |
-
|
233 |
-
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
234 |
-
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
235 |
|
236 |
# dataloaderを準備する
|
237 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
238 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
239 |
train_dataloader = torch.utils.data.DataLoader(
|
240 |
-
|
241 |
|
242 |
# 学習ステップ数を計算する
|
243 |
if args.max_train_epochs is not None:
|
244 |
-
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
245 |
-
|
|
|
246 |
|
247 |
# lr schedulerを用意する
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
252 |
-
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
253 |
|
254 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
255 |
if args.full_fp16:
|
@@ -317,17 +253,21 @@ def train(args):
|
|
317 |
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
318 |
|
319 |
# 学習する
|
|
|
320 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
|
|
|
|
|
|
331 |
metadata = {
|
332 |
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
333 |
"ss_training_started_at": training_started_at, # unix timestamp
|
@@ -335,12 +275,10 @@ def train(args):
|
|
335 |
"ss_learning_rate": args.learning_rate,
|
336 |
"ss_text_encoder_lr": args.text_encoder_lr,
|
337 |
"ss_unet_lr": args.unet_lr,
|
338 |
-
"ss_num_train_images":
|
339 |
-
"ss_num_reg_images":
|
340 |
"ss_num_batches_per_epoch": len(train_dataloader),
|
341 |
"ss_num_epochs": num_train_epochs,
|
342 |
-
"ss_batch_size_per_device": args.train_batch_size,
|
343 |
-
"ss_total_batch_size": total_batch_size,
|
344 |
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
345 |
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
346 |
"ss_max_train_steps": args.max_train_steps,
|
@@ -352,33 +290,156 @@ def train(args):
|
|
352 |
"ss_mixed_precision": args.mixed_precision,
|
353 |
"ss_full_fp16": bool(args.full_fp16),
|
354 |
"ss_v2": bool(args.v2),
|
355 |
-
"ss_resolution": args.resolution,
|
356 |
"ss_clip_skip": args.clip_skip,
|
357 |
"ss_max_token_length": args.max_token_length,
|
358 |
-
"ss_color_aug": bool(args.color_aug),
|
359 |
-
"ss_flip_aug": bool(args.flip_aug),
|
360 |
-
"ss_random_crop": bool(args.random_crop),
|
361 |
-
"ss_shuffle_caption": bool(args.shuffle_caption),
|
362 |
"ss_cache_latents": bool(args.cache_latents),
|
363 |
-
"ss_enable_bucket": bool(train_dataset.enable_bucket),
|
364 |
-
"ss_min_bucket_reso": train_dataset.min_bucket_reso,
|
365 |
-
"ss_max_bucket_reso": train_dataset.max_bucket_reso,
|
366 |
"ss_seed": args.seed,
|
367 |
-
"
|
368 |
"ss_noise_offset": args.noise_offset,
|
369 |
-
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
370 |
-
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
371 |
-
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
|
372 |
-
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
|
373 |
"ss_training_comment": args.training_comment, # will not be updated after training
|
374 |
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
375 |
-
"ss_optimizer": optimizer_name
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
}
|
377 |
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
if args.pretrained_model_name_or_path is not None:
|
383 |
sd_model_name = args.pretrained_model_name_or_path
|
384 |
if os.path.exists(sd_model_name):
|
@@ -397,6 +458,13 @@ def train(args):
|
|
397 |
|
398 |
metadata = {k: str(v) for k, v in metadata.items()}
|
399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
401 |
global_step = 0
|
402 |
|
@@ -409,8 +477,9 @@ def train(args):
|
|
409 |
loss_list = []
|
410 |
loss_total = 0.0
|
411 |
for epoch in range(num_train_epochs):
|
412 |
-
|
413 |
-
|
|
|
414 |
|
415 |
metadata["ss_epoch"] = str(epoch+1)
|
416 |
|
@@ -447,7 +516,7 @@ def train(args):
|
|
447 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
448 |
|
449 |
# Predict the noise residual
|
450 |
-
with autocast():
|
451 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
452 |
|
453 |
if args.v_parameterization:
|
@@ -465,9 +534,9 @@ def train(args):
|
|
465 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
466 |
|
467 |
accelerator.backward(loss)
|
468 |
-
if accelerator.sync_gradients:
|
469 |
params_to_clip = network.get_trainable_params()
|
470 |
-
accelerator.clip_grad_norm_(params_to_clip,
|
471 |
|
472 |
optimizer.step()
|
473 |
lr_scheduler.step()
|
@@ -478,6 +547,8 @@ def train(args):
|
|
478 |
progress_bar.update(1)
|
479 |
global_step += 1
|
480 |
|
|
|
|
|
481 |
current_loss = loss.detach().item()
|
482 |
if epoch == 0:
|
483 |
loss_list.append(current_loss)
|
@@ -508,8 +579,9 @@ def train(args):
|
|
508 |
def save_func():
|
509 |
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
510 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
|
|
511 |
print(f"saving checkpoint: {ckpt_file}")
|
512 |
-
unwrap_model(network).save_weights(ckpt_file, save_dtype,
|
513 |
|
514 |
def remove_old_func(old_epoch_no):
|
515 |
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
@@ -518,15 +590,18 @@ def train(args):
|
|
518 |
print(f"removing old checkpoint: {old_ckpt_file}")
|
519 |
os.remove(old_ckpt_file)
|
520 |
|
521 |
-
|
522 |
-
|
523 |
-
|
|
|
|
|
|
|
524 |
|
525 |
# end of epoch
|
526 |
|
527 |
metadata["ss_epoch"] = str(num_train_epochs)
|
|
|
528 |
|
529 |
-
is_main_process = accelerator.is_main_process
|
530 |
if is_main_process:
|
531 |
network = unwrap_model(network)
|
532 |
|
@@ -545,7 +620,7 @@ def train(args):
|
|
545 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
546 |
|
547 |
print(f"save trained model to {ckpt_file}")
|
548 |
-
network.save_weights(ckpt_file, save_dtype,
|
549 |
print("model saved.")
|
550 |
|
551 |
|
@@ -555,6 +630,8 @@ if __name__ == '__main__':
|
|
555 |
train_util.add_sd_models_arguments(parser)
|
556 |
train_util.add_dataset_arguments(parser, True, True, True)
|
557 |
train_util.add_training_arguments(parser, True)
|
|
|
|
|
558 |
|
559 |
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
560 |
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
@@ -562,10 +639,6 @@ if __name__ == '__main__':
|
|
562 |
|
563 |
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
564 |
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
565 |
-
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
566 |
-
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
567 |
-
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
568 |
-
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
569 |
|
570 |
parser.add_argument("--network_weights", type=str, default=None,
|
571 |
help="pretrained weights for network / 学習するネットワークの初期重み")
|
|
|
|
|
|
|
|
|
1 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
2 |
import importlib
|
3 |
import argparse
|
4 |
import gc
|
|
|
11 |
from tqdm import tqdm
|
12 |
import torch
|
13 |
from accelerate.utils import set_seed
|
|
|
14 |
from diffusers import DDPMScheduler
|
15 |
|
16 |
import library.train_util as train_util
|
17 |
+
from library.train_util import (
|
18 |
+
DreamBoothDataset,
|
19 |
+
)
|
20 |
+
import library.config_util as config_util
|
21 |
+
from library.config_util import (
|
22 |
+
ConfigSanitizer,
|
23 |
+
BlueprintGenerator,
|
24 |
+
)
|
25 |
|
26 |
|
27 |
def collate_fn(examples):
|
28 |
return examples[0]
|
29 |
|
30 |
|
31 |
+
# TODO 他のスクリプトと共通化する
|
32 |
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
33 |
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
34 |
|
35 |
if args.network_train_unet_only:
|
36 |
+
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
|
37 |
elif args.network_train_text_encoder_only:
|
38 |
+
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
39 |
else:
|
40 |
+
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
41 |
+
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
|
|
|
|
|
42 |
|
43 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
|
44 |
+
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
|
45 |
|
46 |
+
return logs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
|
49 |
def train(args):
|
|
|
54 |
|
55 |
cache_latents = args.cache_latents
|
56 |
use_dreambooth_method = args.in_json is None
|
57 |
+
use_user_config = args.dataset_config is not None
|
58 |
|
59 |
if args.seed is not None:
|
60 |
set_seed(args.seed)
|
|
|
62 |
tokenizer = train_util.load_tokenizer(args)
|
63 |
|
64 |
# データセットを準備する
|
65 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
|
66 |
+
if use_user_config:
|
67 |
+
print(f"Load dataset config from {args.dataset_config}")
|
68 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
69 |
+
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
70 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
71 |
+
print(
|
72 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
73 |
else:
|
74 |
+
if use_dreambooth_method:
|
75 |
+
print("Use DreamBooth method.")
|
76 |
+
user_config = {
|
77 |
+
"datasets": [{
|
78 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
|
79 |
+
}]
|
80 |
+
}
|
81 |
+
else:
|
82 |
+
print("Train with captions.")
|
83 |
+
user_config = {
|
84 |
+
"datasets": [{
|
85 |
+
"subsets": [{
|
86 |
+
"image_dir": args.train_data_dir,
|
87 |
+
"metadata_file": args.in_json,
|
88 |
+
}]
|
89 |
+
}]
|
90 |
+
}
|
91 |
+
|
92 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
93 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
94 |
|
95 |
if args.debug_dataset:
|
96 |
+
train_util.debug_dataset(train_dataset_group)
|
97 |
return
|
98 |
+
if len(train_dataset_group) == 0:
|
99 |
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
100 |
return
|
101 |
|
102 |
+
if cache_latents:
|
103 |
+
assert train_dataset_group.is_latent_cacheable(
|
104 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
105 |
+
|
106 |
# acceleratorを準備する
|
107 |
print("prepare accelerator")
|
108 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
109 |
+
is_main_process = accelerator.is_main_process
|
110 |
|
111 |
# mixed precisionに対応した型を用意しておき適宜castする
|
112 |
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
|
|
118 |
if args.lowram:
|
119 |
text_encoder.to("cuda")
|
120 |
unet.to("cuda")
|
121 |
+
|
122 |
# モデルに xformers とか memory efficient attention を組み込む
|
123 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
124 |
|
|
|
128 |
vae.requires_grad_(False)
|
129 |
vae.eval()
|
130 |
with torch.no_grad():
|
131 |
+
train_dataset_group.cache_latents(vae)
|
132 |
vae.to("cpu")
|
133 |
if torch.cuda.is_available():
|
134 |
torch.cuda.empty_cache()
|
135 |
gc.collect()
|
136 |
|
137 |
# prepare network
|
138 |
+
import sys
|
139 |
+
sys.path.append(os.path.dirname(__file__))
|
140 |
print("import network module:", args.network_module)
|
141 |
network_module = importlib.import_module(args.network_module)
|
142 |
|
|
|
167 |
# 学習に必要なクラスを準備する
|
168 |
print("prepare optimizer, data loader etc.")
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
171 |
+
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
|
|
|
|
172 |
|
173 |
# dataloaderを準備する
|
174 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
175 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
176 |
train_dataloader = torch.utils.data.DataLoader(
|
177 |
+
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
178 |
|
179 |
# 学習ステップ数を計算する
|
180 |
if args.max_train_epochs is not None:
|
181 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes)
|
182 |
+
if is_main_process:
|
183 |
+
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
184 |
|
185 |
# lr schedulerを用意する
|
186 |
+
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
187 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps,
|
188 |
+
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
|
|
|
|
189 |
|
190 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
191 |
if args.full_fp16:
|
|
|
253 |
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
254 |
|
255 |
# 学習する
|
256 |
+
# TODO: find a way to handle total batch size when there are multiple datasets
|
257 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
258 |
+
|
259 |
+
if is_main_process:
|
260 |
+
print("running training / 学習開始")
|
261 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
262 |
+
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
263 |
+
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
264 |
+
print(f" num epochs / epoch数: {num_train_epochs}")
|
265 |
+
print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
266 |
+
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
267 |
+
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
268 |
+
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
269 |
+
|
270 |
+
# TODO refactor metadata creation and move to util
|
271 |
metadata = {
|
272 |
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
273 |
"ss_training_started_at": training_started_at, # unix timestamp
|
|
|
275 |
"ss_learning_rate": args.learning_rate,
|
276 |
"ss_text_encoder_lr": args.text_encoder_lr,
|
277 |
"ss_unet_lr": args.unet_lr,
|
278 |
+
"ss_num_train_images": train_dataset_group.num_train_images,
|
279 |
+
"ss_num_reg_images": train_dataset_group.num_reg_images,
|
280 |
"ss_num_batches_per_epoch": len(train_dataloader),
|
281 |
"ss_num_epochs": num_train_epochs,
|
|
|
|
|
282 |
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
283 |
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
284 |
"ss_max_train_steps": args.max_train_steps,
|
|
|
290 |
"ss_mixed_precision": args.mixed_precision,
|
291 |
"ss_full_fp16": bool(args.full_fp16),
|
292 |
"ss_v2": bool(args.v2),
|
|
|
293 |
"ss_clip_skip": args.clip_skip,
|
294 |
"ss_max_token_length": args.max_token_length,
|
|
|
|
|
|
|
|
|
295 |
"ss_cache_latents": bool(args.cache_latents),
|
|
|
|
|
|
|
296 |
"ss_seed": args.seed,
|
297 |
+
"ss_lowram": args.lowram,
|
298 |
"ss_noise_offset": args.noise_offset,
|
|
|
|
|
|
|
|
|
299 |
"ss_training_comment": args.training_comment, # will not be updated after training
|
300 |
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
301 |
+
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
|
302 |
+
"ss_max_grad_norm": args.max_grad_norm,
|
303 |
+
"ss_caption_dropout_rate": args.caption_dropout_rate,
|
304 |
+
"ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
|
305 |
+
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
|
306 |
+
"ss_face_crop_aug_range": args.face_crop_aug_range,
|
307 |
+
"ss_prior_loss_weight": args.prior_loss_weight,
|
308 |
}
|
309 |
|
310 |
+
if use_user_config:
|
311 |
+
# save metadata of multiple datasets
|
312 |
+
# NOTE: pack "ss_datasets" value as json one time
|
313 |
+
# or should also pack nested collections as json?
|
314 |
+
datasets_metadata = []
|
315 |
+
tag_frequency = {} # merge tag frequency for metadata editor
|
316 |
+
dataset_dirs_info = {} # merge subset dirs for metadata editor
|
317 |
+
|
318 |
+
for dataset in train_dataset_group.datasets:
|
319 |
+
is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
|
320 |
+
dataset_metadata = {
|
321 |
+
"is_dreambooth": is_dreambooth_dataset,
|
322 |
+
"batch_size_per_device": dataset.batch_size,
|
323 |
+
"num_train_images": dataset.num_train_images, # includes repeating
|
324 |
+
"num_reg_images": dataset.num_reg_images,
|
325 |
+
"resolution": (dataset.width, dataset.height),
|
326 |
+
"enable_bucket": bool(dataset.enable_bucket),
|
327 |
+
"min_bucket_reso": dataset.min_bucket_reso,
|
328 |
+
"max_bucket_reso": dataset.max_bucket_reso,
|
329 |
+
"tag_frequency": dataset.tag_frequency,
|
330 |
+
"bucket_info": dataset.bucket_info,
|
331 |
+
}
|
332 |
+
|
333 |
+
subsets_metadata = []
|
334 |
+
for subset in dataset.subsets:
|
335 |
+
subset_metadata = {
|
336 |
+
"img_count": subset.img_count,
|
337 |
+
"num_repeats": subset.num_repeats,
|
338 |
+
"color_aug": bool(subset.color_aug),
|
339 |
+
"flip_aug": bool(subset.flip_aug),
|
340 |
+
"random_crop": bool(subset.random_crop),
|
341 |
+
"shuffle_caption": bool(subset.shuffle_caption),
|
342 |
+
"keep_tokens": subset.keep_tokens,
|
343 |
+
}
|
344 |
+
|
345 |
+
image_dir_or_metadata_file = None
|
346 |
+
if subset.image_dir:
|
347 |
+
image_dir = os.path.basename(subset.image_dir)
|
348 |
+
subset_metadata["image_dir"] = image_dir
|
349 |
+
image_dir_or_metadata_file = image_dir
|
350 |
+
|
351 |
+
if is_dreambooth_dataset:
|
352 |
+
subset_metadata["class_tokens"] = subset.class_tokens
|
353 |
+
subset_metadata["is_reg"] = subset.is_reg
|
354 |
+
if subset.is_reg:
|
355 |
+
image_dir_or_metadata_file = None # not merging reg dataset
|
356 |
+
else:
|
357 |
+
metadata_file = os.path.basename(subset.metadata_file)
|
358 |
+
subset_metadata["metadata_file"] = metadata_file
|
359 |
+
image_dir_or_metadata_file = metadata_file # may overwrite
|
360 |
+
|
361 |
+
subsets_metadata.append(subset_metadata)
|
362 |
+
|
363 |
+
# merge dataset dir: not reg subset only
|
364 |
+
# TODO update additional-network extension to show detailed dataset config from metadata
|
365 |
+
if image_dir_or_metadata_file is not None:
|
366 |
+
# datasets may have a certain dir multiple times
|
367 |
+
v = image_dir_or_metadata_file
|
368 |
+
i = 2
|
369 |
+
while v in dataset_dirs_info:
|
370 |
+
v = image_dir_or_metadata_file + f" ({i})"
|
371 |
+
i += 1
|
372 |
+
image_dir_or_metadata_file = v
|
373 |
+
|
374 |
+
dataset_dirs_info[image_dir_or_metadata_file] = {
|
375 |
+
"n_repeats": subset.num_repeats,
|
376 |
+
"img_count": subset.img_count
|
377 |
+
}
|
378 |
+
|
379 |
+
dataset_metadata["subsets"] = subsets_metadata
|
380 |
+
datasets_metadata.append(dataset_metadata)
|
381 |
+
|
382 |
+
# merge tag frequency:
|
383 |
+
for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
|
384 |
+
# あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
|
385 |
+
# もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
|
386 |
+
# なので、ここで複数datasetの回数を合算してもあまり意味はない
|
387 |
+
if ds_dir_name in tag_frequency:
|
388 |
+
continue
|
389 |
+
tag_frequency[ds_dir_name] = ds_freq_for_dir
|
390 |
+
|
391 |
+
metadata["ss_datasets"] = json.dumps(datasets_metadata)
|
392 |
+
metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
|
393 |
+
metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
|
394 |
+
else:
|
395 |
+
# conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
|
396 |
+
assert len(
|
397 |
+
train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
|
398 |
+
|
399 |
+
dataset = train_dataset_group.datasets[0]
|
400 |
+
|
401 |
+
dataset_dirs_info = {}
|
402 |
+
reg_dataset_dirs_info = {}
|
403 |
+
if use_dreambooth_method:
|
404 |
+
for subset in dataset.subsets:
|
405 |
+
info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
|
406 |
+
info[os.path.basename(subset.image_dir)] = {
|
407 |
+
"n_repeats": subset.num_repeats,
|
408 |
+
"img_count": subset.img_count
|
409 |
+
}
|
410 |
+
else:
|
411 |
+
for subset in dataset.subsets:
|
412 |
+
dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
|
413 |
+
"n_repeats": subset.num_repeats,
|
414 |
+
"img_count": subset.img_count
|
415 |
+
}
|
416 |
+
|
417 |
+
metadata.update({
|
418 |
+
"ss_batch_size_per_device": args.train_batch_size,
|
419 |
+
"ss_total_batch_size": total_batch_size,
|
420 |
+
"ss_resolution": args.resolution,
|
421 |
+
"ss_color_aug": bool(args.color_aug),
|
422 |
+
"ss_flip_aug": bool(args.flip_aug),
|
423 |
+
"ss_random_crop": bool(args.random_crop),
|
424 |
+
"ss_shuffle_caption": bool(args.shuffle_caption),
|
425 |
+
"ss_enable_bucket": bool(dataset.enable_bucket),
|
426 |
+
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
|
427 |
+
"ss_min_bucket_reso": dataset.min_bucket_reso,
|
428 |
+
"ss_max_bucket_reso": dataset.max_bucket_reso,
|
429 |
+
"ss_keep_tokens": args.keep_tokens,
|
430 |
+
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
|
431 |
+
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
|
432 |
+
"ss_tag_frequency": json.dumps(dataset.tag_frequency),
|
433 |
+
"ss_bucket_info": json.dumps(dataset.bucket_info),
|
434 |
+
})
|
435 |
+
|
436 |
+
# add extra args
|
437 |
+
if args.network_args:
|
438 |
+
metadata["ss_network_args"] = json.dumps(net_kwargs)
|
439 |
+
# for key, value in net_kwargs.items():
|
440 |
+
# metadata["ss_arg_" + key] = value
|
441 |
+
|
442 |
+
# model name and hash
|
443 |
if args.pretrained_model_name_or_path is not None:
|
444 |
sd_model_name = args.pretrained_model_name_or_path
|
445 |
if os.path.exists(sd_model_name):
|
|
|
458 |
|
459 |
metadata = {k: str(v) for k, v in metadata.items()}
|
460 |
|
461 |
+
# make minimum metadata for filtering
|
462 |
+
minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"]
|
463 |
+
minimum_metadata = {}
|
464 |
+
for key in minimum_keys:
|
465 |
+
if key in metadata:
|
466 |
+
minimum_metadata[key] = metadata[key]
|
467 |
+
|
468 |
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
469 |
global_step = 0
|
470 |
|
|
|
477 |
loss_list = []
|
478 |
loss_total = 0.0
|
479 |
for epoch in range(num_train_epochs):
|
480 |
+
if is_main_process:
|
481 |
+
print(f"epoch {epoch+1}/{num_train_epochs}")
|
482 |
+
train_dataset_group.set_current_epoch(epoch + 1)
|
483 |
|
484 |
metadata["ss_epoch"] = str(epoch+1)
|
485 |
|
|
|
516 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
517 |
|
518 |
# Predict the noise residual
|
519 |
+
with accelerator.autocast():
|
520 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
521 |
|
522 |
if args.v_parameterization:
|
|
|
534 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
535 |
|
536 |
accelerator.backward(loss)
|
537 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
538 |
params_to_clip = network.get_trainable_params()
|
539 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
540 |
|
541 |
optimizer.step()
|
542 |
lr_scheduler.step()
|
|
|
547 |
progress_bar.update(1)
|
548 |
global_step += 1
|
549 |
|
550 |
+
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
551 |
+
|
552 |
current_loss = loss.detach().item()
|
553 |
if epoch == 0:
|
554 |
loss_list.append(current_loss)
|
|
|
579 |
def save_func():
|
580 |
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
581 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
582 |
+
metadata["ss_training_finished_at"] = str(time.time())
|
583 |
print(f"saving checkpoint: {ckpt_file}")
|
584 |
+
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
585 |
|
586 |
def remove_old_func(old_epoch_no):
|
587 |
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
|
|
590 |
print(f"removing old checkpoint: {old_ckpt_file}")
|
591 |
os.remove(old_ckpt_file)
|
592 |
|
593 |
+
if is_main_process:
|
594 |
+
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
595 |
+
if saving and args.save_state:
|
596 |
+
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
597 |
+
|
598 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
599 |
|
600 |
# end of epoch
|
601 |
|
602 |
metadata["ss_epoch"] = str(num_train_epochs)
|
603 |
+
metadata["ss_training_finished_at"] = str(time.time())
|
604 |
|
|
|
605 |
if is_main_process:
|
606 |
network = unwrap_model(network)
|
607 |
|
|
|
620 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
621 |
|
622 |
print(f"save trained model to {ckpt_file}")
|
623 |
+
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
624 |
print("model saved.")
|
625 |
|
626 |
|
|
|
630 |
train_util.add_sd_models_arguments(parser)
|
631 |
train_util.add_dataset_arguments(parser, True, True, True)
|
632 |
train_util.add_training_arguments(parser, True)
|
633 |
+
train_util.add_optimizer_arguments(parser)
|
634 |
+
config_util.add_config_arguments(parser)
|
635 |
|
636 |
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
637 |
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
|
|
639 |
|
640 |
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
641 |
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
|
|
|
|
|
|
|
|
642 |
|
643 |
parser.add_argument("--network_weights", type=str, default=None,
|
644 |
help="pretrained weights for network / 学習するネットワークの初期重み")
|
train_network_README-ja.md
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LoRAの学習について
|
2 |
+
|
3 |
+
[LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685)(arxiv)、[LoRA](https://github.com/microsoft/LoRA)(github)をStable Diffusionに適用したものです。
|
4 |
+
|
5 |
+
[cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を大いに参考にさせていただきました。ありがとうございます。
|
6 |
+
|
7 |
+
通常のLoRAは Linear およぴカーネルサイズ 1x1 の Conv2d にのみ適用されますが、カーネルサイズ 3x3 のConv2dに適用を拡大することもできます。
|
8 |
+
|
9 |
+
Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) が最初にリリースし、KohakuBlueleaf氏が [LoCon](https://github.com/KohakuBlueleaf/LoCon) でその有効性を明らかにしたものです。KohakuBlueleaf氏に深く感謝します。
|
10 |
+
|
11 |
+
8GB VRAMでもぎりぎり動作するようです。
|
12 |
+
|
13 |
+
[学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。
|
14 |
+
|
15 |
+
## 学習したモデルに関する注意
|
16 |
+
|
17 |
+
cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。
|
18 |
+
|
19 |
+
WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージしておくか、こちらの[WebUI用extension](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。
|
20 |
+
|
21 |
+
# 学習の手順
|
22 |
+
|
23 |
+
あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。
|
24 |
+
|
25 |
+
## データの準備
|
26 |
+
|
27 |
+
[学習データの準備について](./train_README-ja.md) を参照してください。
|
28 |
+
|
29 |
+
|
30 |
+
## 学習の実行
|
31 |
+
|
32 |
+
`train_network.py`を用います。
|
33 |
+
|
34 |
+
`train_network.py`では `--network_module` オプションに、学習対象のモジュール名を指定します。LoRAに対応するのはnetwork.loraとなりますので、それを指定してください。
|
35 |
+
|
36 |
+
なお学習率は通常のDreamBoothやfine tuningよりも高めの、1e-4程度を指定するとよいようです。
|
37 |
+
|
38 |
+
以下はコマンドラインの例です。
|
39 |
+
|
40 |
+
```
|
41 |
+
accelerate launch --num_cpu_threads_per_process 1 train_network.py
|
42 |
+
--pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
|
43 |
+
--dataset_config=<データ準備で作成した.tomlファイル>
|
44 |
+
--output_dir=<学習したモデルの出力先フォルダ>
|
45 |
+
--output_name=<学習したモデル出力時のファイル名>
|
46 |
+
--save_model_as=safetensors
|
47 |
+
--prior_loss_weight=1.0
|
48 |
+
--max_train_steps=400
|
49 |
+
--learning_rate=1e-4
|
50 |
+
--optimizer_type="AdamW8bit"
|
51 |
+
--xformers
|
52 |
+
--mixed_precision="fp16"
|
53 |
+
--cache_latents
|
54 |
+
--gradient_checkpointing
|
55 |
+
--save_every_n_epochs=1
|
56 |
+
--network_module=networks.lora
|
57 |
+
```
|
58 |
+
|
59 |
+
`--output_dir` オプションで指定したフォルダに、LoRAのモデルが保存されます。他のオプション、オプティマイザ等については [学習の共通ドキュメント](./train_README-ja.md) の「よく使われるオプション」も参照してください。
|
60 |
+
|
61 |
+
その他、以下のオプションが指定できます。
|
62 |
+
|
63 |
+
* `--network_dim`
|
64 |
+
* LoRAのRANKを指定します(``--networkdim=4``など)。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。
|
65 |
+
* `--network_alpha`
|
66 |
+
* アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定します。デフォルトは1です。``network_dim``と同じ値を指定すると以前のバージョンと同じ動作になります。
|
67 |
+
* `--persistent_data_loader_workers`
|
68 |
+
* Windows環境で指定するとエポック間の待ち時間が大幅に短縮されます。
|
69 |
+
* `--max_data_loader_n_workers`
|
70 |
+
* データ読み込みのプロセス数を指定します。プロセス数が多いとデータ読み込みが速くなりGPUを効率的に利用できますが、メインメモリを消費します。デフォルトは「`8` または `CPU同時実行スレッド数-1` の小さいほう」なので、メインメモリに余裕がない場合や、GPU使用率が90%程度以上なら、それらの数値を見ながら `2` または `1` 程度まで下げてください。
|
71 |
+
* `--network_weights`
|
72 |
+
* 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。
|
73 |
+
* `--network_train_unet_only`
|
74 |
+
* U-Netに関連するLoRAモジュールのみ有効とします。fine tuning的な学習で指定するとよいかもしれません。
|
75 |
+
* `--network_train_text_encoder_only`
|
76 |
+
* Text Encoderに関連するLoRAモジュールのみ有効とします。Textual Inversion的な効果が期待できるかもしれません。
|
77 |
+
* `--unet_lr`
|
78 |
+
* U-Netに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。
|
79 |
+
* `--text_encoder_lr`
|
80 |
+
* Text Encoderに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。Text Encoderのほうを若干低めの学習率(5e-5など)にしたほうが良い、という話もあるようです。
|
81 |
+
* `--network_args`
|
82 |
+
* 複数の引数を指定できます。後述します。
|
83 |
+
|
84 |
+
`--network_train_unet_only` と `--network_train_text_encoder_only` の両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。
|
85 |
+
|
86 |
+
## LoRA を Conv2d に拡大して適用する
|
87 |
+
|
88 |
+
通常のLoRAは Linear およぴカーネルサイズ 1x1 の Conv2d にのみ適用されますが、カーネルサイズ 3x3 のConv2dに適用を拡大することもできます。
|
89 |
+
|
90 |
+
`--network_args` に以下のように指定してください。`conv_dim` で Conv2d (3x3) の rank を、`conv_alpha` で alpha を指定してください。
|
91 |
+
|
92 |
+
```
|
93 |
+
--network_args "conv_dim=1" "conv_alpha=1"
|
94 |
+
```
|
95 |
+
|
96 |
+
以下のように alpha 省略時は1になります。
|
97 |
+
|
98 |
+
```
|
99 |
+
--network_args "conv_dim=1"
|
100 |
+
```
|
101 |
+
|
102 |
+
## マージスクリプトについて
|
103 |
+
|
104 |
+
merge_lora.pyでStable DiffusionのモデルにLoRAの学習結果をマージしたり、複数のLoRAモデルをマージしたりできます。
|
105 |
+
|
106 |
+
### Stable DiffusionのモデルにLoRAのモデルをマージする
|
107 |
+
|
108 |
+
マージ後のモデルは通常のStable Diffusionのckptと同様に扱えます。たとえば以下のようなコマンドラインになります。
|
109 |
+
|
110 |
+
```
|
111 |
+
python networks\merge_lora.py --sd_model ..\model\model.ckpt
|
112 |
+
--save_to ..\lora_train1\model-char1-merged.safetensors
|
113 |
+
--models ..\lora_train1\last.safetensors --ratios 0.8
|
114 |
+
```
|
115 |
+
|
116 |
+
Stable Diffusion v2.xのモデルで学習し、それにマージする場合は、--v2オプションを指定してください。
|
117 |
+
|
118 |
+
--sd_modelオプションにマージの元となるStable Diffusionのモデルファイルを指定します(.ckptまたは.safetensorsのみ対応で、Diffusersは今のところ対応していません)。
|
119 |
+
|
120 |
+
--save_toオプションにマージ後のモデルの保存先を指定します(.ckptまたは.safetensors、拡張子で自動判定)。
|
121 |
+
|
122 |
+
--modelsに学習したLoRAのモデルファイルを指定します。複数指定も可能で、その時は順にマージします。
|
123 |
+
|
124 |
+
--ratiosにそれぞれのモデルの適用率(どのくらい重みを元モデルに反映するか)を0~1.0の数値で指定します。例えば過学習に近いような場合は、適用率を下げるとマシになるかもしれません。モデルの数と同じだけ指定してください。
|
125 |
+
|
126 |
+
複数指定時は以下のようになります。
|
127 |
+
|
128 |
+
```
|
129 |
+
python networks\merge_lora.py --sd_model ..\model\model.ckpt
|
130 |
+
--save_to ..\lora_train1\model-char1-merged.safetensors
|
131 |
+
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.8 0.5
|
132 |
+
```
|
133 |
+
|
134 |
+
### 複数のLoRAのモデルをマージする
|
135 |
+
|
136 |
+
複数のLoRAモデルをひとつずつSDモデルに適用する場合と、複数のLoRAモデルをマージしてからSDモデルにマージする場合とは、計算順序の関連で微妙に異なる結果になります。
|
137 |
+
|
138 |
+
たとえば以下のようなコマンドラインになります。
|
139 |
+
|
140 |
+
```
|
141 |
+
python networks\merge_lora.py
|
142 |
+
--save_to ..\lora_train1\model-char1-style1-merged.safetensors
|
143 |
+
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.6 0.4
|
144 |
+
```
|
145 |
+
|
146 |
+
--sd_modelオプションは指定不要です。
|
147 |
+
|
148 |
+
--save_toオプションにマージ後のLoRAモデルの保存先を指定します(.ckptまたは.safetensors、拡張子で自動判定)。
|
149 |
+
|
150 |
+
--modelsに学習したLoRAのモデルファイルを指定します。三つ以上も指定可能です。
|
151 |
+
|
152 |
+
--ratiosにそれぞれのモデルの比率(どのくらい重みを元モデルに反映するか)を0~1.0の数値で指定します。二つのモデルを一対一でマージす場合は、「0.5 0.5」になります。「1.0 1.0」では合計の重みが大きくなりすぎて、恐らく結果はあまり望ましくないものになると思われます。
|
153 |
+
|
154 |
+
v1で学習したLoRAとv2で学習したLoRA、rank(次元数)や``alpha``の異なるLoRAはマージできません。U-NetだけのLoRAとU-Net+Text EncoderのLoRAはマージできるはずですが、結果は未知数です。
|
155 |
+
|
156 |
+
|
157 |
+
### その他のオプション
|
158 |
+
|
159 |
+
* precision
|
160 |
+
* マージ計算時の精度をfloat、fp16、bf16から指定できます。省略時は精度を確保するためfloatになります。メモリ使用量を減らしたい場合はfp16/bf16を指定してください。
|
161 |
+
* save_precision
|
162 |
+
* モ��ル保存時の精度をfloat、fp16、bf16から指定できます。省略時はprecisionと同じ精度になります。
|
163 |
+
|
164 |
+
|
165 |
+
## 複数のrankが異なるLoRAのモデルをマージする
|
166 |
+
|
167 |
+
複数のLoRAをひとつのLoRAで近似します(完全な再現はできません)。`svd_merge_lora.py`を用います。たとえば以下のようなコマンドラインになります。
|
168 |
+
|
169 |
+
```
|
170 |
+
python networks\svd_merge_lora.py
|
171 |
+
--save_to ..\lora_train1\model-char1-style1-merged.safetensors
|
172 |
+
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors
|
173 |
+
--ratios 0.6 0.4 --new_rank 32 --device cuda
|
174 |
+
```
|
175 |
+
|
176 |
+
`merge_lora.py` と主なオプションは同一です。以下のオプションが追加されています。
|
177 |
+
|
178 |
+
- `--new_rank`
|
179 |
+
- 作成するLoRAのrankを指定します。
|
180 |
+
- `--new_conv_rank`
|
181 |
+
- 作成する Conv2d 3x3 LoRA の rank を指定します。省略時は `new_rank` と同じになります。
|
182 |
+
- `--device`
|
183 |
+
- `--device cuda`としてcudaを指定すると計算をGPU上で行います。処理が速くなります。
|
184 |
+
|
185 |
+
## 当リポジトリ内の画像生成スクリプトで生成する
|
186 |
+
|
187 |
+
gen_img_diffusers.pyに、--network_module、--network_weightsの各オプションを追加してください。意味は学習時と同様です。
|
188 |
+
|
189 |
+
--network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。
|
190 |
+
|
191 |
+
## 二つのモデルの差分からLoRAモデルを作成する
|
192 |
+
|
193 |
+
[こちらのディスカッション](https://github.com/cloneofsimo/lora/discussions/56)を参考に実装したものです。数式はそのまま使わせていただきました(よく理解していませんが近似には特異値分解を用いるようです)。
|
194 |
+
|
195 |
+
二つのモデル(たとえばfine tuningの元モデルとfine tuning後のモデル)の差分を、LoRAで近似します。
|
196 |
+
|
197 |
+
### スクリプトの実行方法
|
198 |
+
|
199 |
+
以下のように指定してください。
|
200 |
+
```
|
201 |
+
python networks\extract_lora_from_models.py --model_org base-model.ckpt
|
202 |
+
--model_tuned fine-tuned-model.ckpt
|
203 |
+
--save_to lora-weights.safetensors --dim 4
|
204 |
+
```
|
205 |
+
|
206 |
+
--model_orgオプションに元のStable Diffusionモデルを指定します。作成したLoRAモデルを適用する場合は、このモデルを指定して適用することになります。.ckptまたは.safetensorsが指定できます。
|
207 |
+
|
208 |
+
--model_tunedオプションに差分を抽出する対象のStable Diffusionモデルを指定します。たとえばfine tuningやDreamBooth後のモデルを指定します。.ckptまたは.safetensorsが指定できます。
|
209 |
+
|
210 |
+
--save_toにLoRAモデルの保存先を指定します。--dimにLoRAの次元数を指定します。
|
211 |
+
|
212 |
+
生成されたLoRAモデルは、学習したLoRAモデルと同様に使用できます。
|
213 |
+
|
214 |
+
Text Encoderが二つのモデルで同じ場合にはLoRAはU-NetのみのLoRAとなります。
|
215 |
+
|
216 |
+
### その他のオプション
|
217 |
+
|
218 |
+
- `--v2`
|
219 |
+
- v2.xのStable Diffusionモデルを使う場合に指定してください。
|
220 |
+
- `--device`
|
221 |
+
- ``--device cuda``としてcudaを指定すると計算をGPU上で行います。処理が速くなります(CPUでもそこまで遅くないため、せいぜい倍~数倍程度のようです)。
|
222 |
+
- `--save_precision`
|
223 |
+
- LoRAの保存形式を"float", "fp16", "bf16"から指定します。省略時はfloatになります。
|
224 |
+
- `--conv_dim`
|
225 |
+
- 指定するとLoRAの適用範囲を Conv2d 3x3 へ拡大します。Conv2d 3x3 の rank を指定します。
|
226 |
+
|
227 |
+
## 画像リサイズスクリプト
|
228 |
+
|
229 |
+
(のちほどドキュメントを整理しますがとりあえずここに説明を書いておきます。)
|
230 |
+
|
231 |
+
Aspect Ratio Bucketingの機能拡張で、小さな画像については拡大しないでそのまま教師データとすることが可能になりました。元の教師画像を縮小した画像を、教師データに加えると精度が向上したという報告とともに前処理用のスクリプトをいただきましたので整備して追加しました。bmaltais氏に感謝します。
|
232 |
+
|
233 |
+
### スクリプトの実行方法
|
234 |
+
|
235 |
+
以下のように指定してください。元の画像そのまま、およびリサイズ後の画像が変換先フォルダに保存されます。リサイズ後の画像には、ファイル名に ``+512x512`` のようにリサイズ先の解像度が付け加えられます(画像サイズとは異なります)。リサイズ先の解像度より小さい画像は拡大されることはありません。
|
236 |
+
|
237 |
+
```
|
238 |
+
python tools\resize_images_to_resolution.py --max_resolution 512x512,384x384,256x256 --save_as_png
|
239 |
+
--copy_associated_files 元画像フォルダ 変換先フォルダ
|
240 |
+
```
|
241 |
+
|
242 |
+
元画像フォルダ内の画像ファイルが、指定した解像度(複数指定可)と同じ面積になるようにリサイズされ、変換先フォルダに保存されます。画像以外のファイルはそのままコピーされます。
|
243 |
+
|
244 |
+
``--max_resolution`` オプションにリサイズ���のサイズを例のように指定してください。面積がそのサイズになるようにリサイズします。複数指定すると、それぞれの解像度でリサイズされます。``512x512,384x384,256x256``なら、変換先フォルダの画像は、元サイズとリサイズ後サイズ×3の計4枚になります。
|
245 |
+
|
246 |
+
``--save_as_png`` オプションを指定するとpng形式で保存します。省略するとjpeg形式(quality=100)で保存されます。
|
247 |
+
|
248 |
+
``--copy_associated_files`` オプションを指定すると、拡張子を除き画像と同じファイル名(たとえばキャプションなど)のファイルが、リサイズ後の画像のファイル名と同じ名前でコピーされます。
|
249 |
+
|
250 |
+
|
251 |
+
### その他のオプション
|
252 |
+
|
253 |
+
- divisible_by
|
254 |
+
- リサイズ後の画像のサイズ(縦、横のそれぞれ)がこの値で割り切れるように、画像中心を切り出します。
|
255 |
+
- interpolation
|
256 |
+
- 縮小時の補完方法を指定します。``area, cubic, lanczos4``から選択可能で、デフォルトは``area``です。
|
257 |
+
|
258 |
+
|
259 |
+
## 追加情報
|
260 |
+
|
261 |
+
### cloneofsimo氏のリポジトリとの違い
|
262 |
+
|
263 |
+
2022/12/25時点では、当リポジトリはLoRAの適用個所をText EncoderのMLP、U-NetのFFN、Transformerのin/out projectionに拡大し、表現力が増しています。ただその代わりメモリ使用量は増え、8GBぎりぎりになりました。
|
264 |
+
|
265 |
+
またモジュール入れ替え機構は全く異なります。
|
266 |
+
|
267 |
+
### 将来拡張について
|
268 |
+
|
269 |
+
LoRAだけでなく他の拡張にも対応可能ですので、それらも追加予定です。
|
train_network_opt.py
CHANGED
@@ -1,8 +1,5 @@
|
|
1 |
-
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
2 |
-
from torch.optim import Optimizer
|
3 |
from torch.cuda.amp import autocast
|
4 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
5 |
-
from typing import Optional, Union
|
6 |
import importlib
|
7 |
import argparse
|
8 |
import gc
|
@@ -15,138 +12,49 @@ import json
|
|
15 |
from tqdm import tqdm
|
16 |
import torch
|
17 |
from accelerate.utils import set_seed
|
18 |
-
import diffusers
|
19 |
from diffusers import DDPMScheduler
|
20 |
-
|
21 |
-
#先に
|
22 |
-
#pip install torch_optimizer
|
23 |
-
#が必要
|
24 |
-
try:
|
25 |
-
import torch_optimizer as optim
|
26 |
-
except:
|
27 |
-
print("torch_optimizerがインストールされていないためAdafactorとAdastand以外の追加optimzierは使えません。\noptimizerの変更をしたい場合先にpip install torch_optimizerでライブラリを追加してください")
|
28 |
-
try:
|
29 |
-
import adastand
|
30 |
-
except:
|
31 |
-
print("※Adastandが使えません")
|
32 |
-
|
33 |
-
from transformers.optimization import Adafactor, AdafactorSchedule
|
34 |
-
print("**********************************")
|
35 |
##### バケット拡張のためのモジュール
|
36 |
import append_module
|
37 |
######
|
38 |
import library.train_util as train_util
|
39 |
-
from library.train_util import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
|
42 |
def collate_fn(examples):
|
43 |
return examples[0]
|
44 |
|
45 |
|
46 |
-
|
|
|
47 |
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
53 |
else:
|
54 |
last_lrs = lr_scheduler.get_last_lr()
|
55 |
-
|
56 |
-
logs["lr/
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
logs_names = ["textencoder", "lora_unet_mid_block", "unet_down_blocks", "unet_up_blocks"]
|
61 |
-
elif len(last_lrs) == 8:
|
62 |
-
logs_names = ["textencoder", "unet_midblock"]
|
63 |
-
for i in range(3):
|
64 |
-
logs_names.append(f"unet_down_blocks_{i}")
|
65 |
-
logs_names.append(f"unet_up_blocks_{i+1}")
|
66 |
-
else:
|
67 |
-
logs_names = []
|
68 |
-
for i in range(12):
|
69 |
-
logs_names.append(f"text_model_encoder_layers_{i}_")
|
70 |
-
logs_names.append("unet_midblock")
|
71 |
-
for i in range(3):
|
72 |
-
logs_names.append(f"unet_down_blocks_{i}")
|
73 |
-
logs_names.append(f"unet_up_blocks_{i+1}")
|
74 |
-
|
75 |
-
for last_lr, logs_name in zip(last_lrs, logs_names):
|
76 |
-
logs[f"lr/{logs_name}"] = float(last_lr)
|
77 |
|
78 |
return logs
|
79 |
|
80 |
|
81 |
-
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
82 |
-
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
83 |
-
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
84 |
-
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
85 |
-
|
86 |
-
|
87 |
-
def get_scheduler_fix(
|
88 |
-
name: Union[str, SchedulerType],
|
89 |
-
optimizer: Optimizer,
|
90 |
-
num_warmup_steps: Optional[int] = None,
|
91 |
-
num_training_steps: Optional[int] = None,
|
92 |
-
num_cycles: float = 1.,
|
93 |
-
power: float = 1.0,
|
94 |
-
):
|
95 |
-
"""
|
96 |
-
Unified API to get any scheduler from its name.
|
97 |
-
Args:
|
98 |
-
name (`str` or `SchedulerType`):
|
99 |
-
The name of the scheduler to use.
|
100 |
-
optimizer (`torch.optim.Optimizer`):
|
101 |
-
The optimizer that will be used during training.
|
102 |
-
num_warmup_steps (`int`, *optional*):
|
103 |
-
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
104 |
-
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
105 |
-
num_training_steps (`int``, *optional*):
|
106 |
-
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
107 |
-
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
108 |
-
num_cycles (`int`, *optional*):
|
109 |
-
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
110 |
-
power (`float`, *optional*, defaults to 1.0):
|
111 |
-
Power factor. See `POLYNOMIAL` scheduler
|
112 |
-
last_epoch (`int`, *optional*, defaults to -1):
|
113 |
-
The index of the last epoch when resuming training.
|
114 |
-
"""
|
115 |
-
name = SchedulerType(name)
|
116 |
-
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
117 |
-
if name == SchedulerType.CONSTANT:
|
118 |
-
return schedule_func(optimizer)
|
119 |
-
|
120 |
-
# All other schedulers require `num_warmup_steps`
|
121 |
-
if num_warmup_steps is None:
|
122 |
-
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
123 |
-
|
124 |
-
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
125 |
-
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
126 |
-
|
127 |
-
# All other schedulers require `num_training_steps`
|
128 |
-
if num_training_steps is None:
|
129 |
-
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
130 |
-
|
131 |
-
if name == SchedulerType.COSINE:
|
132 |
-
print(f"{name} num_cycles: {num_cycles}")
|
133 |
-
return schedule_func(
|
134 |
-
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
135 |
-
)
|
136 |
-
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
137 |
-
print(f"{name} num_cycles: {int(num_cycles)}")
|
138 |
-
return schedule_func(
|
139 |
-
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=int(num_cycles)
|
140 |
-
)
|
141 |
-
|
142 |
-
if name == SchedulerType.POLYNOMIAL:
|
143 |
-
return schedule_func(
|
144 |
-
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
145 |
-
)
|
146 |
-
|
147 |
-
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
148 |
-
|
149 |
-
|
150 |
def train(args):
|
151 |
session_id = random.randint(0, 2**32)
|
152 |
training_started_at = time.time()
|
@@ -155,6 +63,7 @@ def train(args):
|
|
155 |
|
156 |
cache_latents = args.cache_latents
|
157 |
use_dreambooth_method = args.in_json is None
|
|
|
158 |
|
159 |
if args.seed is not None:
|
160 |
set_seed(args.seed)
|
@@ -162,52 +71,72 @@ def train(args):
|
|
162 |
tokenizer = train_util.load_tokenizer(args)
|
163 |
|
164 |
# データセットを準備する
|
165 |
-
if
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
print("Use DreamBooth method.")
|
172 |
-
train_dataset = append_module.DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
|
173 |
-
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
|
174 |
-
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
175 |
-
args.bucket_reso_steps, args.bucket_no_upscale,
|
176 |
-
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
|
177 |
-
args.random_crop, args.debug_dataset, args.min_resolution, args.area_step)
|
178 |
else:
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
if args.debug_dataset:
|
193 |
-
train_util.debug_dataset(
|
194 |
return
|
195 |
-
if len(
|
196 |
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
197 |
return
|
198 |
|
|
|
|
|
|
|
|
|
199 |
# acceleratorを準備する
|
200 |
print("prepare accelerator")
|
201 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
|
|
202 |
|
203 |
# mixed precisionに対応した型を用意しておき適宜castする
|
204 |
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
205 |
|
206 |
# モデルを読み込む
|
207 |
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
208 |
-
|
209 |
-
|
210 |
-
|
|
|
|
|
|
|
211 |
# モデルに xformers とか memory efficient attention を組み込む
|
212 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
213 |
|
@@ -217,13 +146,15 @@ def train(args):
|
|
217 |
vae.requires_grad_(False)
|
218 |
vae.eval()
|
219 |
with torch.no_grad():
|
220 |
-
|
221 |
vae.to("cpu")
|
222 |
if torch.cuda.is_available():
|
223 |
torch.cuda.empty_cache()
|
224 |
gc.collect()
|
225 |
|
226 |
# prepare network
|
|
|
|
|
227 |
print("import network module:", args.network_module)
|
228 |
network_module = importlib.import_module(args.network_module)
|
229 |
|
@@ -253,188 +184,65 @@ def train(args):
|
|
253 |
|
254 |
# 学習に必要なクラスを準備する
|
255 |
print("prepare optimizer, data loader etc.")
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
except:
|
260 |
-
not_torch_optimizer_flag = True
|
261 |
-
try:
|
262 |
-
print(f"adastand version is {adastand.__version__()}")
|
263 |
-
not_adasatand_optimzier_flag = False
|
264 |
-
except:
|
265 |
-
not_adasatand_optimzier_flag = True
|
266 |
-
|
267 |
-
# 8-bit Adamを使う
|
268 |
-
if args.optimizer=="Adafactor" or args.optimizer=="Adastand" or args.optimizer=="Adastand_belief":
|
269 |
-
not_torch_optimizer_flag = False
|
270 |
-
if args.optimizer=="Adafactor":
|
271 |
-
not_adasatand_optimzier_flag = False
|
272 |
-
if not_torch_optimizer_flag or not_adasatand_optimzier_flag:
|
273 |
-
print(f"==========================\n必要なライブラリがないため {args.optimizer} の使用ができません。optimizerを AdamW に変更して実行します\n==========================")
|
274 |
-
args.optimizer="AdamW"
|
275 |
-
if args.use_8bit_adam:
|
276 |
-
if not args.optimizer=="AdamW" and not args.optimizer=="Lamb":
|
277 |
-
print(f"\n==========================\n{args.optimizer} は8bitAdamに実装されていないので8bitAdamをオフにします\n==========================\n")
|
278 |
-
args.use_8bit_adam=False
|
279 |
-
if args.use_8bit_adam:
|
280 |
-
try:
|
281 |
-
import bitsandbytes as bnb
|
282 |
-
except ImportError:
|
283 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
284 |
-
print("use 8-bit Adam optimizer")
|
285 |
-
args.training_comment=f"{args.training_comment} use_8bit_adam=True"
|
286 |
-
if args.optimizer=="Lamb":
|
287 |
-
optimizer_class = bnb.optim.LAMB8bit
|
288 |
-
else:
|
289 |
-
args.optimizer="AdamW"
|
290 |
-
optimizer_class = bnb.optim.AdamW8bit
|
291 |
-
else:
|
292 |
-
print(f"use {args.optimizer}")
|
293 |
-
if args.optimizer=="RAdam":
|
294 |
-
optimizer_class = torch.optim.RAdam
|
295 |
-
elif args.optimizer=="AdaBound":
|
296 |
-
optimizer_class = optim.AdaBound
|
297 |
-
elif args.optimizer=="AdaBelief":
|
298 |
-
optimizer_class = optim.AdaBelief
|
299 |
-
elif args.optimizer=="AdamP":
|
300 |
-
optimizer_class = optim.AdamP
|
301 |
-
elif args.optimizer=="Adafactor":
|
302 |
-
optimizer_class = Adafactor
|
303 |
-
elif args.optimizer=="Adastand":
|
304 |
-
optimizer_class = adastand.Adastand
|
305 |
-
elif args.optimizer=="Adastand_belief":
|
306 |
-
optimizer_class = adastand.Adastand_b
|
307 |
-
elif args.optimizer=="AggMo":
|
308 |
-
optimizer_class = optim.AggMo
|
309 |
-
elif args.optimizer=="Apollo":
|
310 |
-
optimizer_class = optim.Apollo
|
311 |
-
elif args.optimizer=="Lamb":
|
312 |
-
optimizer_class = optim.Lamb
|
313 |
-
elif args.optimizer=="Ranger":
|
314 |
-
optimizer_class = optim.Ranger
|
315 |
-
elif args.optimizer=="RangerVA":
|
316 |
-
optimizer_class = optim.RangerVA
|
317 |
-
elif args.optimizer=="Yogi":
|
318 |
-
optimizer_class = optim.Yogi
|
319 |
-
elif args.optimizer=="Shampoo":
|
320 |
-
optimizer_class = optim.Shampoo
|
321 |
-
elif args.optimizer=="NovoGrad":
|
322 |
-
optimizer_class = optim.NovoGrad
|
323 |
-
elif args.optimizer=="QHAdam":
|
324 |
-
optimizer_class = optim.QHAdam
|
325 |
-
elif args.optimizer=="DiffGrad" or args.optimizer=="Lookahead_DiffGrad":
|
326 |
-
optimizer_class = optim.DiffGrad
|
327 |
-
elif args.optimizer=="MADGRAD":
|
328 |
-
optimizer_class = optim.MADGRAD
|
329 |
-
else:
|
330 |
-
optimizer_class = torch.optim.AdamW
|
331 |
-
|
332 |
-
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
333 |
-
#optimizerデフォ設定
|
334 |
-
if args.optimizer_arg==None:
|
335 |
-
if args.optimizer=="AdaBelief":
|
336 |
-
args.optimizer_arg = ["eps=1e-16","betas=0.9,0.999","weight_decouple=True","rectify=False","fixed_decay=False"]
|
337 |
-
elif args.optimizer=="DiffGrad":
|
338 |
-
args.optimizer_arg = ["eps=1e-16"]
|
339 |
-
optimizer_arg = {}
|
340 |
-
lookahed_arg = {"k": 5, "alpha": 0.5}
|
341 |
-
adafactor_scheduler_arg = {"initial_lr": 0.}
|
342 |
-
int_args = ["k","n_sma_threshold","warmup"]
|
343 |
-
str_args = ["transformer","grad_transformer"]
|
344 |
-
if not args.optimizer_arg==None and len(args.optimizer_arg)>0:
|
345 |
-
for _opt_arg in args.optimizer_arg:
|
346 |
-
key, value = _opt_arg.split("=")
|
347 |
-
if value=="True" or value=="False":
|
348 |
-
optimizer_arg[key]=bool((value=="True"))
|
349 |
-
elif key=="betas" or key=="nus" or key=="eps2" or (key=="eps" and "," in value):
|
350 |
-
_value = value.split(",")
|
351 |
-
optimizer_arg[key] = (float(_value[0]),float(_value[1]))
|
352 |
-
del _value
|
353 |
-
elif key in int_args:
|
354 |
-
if "Lookahead" in args.optimizer:
|
355 |
-
lookahed_arg[key] = int(value)
|
356 |
-
else:
|
357 |
-
optimizer_arg[key] = int(value)
|
358 |
-
elif key in str_args:
|
359 |
-
optimizer_arg[key] = value
|
360 |
-
else:
|
361 |
-
if key=="alpha" and "Lookahead" in args.optimizer:
|
362 |
-
lookahed_arg[key] = int(value)
|
363 |
-
elif key=="initial_lr" and args.optimizer == "Adafactor":
|
364 |
-
adafactor_scheduler_arg[key] = float(value)
|
365 |
-
else:
|
366 |
-
optimizer_arg[key] = float(value)
|
367 |
-
del _opt_arg
|
368 |
-
AdafactorScheduler_Flag = False
|
369 |
-
list_of_init_lr = []
|
370 |
-
if args.optimizer=="Adafactor":
|
371 |
-
if not "relative_step" in optimizer_arg:
|
372 |
-
optimizer_arg["relative_step"] = True
|
373 |
-
if "warmup_init" in optimizer_arg:
|
374 |
-
if optimizer_arg["warmup_init"]==True and optimizer_arg["relative_step"]==False:
|
375 |
-
print("**************\nwarmup_initはrelative_stepがオンである必要があるためrelative_stepをオンにします\n**************")
|
376 |
-
optimizer_arg["relative_step"] = True
|
377 |
-
if optimizer_arg["relative_step"] == True:
|
378 |
-
AdafactorScheduler_Flag = True
|
379 |
-
list_of_init_lr = [0.,0.]
|
380 |
-
if args.text_encoder_lr is not None: list_of_init_lr[0] = float(args.text_encoder_lr)
|
381 |
-
if args.unet_lr is not None: list_of_init_lr[1] = float(args.unet_lr)
|
382 |
-
#if not "initial_lr" in adafactor_scheduler_arg:
|
383 |
-
# adafactor_scheduler_arg = args.learning_rate
|
384 |
-
args.learning_rate = None
|
385 |
-
args.text_encoder_lr = None
|
386 |
-
args.unet_lr = None
|
387 |
-
print(f"optimizer arg: {optimizer_arg}")
|
388 |
-
print("=-----------------------------------=")
|
389 |
-
if not AdafactorScheduler_Flag: args.split_lora_networks = False
|
390 |
if args.split_lora_networks:
|
|
|
391 |
lora_names = append_module.create_split_names(args.split_lora_networks, args.split_lora_level)
|
392 |
-
append_module.replace_prepare_optimizer_params(network)
|
393 |
-
trainable_params,
|
394 |
else:
|
395 |
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
if args.
|
407 |
-
|
408 |
-
|
409 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
# dataloaderを準備する
|
411 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
412 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
413 |
train_dataloader = torch.utils.data.DataLoader(
|
414 |
-
|
415 |
|
416 |
# 学習ステップ数を計算する
|
417 |
if args.max_train_epochs is not None:
|
418 |
-
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
419 |
-
|
|
|
420 |
|
421 |
# lr schedulerを用意する
|
422 |
-
|
423 |
-
|
424 |
-
print("===================================\nAdafactorはデフォルトでrelative_stepがオンになっているので lrは自動算出されるためLrScheculerの指定も無効になります\nもし任意のLrやLr_Schedulerを使いたい場合は --optimizer_arg relative_ste=False を指定してください\nまた任意のLrを使う場合は scale_parameter=False も併せて指定するのが推奨です\n===================================")
|
425 |
-
lr_scheduler = append_module.AdafactorSchedule_append(optimizer, **adafactor_scheduler_arg)
|
426 |
-
print(f"AdafactorSchedule initial lrs: {lr_scheduler.get_lr()}")
|
427 |
-
del list_of_init_lr
|
428 |
else:
|
429 |
-
lr_scheduler = get_scheduler_fix(
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
#追加機能の設定をコメントに追記して残す
|
435 |
-
|
436 |
-
|
437 |
-
|
|
|
438 |
if args.min_resolution:
|
439 |
args.training_comment=f"{args.training_comment} min_resolution: {args.min_resolution} area_step: {args.area_step}"
|
440 |
|
@@ -504,17 +312,21 @@ def train(args):
|
|
504 |
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
505 |
|
506 |
# 学習する
|
|
|
507 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
|
|
|
|
|
|
518 |
metadata = {
|
519 |
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
520 |
"ss_training_started_at": training_started_at, # unix timestamp
|
@@ -522,12 +334,10 @@ def train(args):
|
|
522 |
"ss_learning_rate": args.learning_rate,
|
523 |
"ss_text_encoder_lr": args.text_encoder_lr,
|
524 |
"ss_unet_lr": args.unet_lr,
|
525 |
-
"ss_num_train_images":
|
526 |
-
"ss_num_reg_images":
|
527 |
"ss_num_batches_per_epoch": len(train_dataloader),
|
528 |
"ss_num_epochs": num_train_epochs,
|
529 |
-
"ss_batch_size_per_device": args.train_batch_size,
|
530 |
-
"ss_total_batch_size": total_batch_size,
|
531 |
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
532 |
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
533 |
"ss_max_train_steps": args.max_train_steps,
|
@@ -539,32 +349,156 @@ def train(args):
|
|
539 |
"ss_mixed_precision": args.mixed_precision,
|
540 |
"ss_full_fp16": bool(args.full_fp16),
|
541 |
"ss_v2": bool(args.v2),
|
542 |
-
"ss_resolution": args.resolution,
|
543 |
"ss_clip_skip": args.clip_skip,
|
544 |
"ss_max_token_length": args.max_token_length,
|
545 |
-
"ss_color_aug": bool(args.color_aug),
|
546 |
-
"ss_flip_aug": bool(args.flip_aug),
|
547 |
-
"ss_random_crop": bool(args.random_crop),
|
548 |
-
"ss_shuffle_caption": bool(args.shuffle_caption),
|
549 |
"ss_cache_latents": bool(args.cache_latents),
|
550 |
-
"ss_enable_bucket": bool(train_dataset.enable_bucket),
|
551 |
-
"ss_min_bucket_reso": train_dataset.min_bucket_reso,
|
552 |
-
"ss_max_bucket_reso": train_dataset.max_bucket_reso,
|
553 |
"ss_seed": args.seed,
|
554 |
-
"
|
555 |
"ss_noise_offset": args.noise_offset,
|
556 |
-
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
557 |
-
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
558 |
-
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
|
559 |
-
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
|
560 |
"ss_training_comment": args.training_comment, # will not be updated after training
|
561 |
-
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
562 |
}
|
563 |
|
564 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
565 |
# for key, value in net_kwargs.items():
|
566 |
# metadata["ss_arg_" + key] = value
|
567 |
|
|
|
568 |
if args.pretrained_model_name_or_path is not None:
|
569 |
sd_model_name = args.pretrained_model_name_or_path
|
570 |
if os.path.exists(sd_model_name):
|
@@ -583,6 +517,13 @@ def train(args):
|
|
583 |
|
584 |
metadata = {k: str(v) for k, v in metadata.items()}
|
585 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
586 |
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
587 |
global_step = 0
|
588 |
|
@@ -595,8 +536,9 @@ def train(args):
|
|
595 |
loss_list = []
|
596 |
loss_total = 0.0
|
597 |
for epoch in range(num_train_epochs):
|
598 |
-
|
599 |
-
|
|
|
600 |
|
601 |
metadata["ss_epoch"] = str(epoch+1)
|
602 |
|
@@ -633,7 +575,7 @@ def train(args):
|
|
633 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
634 |
|
635 |
# Predict the noise residual
|
636 |
-
with autocast():
|
637 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
638 |
|
639 |
if args.v_parameterization:
|
@@ -651,12 +593,13 @@ def train(args):
|
|
651 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
652 |
|
653 |
accelerator.backward(loss)
|
654 |
-
if accelerator.sync_gradients:
|
655 |
params_to_clip = network.get_trainable_params()
|
656 |
-
accelerator.clip_grad_norm_(params_to_clip,
|
657 |
|
658 |
optimizer.step()
|
659 |
-
|
|
|
660 |
optimizer.zero_grad(set_to_none=True)
|
661 |
|
662 |
# Checks if the accelerator has performed an optimization step behind the scenes
|
@@ -664,6 +607,8 @@ def train(args):
|
|
664 |
progress_bar.update(1)
|
665 |
global_step += 1
|
666 |
|
|
|
|
|
667 |
current_loss = loss.detach().item()
|
668 |
if epoch == 0:
|
669 |
loss_list.append(current_loss)
|
@@ -676,7 +621,7 @@ def train(args):
|
|
676 |
progress_bar.set_postfix(**logs)
|
677 |
|
678 |
if args.logging_dir is not None:
|
679 |
-
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
680 |
accelerator.log(logs, step=global_step)
|
681 |
|
682 |
if global_step >= args.max_train_steps:
|
@@ -694,8 +639,9 @@ def train(args):
|
|
694 |
def save_func():
|
695 |
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
696 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
|
|
697 |
print(f"saving checkpoint: {ckpt_file}")
|
698 |
-
unwrap_model(network).save_weights(ckpt_file, save_dtype,
|
699 |
|
700 |
def remove_old_func(old_epoch_no):
|
701 |
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
@@ -704,15 +650,18 @@ def train(args):
|
|
704 |
print(f"removing old checkpoint: {old_ckpt_file}")
|
705 |
os.remove(old_ckpt_file)
|
706 |
|
707 |
-
|
708 |
-
|
709 |
-
|
|
|
|
|
|
|
710 |
|
711 |
# end of epoch
|
712 |
|
713 |
metadata["ss_epoch"] = str(num_train_epochs)
|
|
|
714 |
|
715 |
-
is_main_process = accelerator.is_main_process
|
716 |
if is_main_process:
|
717 |
network = unwrap_model(network)
|
718 |
|
@@ -731,7 +680,7 @@ def train(args):
|
|
731 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
732 |
|
733 |
print(f"save trained model to {ckpt_file}")
|
734 |
-
network.save_weights(ckpt_file, save_dtype,
|
735 |
print("model saved.")
|
736 |
|
737 |
|
@@ -741,6 +690,8 @@ if __name__ == '__main__':
|
|
741 |
train_util.add_sd_models_arguments(parser)
|
742 |
train_util.add_dataset_arguments(parser, True, True, True)
|
743 |
train_util.add_training_arguments(parser, True)
|
|
|
|
|
744 |
|
745 |
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
746 |
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
@@ -748,10 +699,6 @@ if __name__ == '__main__':
|
|
748 |
|
749 |
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
750 |
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
751 |
-
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
752 |
-
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
753 |
-
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
754 |
-
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
755 |
|
756 |
parser.add_argument("--network_weights", type=str, default=None,
|
757 |
help="pretrained weights for network / 学習するネットワークの初期重み")
|
@@ -771,27 +718,30 @@ if __name__ == '__main__':
|
|
771 |
#Optimizer変更関連のオプション追加
|
772 |
append_module.add_append_arguments(parser)
|
773 |
args = append_module.get_config(parser)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
774 |
|
775 |
if args.resolution==args.min_resolution:
|
776 |
args.min_resolution=None
|
777 |
|
778 |
train(args)
|
|
|
779 |
|
780 |
-
#学習が終わったら現在のargsを保存する
|
781 |
-
# import yaml
|
782 |
-
# import datetime
|
783 |
-
# _t = datetime.datetime.today().strftime('%Y%m%d_%H%M')
|
784 |
-
# if args.output_name==None:
|
785 |
-
# config_name = f"train_network_config_{_t}.yaml"
|
786 |
-
# else:
|
787 |
-
# config_name = f"train_network_config_{os.path.basename(args.output_name)}_{_t}.yaml"
|
788 |
-
# print(f"{config_name} に設定を書き出し中...")
|
789 |
-
# with open(config_name, mode="w") as f:
|
790 |
-
# yaml.dump(args.__dict__, f, indent=4)
|
791 |
-
# print("done!")
|
792 |
|
793 |
'''
|
794 |
optimizer設定メモ
|
|
|
|
|
795 |
(optimizer_argから設定できるように変更するためのメモ)
|
796 |
|
797 |
AdamWのweight_decay初期値は1e-2
|
@@ -821,6 +771,7 @@ Adafactor
|
|
821 |
transformerベースのT5学習において最強とかいう噂のoptimizer
|
822 |
huggingfaceのサンプルパラ
|
823 |
eps=1e-30,1e-3 clip_threshold=1.0 decay_rate=-0.8 relative_step=False scale_parameter=False warmup_init=False
|
|
|
824 |
|
825 |
AggMo
|
826 |
|
|
|
|
|
|
|
1 |
from torch.cuda.amp import autocast
|
2 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
3 |
import importlib
|
4 |
import argparse
|
5 |
import gc
|
|
|
12 |
from tqdm import tqdm
|
13 |
import torch
|
14 |
from accelerate.utils import set_seed
|
15 |
+
#import diffusers
|
16 |
from diffusers import DDPMScheduler
|
17 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
##### バケット拡張のためのモジュール
|
19 |
import append_module
|
20 |
######
|
21 |
import library.train_util as train_util
|
22 |
+
from library.train_util import (
|
23 |
+
DreamBoothDataset,
|
24 |
+
)
|
25 |
+
import library.config_util as config_util
|
26 |
+
from library.config_util import (
|
27 |
+
ConfigSanitizer,
|
28 |
+
BlueprintGenerator,
|
29 |
+
)
|
30 |
|
31 |
|
32 |
def collate_fn(examples):
|
33 |
return examples[0]
|
34 |
|
35 |
|
36 |
+
# TODO 他のスクリプトと共通化する
|
37 |
+
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, split_names=None):
|
38 |
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
39 |
+
if not args.split_lora_networks:
|
40 |
+
if args.network_train_unet_only:
|
41 |
+
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
|
42 |
+
elif args.network_train_text_encoder_only:
|
43 |
+
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
44 |
+
else:
|
45 |
+
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
46 |
+
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
|
47 |
else:
|
48 |
last_lrs = lr_scheduler.get_last_lr()
|
49 |
+
for last_lr, t_name in zip(last_lrs, split_names):
|
50 |
+
logs[f"lr/{t_name}"] = float(last_lr)
|
51 |
+
#D-Adaptationの仕様ちゃんと見てないからたぶん分割したのをちゃんと表示するならそれに合わせた記述が必要 でも多分D-Adaptationの挙動的に全部同一の形になるのでいらない
|
52 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
|
53 |
+
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
return logs
|
56 |
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
def train(args):
|
59 |
session_id = random.randint(0, 2**32)
|
60 |
training_started_at = time.time()
|
|
|
63 |
|
64 |
cache_latents = args.cache_latents
|
65 |
use_dreambooth_method = args.in_json is None
|
66 |
+
use_user_config = args.dataset_config is not None
|
67 |
|
68 |
if args.seed is not None:
|
69 |
set_seed(args.seed)
|
|
|
71 |
tokenizer = train_util.load_tokenizer(args)
|
72 |
|
73 |
# データセットを準備する
|
74 |
+
if args.min_resolution:
|
75 |
+
args.min_resolution = tuple([int(r) for r in args.min_resolution.split(',')])
|
76 |
+
if len(args.min_resolution) == 1:
|
77 |
+
args.min_resolution = (args.min_resolution[0], args.min_resolution[0])
|
78 |
+
blueprint_generator = append_module.BlueprintGenerator(append_module.ConfigSanitizer(True, True, True))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
else:
|
80 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
|
81 |
+
if use_user_config:
|
82 |
+
print(f"Load dataset config from {args.dataset_config}")
|
83 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
84 |
+
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
85 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
86 |
+
print(
|
87 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
88 |
+
else:
|
89 |
+
if use_dreambooth_method:
|
90 |
+
print("Use DreamBooth method.")
|
91 |
+
user_config = {
|
92 |
+
"datasets": [{
|
93 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
|
94 |
+
}]
|
95 |
+
}
|
96 |
+
else:
|
97 |
+
print("Train with captions.")
|
98 |
+
user_config = {
|
99 |
+
"datasets": [{
|
100 |
+
"subsets": [{
|
101 |
+
"image_dir": args.train_data_dir,
|
102 |
+
"metadata_file": args.in_json,
|
103 |
+
}]
|
104 |
+
}]
|
105 |
+
}
|
106 |
+
|
107 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
108 |
+
if args.min_resolution:
|
109 |
+
train_dataset_group = append_module.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
110 |
+
else:
|
111 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
112 |
|
113 |
if args.debug_dataset:
|
114 |
+
train_util.debug_dataset(train_dataset_group)
|
115 |
return
|
116 |
+
if len(train_dataset_group) == 0:
|
117 |
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
118 |
return
|
119 |
|
120 |
+
if cache_latents:
|
121 |
+
assert train_dataset_group.is_latent_cacheable(
|
122 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
123 |
+
|
124 |
# acceleratorを準備する
|
125 |
print("prepare accelerator")
|
126 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
127 |
+
is_main_process = accelerator.is_main_process
|
128 |
|
129 |
# mixed precisionに対応した型を用意しておき適宜castする
|
130 |
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
131 |
|
132 |
# モデルを読み込む
|
133 |
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
134 |
+
|
135 |
+
# work on low-ram device
|
136 |
+
if args.lowram:
|
137 |
+
text_encoder.to("cuda")
|
138 |
+
unet.to("cuda")
|
139 |
+
|
140 |
# モデルに xformers とか memory efficient attention を組み込む
|
141 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
142 |
|
|
|
146 |
vae.requires_grad_(False)
|
147 |
vae.eval()
|
148 |
with torch.no_grad():
|
149 |
+
train_dataset_group.cache_latents(vae)
|
150 |
vae.to("cpu")
|
151 |
if torch.cuda.is_available():
|
152 |
torch.cuda.empty_cache()
|
153 |
gc.collect()
|
154 |
|
155 |
# prepare network
|
156 |
+
import sys
|
157 |
+
sys.path.append(os.path.dirname(__file__))
|
158 |
print("import network module:", args.network_module)
|
159 |
network_module = importlib.import_module(args.network_module)
|
160 |
|
|
|
184 |
|
185 |
# 学習に必要なクラスを準備する
|
186 |
print("prepare optimizer, data loader etc.")
|
187 |
+
split_flag = (args.split_lora_networks) or ((not args.network_train_text_encoder_only) and (not args.network_train_unet_only))
|
188 |
+
|
189 |
+
used_names = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
if args.split_lora_networks:
|
191 |
+
lr_dic, block_args_dic = append_module.create_lr_blocks(args.blocks_lr_setting, args.block_optim_args)
|
192 |
lora_names = append_module.create_split_names(args.split_lora_networks, args.split_lora_level)
|
193 |
+
append_module.replace_prepare_optimizer_params(network, network_module)
|
194 |
+
trainable_params, adafactor_scheduler_arg, used_names = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, lora_names, lr_dic, block_args_dic)
|
195 |
else:
|
196 |
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
197 |
+
if split_flag:
|
198 |
+
_t_lr = 0.
|
199 |
+
_u_lr = 0.
|
200 |
+
if args.text_encoder_lr:
|
201 |
+
_t_lr = args.text_encoder_lr
|
202 |
+
if args.unet_lr:
|
203 |
+
_u_lr = args.unet_lr
|
204 |
+
adafactor_scheduler_arg = {"initial_lr": [_t_lr, _u_lr]}
|
205 |
+
|
206 |
+
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
207 |
+
if args.use_lookahead:
|
208 |
+
try:
|
209 |
+
import torch_optimizer
|
210 |
+
lookahed_arg = {"k": 5, "alpha": 0.5}
|
211 |
+
if args.lookahead_arg is not None:
|
212 |
+
for _arg in args.lookahead_arg:
|
213 |
+
k, v = _arg.split("=")
|
214 |
+
if k == "k":
|
215 |
+
lookahed_arg[k] = int(v)
|
216 |
+
else:
|
217 |
+
lookahed_arg[k] = float(v)
|
218 |
+
optimizer = torch_optimizer.Lookahead(optimizer, **lookahed_arg)
|
219 |
+
except:
|
220 |
+
print("\n============\ntorch_optimizerのimportに失敗しました Lookaheadを無効化して処理を続けます\n============\n")
|
221 |
# dataloaderを準備する
|
222 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
223 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
224 |
train_dataloader = torch.utils.data.DataLoader(
|
225 |
+
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
226 |
|
227 |
# 学習ステップ数を計算する
|
228 |
if args.max_train_epochs is not None:
|
229 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes)
|
230 |
+
if is_main_process:
|
231 |
+
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
232 |
|
233 |
# lr schedulerを用意する
|
234 |
+
if args.lr_scheduler.startswith("adafactor") and split_flag:
|
235 |
+
lr_scheduler = append_module.get_scheduler_Adafactor(args.lr_scheduler, optimizer, adafactor_scheduler_arg)
|
|
|
|
|
|
|
|
|
236 |
else:
|
237 |
+
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
238 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps,
|
239 |
+
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
240 |
+
|
|
|
241 |
#追加機能の設定をコメントに追記して残す
|
242 |
+
if args.use_lookahead:
|
243 |
+
args.training_comment=f"{args.training_comment} use Lookahead: True Lookahead args: {lookahed_arg}"
|
244 |
+
if args.split_lora_networks:
|
245 |
+
args.training_comment=f"{args.training_comment} split_lora_networks: {args.split_lora_networks} split_level: {args.split_lora_level}"
|
246 |
if args.min_resolution:
|
247 |
args.training_comment=f"{args.training_comment} min_resolution: {args.min_resolution} area_step: {args.area_step}"
|
248 |
|
|
|
312 |
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
313 |
|
314 |
# 学習する
|
315 |
+
# TODO: find a way to handle total batch size when there are multiple datasets
|
316 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
317 |
+
|
318 |
+
if is_main_process:
|
319 |
+
print("running training / 学習開始")
|
320 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
321 |
+
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
322 |
+
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
323 |
+
print(f" num epochs / epoch数: {num_train_epochs}")
|
324 |
+
print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
325 |
+
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
326 |
+
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
327 |
+
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
328 |
+
|
329 |
+
# TODO refactor metadata creation and move to util
|
330 |
metadata = {
|
331 |
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
332 |
"ss_training_started_at": training_started_at, # unix timestamp
|
|
|
334 |
"ss_learning_rate": args.learning_rate,
|
335 |
"ss_text_encoder_lr": args.text_encoder_lr,
|
336 |
"ss_unet_lr": args.unet_lr,
|
337 |
+
"ss_num_train_images": train_dataset_group.num_train_images,
|
338 |
+
"ss_num_reg_images": train_dataset_group.num_reg_images,
|
339 |
"ss_num_batches_per_epoch": len(train_dataloader),
|
340 |
"ss_num_epochs": num_train_epochs,
|
|
|
|
|
341 |
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
342 |
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
343 |
"ss_max_train_steps": args.max_train_steps,
|
|
|
349 |
"ss_mixed_precision": args.mixed_precision,
|
350 |
"ss_full_fp16": bool(args.full_fp16),
|
351 |
"ss_v2": bool(args.v2),
|
|
|
352 |
"ss_clip_skip": args.clip_skip,
|
353 |
"ss_max_token_length": args.max_token_length,
|
|
|
|
|
|
|
|
|
354 |
"ss_cache_latents": bool(args.cache_latents),
|
|
|
|
|
|
|
355 |
"ss_seed": args.seed,
|
356 |
+
"ss_lowram": args.lowram,
|
357 |
"ss_noise_offset": args.noise_offset,
|
|
|
|
|
|
|
|
|
358 |
"ss_training_comment": args.training_comment, # will not be updated after training
|
359 |
+
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
360 |
+
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
|
361 |
+
"ss_max_grad_norm": args.max_grad_norm,
|
362 |
+
"ss_caption_dropout_rate": args.caption_dropout_rate,
|
363 |
+
"ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
|
364 |
+
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
|
365 |
+
"ss_face_crop_aug_range": args.face_crop_aug_range,
|
366 |
+
"ss_prior_loss_weight": args.prior_loss_weight,
|
367 |
}
|
368 |
|
369 |
+
if use_user_config:
|
370 |
+
# save metadata of multiple datasets
|
371 |
+
# NOTE: pack "ss_datasets" value as json one time
|
372 |
+
# or should also pack nested collections as json?
|
373 |
+
datasets_metadata = []
|
374 |
+
tag_frequency = {} # merge tag frequency for metadata editor
|
375 |
+
dataset_dirs_info = {} # merge subset dirs for metadata editor
|
376 |
+
|
377 |
+
for dataset in train_dataset_group.datasets:
|
378 |
+
is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
|
379 |
+
dataset_metadata = {
|
380 |
+
"is_dreambooth": is_dreambooth_dataset,
|
381 |
+
"batch_size_per_device": dataset.batch_size,
|
382 |
+
"num_train_images": dataset.num_train_images, # includes repeating
|
383 |
+
"num_reg_images": dataset.num_reg_images,
|
384 |
+
"resolution": (dataset.width, dataset.height),
|
385 |
+
"enable_bucket": bool(dataset.enable_bucket),
|
386 |
+
"min_bucket_reso": dataset.min_bucket_reso,
|
387 |
+
"max_bucket_reso": dataset.max_bucket_reso,
|
388 |
+
"tag_frequency": dataset.tag_frequency,
|
389 |
+
"bucket_info": dataset.bucket_info,
|
390 |
+
}
|
391 |
+
|
392 |
+
subsets_metadata = []
|
393 |
+
for subset in dataset.subsets:
|
394 |
+
subset_metadata = {
|
395 |
+
"img_count": subset.img_count,
|
396 |
+
"num_repeats": subset.num_repeats,
|
397 |
+
"color_aug": bool(subset.color_aug),
|
398 |
+
"flip_aug": bool(subset.flip_aug),
|
399 |
+
"random_crop": bool(subset.random_crop),
|
400 |
+
"shuffle_caption": bool(subset.shuffle_caption),
|
401 |
+
"keep_tokens": subset.keep_tokens,
|
402 |
+
}
|
403 |
+
|
404 |
+
image_dir_or_metadata_file = None
|
405 |
+
if subset.image_dir:
|
406 |
+
image_dir = os.path.basename(subset.image_dir)
|
407 |
+
subset_metadata["image_dir"] = image_dir
|
408 |
+
image_dir_or_metadata_file = image_dir
|
409 |
+
|
410 |
+
if is_dreambooth_dataset:
|
411 |
+
subset_metadata["class_tokens"] = subset.class_tokens
|
412 |
+
subset_metadata["is_reg"] = subset.is_reg
|
413 |
+
if subset.is_reg:
|
414 |
+
image_dir_or_metadata_file = None # not merging reg dataset
|
415 |
+
else:
|
416 |
+
metadata_file = os.path.basename(subset.metadata_file)
|
417 |
+
subset_metadata["metadata_file"] = metadata_file
|
418 |
+
image_dir_or_metadata_file = metadata_file # may overwrite
|
419 |
+
|
420 |
+
subsets_metadata.append(subset_metadata)
|
421 |
+
|
422 |
+
# merge dataset dir: not reg subset only
|
423 |
+
# TODO update additional-network extension to show detailed dataset config from metadata
|
424 |
+
if image_dir_or_metadata_file is not None:
|
425 |
+
# datasets may have a certain dir multiple times
|
426 |
+
v = image_dir_or_metadata_file
|
427 |
+
i = 2
|
428 |
+
while v in dataset_dirs_info:
|
429 |
+
v = image_dir_or_metadata_file + f" ({i})"
|
430 |
+
i += 1
|
431 |
+
image_dir_or_metadata_file = v
|
432 |
+
|
433 |
+
dataset_dirs_info[image_dir_or_metadata_file] = {
|
434 |
+
"n_repeats": subset.num_repeats,
|
435 |
+
"img_count": subset.img_count
|
436 |
+
}
|
437 |
+
|
438 |
+
dataset_metadata["subsets"] = subsets_metadata
|
439 |
+
datasets_metadata.append(dataset_metadata)
|
440 |
+
|
441 |
+
# merge tag frequency:
|
442 |
+
for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
|
443 |
+
# あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
|
444 |
+
# もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
|
445 |
+
# なので、ここで複数datasetの回数を合算してもあまり意味はない
|
446 |
+
if ds_dir_name in tag_frequency:
|
447 |
+
continue
|
448 |
+
tag_frequency[ds_dir_name] = ds_freq_for_dir
|
449 |
+
|
450 |
+
metadata["ss_datasets"] = json.dumps(datasets_metadata)
|
451 |
+
metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
|
452 |
+
metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
|
453 |
+
else:
|
454 |
+
# conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
|
455 |
+
assert len(
|
456 |
+
train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
|
457 |
+
|
458 |
+
dataset = train_dataset_group.datasets[0]
|
459 |
+
|
460 |
+
dataset_dirs_info = {}
|
461 |
+
reg_dataset_dirs_info = {}
|
462 |
+
if use_dreambooth_method:
|
463 |
+
for subset in dataset.subsets:
|
464 |
+
info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
|
465 |
+
info[os.path.basename(subset.image_dir)] = {
|
466 |
+
"n_repeats": subset.num_repeats,
|
467 |
+
"img_count": subset.img_count
|
468 |
+
}
|
469 |
+
else:
|
470 |
+
for subset in dataset.subsets:
|
471 |
+
dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
|
472 |
+
"n_repeats": subset.num_repeats,
|
473 |
+
"img_count": subset.img_count
|
474 |
+
}
|
475 |
+
|
476 |
+
metadata.update({
|
477 |
+
"ss_batch_size_per_device": args.train_batch_size,
|
478 |
+
"ss_total_batch_size": total_batch_size,
|
479 |
+
"ss_resolution": args.resolution,
|
480 |
+
"ss_color_aug": bool(args.color_aug),
|
481 |
+
"ss_flip_aug": bool(args.flip_aug),
|
482 |
+
"ss_random_crop": bool(args.random_crop),
|
483 |
+
"ss_shuffle_caption": bool(args.shuffle_caption),
|
484 |
+
"ss_enable_bucket": bool(dataset.enable_bucket),
|
485 |
+
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
|
486 |
+
"ss_min_bucket_reso": dataset.min_bucket_reso,
|
487 |
+
"ss_max_bucket_reso": dataset.max_bucket_reso,
|
488 |
+
"ss_keep_tokens": args.keep_tokens,
|
489 |
+
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
|
490 |
+
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
|
491 |
+
"ss_tag_frequency": json.dumps(dataset.tag_frequency),
|
492 |
+
"ss_bucket_info": json.dumps(dataset.bucket_info),
|
493 |
+
})
|
494 |
+
|
495 |
+
# add extra args
|
496 |
+
if args.network_args:
|
497 |
+
metadata["ss_network_args"] = json.dumps(net_kwargs)
|
498 |
# for key, value in net_kwargs.items():
|
499 |
# metadata["ss_arg_" + key] = value
|
500 |
|
501 |
+
# model name and hash
|
502 |
if args.pretrained_model_name_or_path is not None:
|
503 |
sd_model_name = args.pretrained_model_name_or_path
|
504 |
if os.path.exists(sd_model_name):
|
|
|
517 |
|
518 |
metadata = {k: str(v) for k, v in metadata.items()}
|
519 |
|
520 |
+
# make minimum metadata for filtering
|
521 |
+
minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"]
|
522 |
+
minimum_metadata = {}
|
523 |
+
for key in minimum_keys:
|
524 |
+
if key in metadata:
|
525 |
+
minimum_metadata[key] = metadata[key]
|
526 |
+
|
527 |
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
528 |
global_step = 0
|
529 |
|
|
|
536 |
loss_list = []
|
537 |
loss_total = 0.0
|
538 |
for epoch in range(num_train_epochs):
|
539 |
+
if is_main_process:
|
540 |
+
print(f"epoch {epoch+1}/{num_train_epochs}")
|
541 |
+
train_dataset_group.set_current_epoch(epoch + 1)
|
542 |
|
543 |
metadata["ss_epoch"] = str(epoch+1)
|
544 |
|
|
|
575 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
576 |
|
577 |
# Predict the noise residual
|
578 |
+
with accelerator.autocast():
|
579 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
580 |
|
581 |
if args.v_parameterization:
|
|
|
593 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
594 |
|
595 |
accelerator.backward(loss)
|
596 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
597 |
params_to_clip = network.get_trainable_params()
|
598 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
599 |
|
600 |
optimizer.step()
|
601 |
+
if accelerator.sync_gradients:
|
602 |
+
lr_scheduler.step()
|
603 |
optimizer.zero_grad(set_to_none=True)
|
604 |
|
605 |
# Checks if the accelerator has performed an optimization step behind the scenes
|
|
|
607 |
progress_bar.update(1)
|
608 |
global_step += 1
|
609 |
|
610 |
+
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
611 |
+
|
612 |
current_loss = loss.detach().item()
|
613 |
if epoch == 0:
|
614 |
loss_list.append(current_loss)
|
|
|
621 |
progress_bar.set_postfix(**logs)
|
622 |
|
623 |
if args.logging_dir is not None:
|
624 |
+
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, used_names)
|
625 |
accelerator.log(logs, step=global_step)
|
626 |
|
627 |
if global_step >= args.max_train_steps:
|
|
|
639 |
def save_func():
|
640 |
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
641 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
642 |
+
metadata["ss_training_finished_at"] = str(time.time())
|
643 |
print(f"saving checkpoint: {ckpt_file}")
|
644 |
+
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
645 |
|
646 |
def remove_old_func(old_epoch_no):
|
647 |
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
|
|
650 |
print(f"removing old checkpoint: {old_ckpt_file}")
|
651 |
os.remove(old_ckpt_file)
|
652 |
|
653 |
+
if is_main_process:
|
654 |
+
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
655 |
+
if saving and args.save_state:
|
656 |
+
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
657 |
+
|
658 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
659 |
|
660 |
# end of epoch
|
661 |
|
662 |
metadata["ss_epoch"] = str(num_train_epochs)
|
663 |
+
metadata["ss_training_finished_at"] = str(time.time())
|
664 |
|
|
|
665 |
if is_main_process:
|
666 |
network = unwrap_model(network)
|
667 |
|
|
|
680 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
681 |
|
682 |
print(f"save trained model to {ckpt_file}")
|
683 |
+
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
684 |
print("model saved.")
|
685 |
|
686 |
|
|
|
690 |
train_util.add_sd_models_arguments(parser)
|
691 |
train_util.add_dataset_arguments(parser, True, True, True)
|
692 |
train_util.add_training_arguments(parser, True)
|
693 |
+
train_util.add_optimizer_arguments(parser)
|
694 |
+
config_util.add_config_arguments(parser)
|
695 |
|
696 |
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
697 |
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
|
|
699 |
|
700 |
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
701 |
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
|
|
|
|
|
|
|
|
702 |
|
703 |
parser.add_argument("--network_weights", type=str, default=None,
|
704 |
help="pretrained weights for network / 学習するネットワークの初期重み")
|
|
|
718 |
#Optimizer変更関連のオプション追加
|
719 |
append_module.add_append_arguments(parser)
|
720 |
args = append_module.get_config(parser)
|
721 |
+
if not args.not_output_config:
|
722 |
+
#argsを保存する
|
723 |
+
import yaml
|
724 |
+
import datetime
|
725 |
+
_t = datetime.datetime.today().strftime('%Y%m%d_%H%M')
|
726 |
+
if args.output_name==None:
|
727 |
+
config_name = f"train_network_config_{_t}.yaml"
|
728 |
+
else:
|
729 |
+
config_name = f"train_network_config_{os.path.basename(args.output_name)}_{_t}.yaml"
|
730 |
+
print(f"{config_name} に設定を書き出し中...")
|
731 |
+
with open(config_name, mode="w") as f:
|
732 |
+
yaml.dump(args.__dict__, f, indent=4)
|
733 |
|
734 |
if args.resolution==args.min_resolution:
|
735 |
args.min_resolution=None
|
736 |
|
737 |
train(args)
|
738 |
+
print("done!")
|
739 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
740 |
|
741 |
'''
|
742 |
optimizer設定メモ
|
743 |
+
torch_optimizer.AdaBelief
|
744 |
+
adastand.Adastand
|
745 |
(optimizer_argから設定できるように変更するためのメモ)
|
746 |
|
747 |
AdamWのweight_decay初期値は1e-2
|
|
|
771 |
transformerベースのT5学習において最強とかいう噂のoptimizer
|
772 |
huggingfaceのサンプルパラ
|
773 |
eps=1e-30,1e-3 clip_threshold=1.0 decay_rate=-0.8 relative_step=False scale_parameter=False warmup_init=False
|
774 |
+
epsの二つ目の値1e-3が学習率に影響大きい
|
775 |
|
776 |
AggMo
|
777 |
|
train_textual_inversion.py
CHANGED
@@ -11,7 +11,11 @@ import diffusers
|
|
11 |
from diffusers import DDPMScheduler
|
12 |
|
13 |
import library.train_util as train_util
|
14 |
-
|
|
|
|
|
|
|
|
|
15 |
|
16 |
imagenet_templates_small = [
|
17 |
"a photo of a {}",
|
@@ -79,7 +83,6 @@ def train(args):
|
|
79 |
train_util.prepare_dataset_args(args, True)
|
80 |
|
81 |
cache_latents = args.cache_latents
|
82 |
-
use_dreambooth_method = args.in_json is None
|
83 |
|
84 |
if args.seed is not None:
|
85 |
set_seed(args.seed)
|
@@ -139,21 +142,35 @@ def train(args):
|
|
139 |
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
140 |
|
141 |
# データセットを準備する
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
else:
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
159 |
if use_template:
|
@@ -163,20 +180,30 @@ def train(args):
|
|
163 |
captions = []
|
164 |
for tmpl in templates:
|
165 |
captions.append(tmpl.format(replace_to))
|
166 |
-
|
167 |
-
elif args.num_vectors_per_token > 1:
|
168 |
-
replace_to = " ".join(token_strings)
|
169 |
-
train_dataset.add_replacement(args.token_string, replace_to)
|
170 |
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
if args.debug_dataset:
|
174 |
-
train_util.debug_dataset(
|
175 |
return
|
176 |
-
if len(
|
177 |
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
178 |
return
|
179 |
|
|
|
|
|
|
|
180 |
# モデルに xformers とか memory efficient attention を組み込む
|
181 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
182 |
|
@@ -186,7 +213,7 @@ def train(args):
|
|
186 |
vae.requires_grad_(False)
|
187 |
vae.eval()
|
188 |
with torch.no_grad():
|
189 |
-
|
190 |
vae.to("cpu")
|
191 |
if torch.cuda.is_available():
|
192 |
torch.cuda.empty_cache()
|
@@ -198,35 +225,14 @@ def train(args):
|
|
198 |
|
199 |
# 学習に必要なクラスを準備する
|
200 |
print("prepare optimizer, data loader etc.")
|
201 |
-
|
202 |
-
# 8-bit Adamを使う
|
203 |
-
if args.use_8bit_adam:
|
204 |
-
try:
|
205 |
-
import bitsandbytes as bnb
|
206 |
-
except ImportError:
|
207 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
208 |
-
print("use 8-bit Adam optimizer")
|
209 |
-
optimizer_class = bnb.optim.AdamW8bit
|
210 |
-
elif args.use_lion_optimizer:
|
211 |
-
try:
|
212 |
-
import lion_pytorch
|
213 |
-
except ImportError:
|
214 |
-
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
215 |
-
print("use Lion optimizer")
|
216 |
-
optimizer_class = lion_pytorch.Lion
|
217 |
-
else:
|
218 |
-
optimizer_class = torch.optim.AdamW
|
219 |
-
|
220 |
trainable_params = text_encoder.get_input_embeddings().parameters()
|
221 |
-
|
222 |
-
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
223 |
-
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
224 |
|
225 |
# dataloaderを準備する
|
226 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
227 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
228 |
train_dataloader = torch.utils.data.DataLoader(
|
229 |
-
|
230 |
|
231 |
# 学習ステップ数を計算する
|
232 |
if args.max_train_epochs is not None:
|
@@ -234,8 +240,9 @@ def train(args):
|
|
234 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
235 |
|
236 |
# lr schedulerを用意する
|
237 |
-
lr_scheduler =
|
238 |
-
|
|
|
239 |
|
240 |
# acceleratorがなんかよろしくやってくれるらしい
|
241 |
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
@@ -283,8 +290,8 @@ def train(args):
|
|
283 |
# 学習する
|
284 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
285 |
print("running training / 学習開始")
|
286 |
-
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {
|
287 |
-
print(f" num reg images / 正則化画像の数: {
|
288 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
289 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
290 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
@@ -303,12 +310,11 @@ def train(args):
|
|
303 |
|
304 |
for epoch in range(num_train_epochs):
|
305 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
306 |
-
|
307 |
|
308 |
text_encoder.train()
|
309 |
|
310 |
loss_total = 0
|
311 |
-
bef_epo_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
312 |
for step, batch in enumerate(train_dataloader):
|
313 |
with accelerator.accumulate(text_encoder):
|
314 |
with torch.no_grad():
|
@@ -357,9 +363,9 @@ def train(args):
|
|
357 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
358 |
|
359 |
accelerator.backward(loss)
|
360 |
-
if accelerator.sync_gradients:
|
361 |
params_to_clip = text_encoder.get_input_embeddings().parameters()
|
362 |
-
accelerator.clip_grad_norm_(params_to_clip,
|
363 |
|
364 |
optimizer.step()
|
365 |
lr_scheduler.step()
|
@@ -374,9 +380,14 @@ def train(args):
|
|
374 |
progress_bar.update(1)
|
375 |
global_step += 1
|
376 |
|
|
|
|
|
|
|
377 |
current_loss = loss.detach().item()
|
378 |
if args.logging_dir is not None:
|
379 |
-
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
|
|
380 |
accelerator.log(logs, step=global_step)
|
381 |
|
382 |
loss_total += current_loss
|
@@ -394,8 +405,6 @@ def train(args):
|
|
394 |
accelerator.wait_for_everyone()
|
395 |
|
396 |
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
397 |
-
# d = updated_embs - bef_epo_embs
|
398 |
-
# print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
|
399 |
|
400 |
if args.save_every_n_epochs is not None:
|
401 |
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
@@ -417,6 +426,9 @@ def train(args):
|
|
417 |
if saving and args.save_state:
|
418 |
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
419 |
|
|
|
|
|
|
|
420 |
# end of epoch
|
421 |
|
422 |
is_main_process = accelerator.is_main_process
|
@@ -491,6 +503,8 @@ if __name__ == '__main__':
|
|
491 |
train_util.add_sd_models_arguments(parser)
|
492 |
train_util.add_dataset_arguments(parser, True, True, False)
|
493 |
train_util.add_training_arguments(parser, True)
|
|
|
|
|
494 |
|
495 |
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
496 |
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|
|
|
11 |
from diffusers import DDPMScheduler
|
12 |
|
13 |
import library.train_util as train_util
|
14 |
+
import library.config_util as config_util
|
15 |
+
from library.config_util import (
|
16 |
+
ConfigSanitizer,
|
17 |
+
BlueprintGenerator,
|
18 |
+
)
|
19 |
|
20 |
imagenet_templates_small = [
|
21 |
"a photo of a {}",
|
|
|
83 |
train_util.prepare_dataset_args(args, True)
|
84 |
|
85 |
cache_latents = args.cache_latents
|
|
|
86 |
|
87 |
if args.seed is not None:
|
88 |
set_seed(args.seed)
|
|
|
142 |
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
143 |
|
144 |
# データセットを準備する
|
145 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
|
146 |
+
if args.dataset_config is not None:
|
147 |
+
print(f"Load dataset config from {args.dataset_config}")
|
148 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
149 |
+
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
150 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
151 |
+
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
152 |
else:
|
153 |
+
use_dreambooth_method = args.in_json is None
|
154 |
+
if use_dreambooth_method:
|
155 |
+
print("Use DreamBooth method.")
|
156 |
+
user_config = {
|
157 |
+
"datasets": [{
|
158 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
|
159 |
+
}]
|
160 |
+
}
|
161 |
+
else:
|
162 |
+
print("Train with captions.")
|
163 |
+
user_config = {
|
164 |
+
"datasets": [{
|
165 |
+
"subsets": [{
|
166 |
+
"image_dir": args.train_data_dir,
|
167 |
+
"metadata_file": args.in_json,
|
168 |
+
}]
|
169 |
+
}]
|
170 |
+
}
|
171 |
+
|
172 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
173 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
174 |
|
175 |
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
176 |
if use_template:
|
|
|
180 |
captions = []
|
181 |
for tmpl in templates:
|
182 |
captions.append(tmpl.format(replace_to))
|
183 |
+
train_dataset_group.add_replacement("", captions)
|
|
|
|
|
|
|
184 |
|
185 |
+
if args.num_vectors_per_token > 1:
|
186 |
+
prompt_replacement = (args.token_string, replace_to)
|
187 |
+
else:
|
188 |
+
prompt_replacement = None
|
189 |
+
else:
|
190 |
+
if args.num_vectors_per_token > 1:
|
191 |
+
replace_to = " ".join(token_strings)
|
192 |
+
train_dataset_group.add_replacement(args.token_string, replace_to)
|
193 |
+
prompt_replacement = (args.token_string, replace_to)
|
194 |
+
else:
|
195 |
+
prompt_replacement = None
|
196 |
|
197 |
if args.debug_dataset:
|
198 |
+
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
|
199 |
return
|
200 |
+
if len(train_dataset_group) == 0:
|
201 |
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
202 |
return
|
203 |
|
204 |
+
if cache_latents:
|
205 |
+
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
206 |
+
|
207 |
# モデルに xformers とか memory efficient attention を組み込む
|
208 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
209 |
|
|
|
213 |
vae.requires_grad_(False)
|
214 |
vae.eval()
|
215 |
with torch.no_grad():
|
216 |
+
train_dataset_group.cache_latents(vae)
|
217 |
vae.to("cpu")
|
218 |
if torch.cuda.is_available():
|
219 |
torch.cuda.empty_cache()
|
|
|
225 |
|
226 |
# 学習に必要なクラスを準備する
|
227 |
print("prepare optimizer, data loader etc.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
trainable_params = text_encoder.get_input_embeddings().parameters()
|
229 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
|
|
|
|
230 |
|
231 |
# dataloaderを準備する
|
232 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
233 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
234 |
train_dataloader = torch.utils.data.DataLoader(
|
235 |
+
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
236 |
|
237 |
# 学習ステップ数を計算する
|
238 |
if args.max_train_epochs is not None:
|
|
|
240 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
241 |
|
242 |
# lr schedulerを用意する
|
243 |
+
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
244 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
245 |
+
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
246 |
|
247 |
# acceleratorがなんかよろしくやってくれるらしい
|
248 |
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
|
|
290 |
# 学習する
|
291 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
292 |
print("running training / 学習開始")
|
293 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
294 |
+
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
295 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
296 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
297 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
|
|
310 |
|
311 |
for epoch in range(num_train_epochs):
|
312 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
313 |
+
train_dataset_group.set_current_epoch(epoch + 1)
|
314 |
|
315 |
text_encoder.train()
|
316 |
|
317 |
loss_total = 0
|
|
|
318 |
for step, batch in enumerate(train_dataloader):
|
319 |
with accelerator.accumulate(text_encoder):
|
320 |
with torch.no_grad():
|
|
|
363 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
364 |
|
365 |
accelerator.backward(loss)
|
366 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
367 |
params_to_clip = text_encoder.get_input_embeddings().parameters()
|
368 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
369 |
|
370 |
optimizer.step()
|
371 |
lr_scheduler.step()
|
|
|
380 |
progress_bar.update(1)
|
381 |
global_step += 1
|
382 |
|
383 |
+
train_util.sample_images(accelerator, args, None, global_step, accelerator.device,
|
384 |
+
vae, tokenizer, text_encoder, unet, prompt_replacement)
|
385 |
+
|
386 |
current_loss = loss.detach().item()
|
387 |
if args.logging_dir is not None:
|
388 |
+
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
389 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
390 |
+
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
391 |
accelerator.log(logs, step=global_step)
|
392 |
|
393 |
loss_total += current_loss
|
|
|
405 |
accelerator.wait_for_everyone()
|
406 |
|
407 |
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
|
|
|
|
408 |
|
409 |
if args.save_every_n_epochs is not None:
|
410 |
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
|
|
426 |
if saving and args.save_state:
|
427 |
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
428 |
|
429 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device,
|
430 |
+
vae, tokenizer, text_encoder, unet, prompt_replacement)
|
431 |
+
|
432 |
# end of epoch
|
433 |
|
434 |
is_main_process = accelerator.is_main_process
|
|
|
503 |
train_util.add_sd_models_arguments(parser)
|
504 |
train_util.add_dataset_arguments(parser, True, True, False)
|
505 |
train_util.add_training_arguments(parser, True)
|
506 |
+
train_util.add_optimizer_arguments(parser)
|
507 |
+
config_util.add_config_arguments(parser)
|
508 |
|
509 |
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
510 |
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|
train_ti_README-ja.md
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[Textual Inversion](https://textual-inversion.github.io/) の学習についての説明です。
|
2 |
+
|
3 |
+
[学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。
|
4 |
+
|
5 |
+
実装に当たっては https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion を大いに参考にしました。
|
6 |
+
|
7 |
+
学習したモデルはWeb UIでもそのまま使えます。なお恐らくSD2.xにも対応していますが現時点では未テストです。
|
8 |
+
|
9 |
+
# 学習の手順
|
10 |
+
|
11 |
+
あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。
|
12 |
+
|
13 |
+
## データの準備
|
14 |
+
|
15 |
+
[学習データの準備について](./train_README-ja.md) を参照してください。
|
16 |
+
|
17 |
+
## 学習の実行
|
18 |
+
|
19 |
+
``train_textual_inversion.py`` を用います。以下はコマンドラインの例です(DreamBooth手法)。
|
20 |
+
|
21 |
+
```
|
22 |
+
accelerate launch --num_cpu_threads_per_process 1 train_textual_inversion.py
|
23 |
+
--dataset_config=<データ準備で作成した.tomlファイル>
|
24 |
+
--output_dir=<学習したモデルの出力先フォルダ>
|
25 |
+
--output_name=<学習したモデル出力時のファイル名>
|
26 |
+
--save_model_as=safetensors
|
27 |
+
--prior_loss_weight=1.0
|
28 |
+
--max_train_steps=1600
|
29 |
+
--learning_rate=1e-6
|
30 |
+
--optimizer_type="AdamW8bit"
|
31 |
+
--xformers
|
32 |
+
--mixed_precision="fp16"
|
33 |
+
--cache_latents
|
34 |
+
--gradient_checkpointing
|
35 |
+
--token_string=mychar4 --init_word=cute --num_vectors_per_token=4
|
36 |
+
```
|
37 |
+
|
38 |
+
``--token_string`` に学習時のトークン文字列を指定します。__学習時のプロンプトは、この文字列を含むようにしてください(token_stringがmychar4なら、``mychar4 1girl`` など)__。プロンプトのこの文字列の部分が、Textual Inversionの新しいtokenに置換されて学習されます。DreamBooth, class+identifier形式のデータセットとして、`token_string` をトークン文字列にするのが最も簡単で確実です。
|
39 |
+
|
40 |
+
プロンプトにトークン文字列が含まれているかどうかは、``--debug_dataset`` で置換後のtoken idが表示されますので、以下のように ``49408`` 以降のtokenが存在するかどうかで確認できます。
|
41 |
+
|
42 |
+
```
|
43 |
+
input ids: tensor([[49406, 49408, 49409, 49410, 49411, 49412, 49413, 49414, 49415, 49407,
|
44 |
+
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
45 |
+
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
46 |
+
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
47 |
+
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
48 |
+
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
49 |
+
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
50 |
+
49407, 49407, 49407, 49407, 49407, 49407, 49407]])
|
51 |
+
```
|
52 |
+
|
53 |
+
tokenizerがすでに持っている単語(一般的な単語)は使用できません。
|
54 |
+
|
55 |
+
``--init_word`` にembeddingsを初期化するときのコピー元トークンの文字列を指定します。学ばせたい概念が近いものを選ぶとよいようです。二つ以上のトークンになる文字列は指定できません。
|
56 |
+
|
57 |
+
``--num_vectors_per_token`` にいくつのトークンをこの学習で使うかを指定します。多いほうが表現力が増しますが、その分多くのトークンを消費します。たとえばnum_vectors_per_token=8の場合、指定したトークン文字列は(一般的なプロンプトの77トークン制限のうち)8トークンを消費します。
|
58 |
+
|
59 |
+
以上がTextual Inversionのための主なオプションです。以降は他の学習スクリプトと同様です。
|
60 |
+
|
61 |
+
`num_cpu_threads_per_process` には通常は1を指定するとよいようです。
|
62 |
+
|
63 |
+
`pretrained_model_name_or_path` に追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。
|
64 |
+
|
65 |
+
`output_dir` に学習後のモデルを保存するフォルダを指定します。`output_name` にモデルのファイル名を拡張子を除いて指定します。`save_model_as` でsafetensors形式での保存を指定しています。
|
66 |
+
|
67 |
+
`dataset_config` に `.toml` ファイルを指定します。ファイル内でのバッチサイズ指定は、当初はメモリ消費を抑えるために `1` としてください。
|
68 |
+
|
69 |
+
学習させるステップ数 `max_train_steps` を10000とします。学習率 `learning_rate` はここでは5e-6を指定しています。
|
70 |
+
|
71 |
+
省メモリ化のため `mixed_precision="fp16"` を指定します(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。また `gradient_checkpointing` を指定します。
|
72 |
+
|
73 |
+
オプティマイザ(���デルを学習データにあうように最適化=学習させるクラス)にメモリ消費の少ない 8bit AdamW を使うため、 `optimizer_type="AdamW8bit"` を指定します。
|
74 |
+
|
75 |
+
`xformers` オプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。
|
76 |
+
|
77 |
+
ある程度メモリがある場合は、`.toml` ファイルを編集してバッチサイズをたとえば `8` くらいに増やしてください(高速化と精度向上の可能性があります)。
|
78 |
+
|
79 |
+
### よく使われるオプションについて
|
80 |
+
|
81 |
+
以下の場合にはオプションに関するドキュメントを参照してください。
|
82 |
+
|
83 |
+
- Stable Diffusion 2.xまたはそこからの派生モデルを学習する
|
84 |
+
- clip skipを2以上を前提としたモデルを学習する
|
85 |
+
- 75トークンを超えたキャプションで学習する
|
86 |
+
|
87 |
+
### Textual Inversionでのバッチサイズについて
|
88 |
+
|
89 |
+
モデル全体を学習するDreamBoothやfine tuningに比べてメモリ使用量が少ないため、バッチサイズは大きめにできます。
|
90 |
+
|
91 |
+
# Textual Inversionのその他の主なオプション
|
92 |
+
|
93 |
+
すべてのオプションについては別文書を参照してください。
|
94 |
+
|
95 |
+
* `--weights`
|
96 |
+
* 学習前に学習済みのembeddingsを読み込み、そこから追加で学習します。
|
97 |
+
* `--use_object_template`
|
98 |
+
* キャプションではなく既定の物体用テンプレート文字列(``a photo of a {}``など)で学習します。公式実装と同じになります。キャプションは無視されます。
|
99 |
+
* `--use_style_template`
|
100 |
+
* キャプションではなく既定のスタイル用テンプレート文字列で学習します(``a painting in the style of {}``など)。公式実装と同じになります。キャプションは無視されます。
|
101 |
+
|
102 |
+
## 当リポジトリ内の画像生成スクリプトで生成する
|
103 |
+
|
104 |
+
gen_img_diffusers.pyに、``--textual_inversion_embeddings`` オプションで学習したembeddingsファイルを指定してください(複数可)。プロンプトでembeddingsファイルのファイル名(拡張子を除く)を使うと、そのembeddingsが適用されます。
|
105 |
+
|