abc commited on
Commit
3249d87
·
1 Parent(s): 74be2a5

Upload 55 files

Browse files
Files changed (46) hide show
  1. .gitattributes +1 -0
  2. .github/workflows/typos.yml +21 -0
  3. .gitignore +7 -0
  4. LICENSE.md +201 -0
  5. README-ja.md +147 -0
  6. README.md +230 -0
  7. append_module.py +378 -56
  8. bitsandbytes_windows/cextension.py +54 -0
  9. bitsandbytes_windows/libbitsandbytes_cpu.dll +0 -0
  10. bitsandbytes_windows/libbitsandbytes_cuda116.dll +3 -0
  11. bitsandbytes_windows/main.py +166 -0
  12. config_README-ja.md +279 -0
  13. fine_tune.py +50 -45
  14. fine_tune_README_ja.md +140 -0
  15. finetune/blip/blip.py +240 -0
  16. finetune/blip/med.py +955 -0
  17. finetune/blip/med_config.json +22 -0
  18. finetune/blip/vit.py +305 -0
  19. finetune/clean_captions_and_tags.py +184 -0
  20. finetune/hypernetwork_nai.py +96 -0
  21. finetune/make_captions.py +162 -0
  22. finetune/make_captions_by_git.py +145 -0
  23. finetune/merge_captions_to_metadata.py +67 -0
  24. finetune/merge_dd_tags_to_metadata.py +62 -0
  25. finetune/prepare_buckets_latents.py +261 -0
  26. finetune/tag_images_by_wd14_tagger.py +200 -0
  27. gen_img_diffusers.py +234 -55
  28. library/model_util.py +5 -1
  29. library/train_util.py +853 -229
  30. networks/check_lora_weights.py +1 -1
  31. networks/extract_lora_from_models.py +44 -25
  32. networks/lora.py +191 -30
  33. networks/merge_lora.py +11 -5
  34. networks/resize_lora.py +187 -50
  35. networks/svd_merge_lora.py +40 -18
  36. requirements.txt +2 -0
  37. tools/canny.py +24 -0
  38. tools/original_control_net.py +320 -0
  39. train_README-ja.md +936 -0
  40. train_db.py +47 -45
  41. train_db_README-ja.md +167 -0
  42. train_network.py +248 -175
  43. train_network_README-ja.md +269 -0
  44. train_network_opt.py +324 -373
  45. train_textual_inversion.py +72 -58
  46. train_ti_README-ja.md +105 -0
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ bitsandbytes_windows/libbitsandbytes_cuda116.dll filter=lfs diff=lfs merge=lfs -text
.github/workflows/typos.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # yamllint disable rule:line-length
3
+ name: Typos
4
+
5
+ on: # yamllint disable-line rule:truthy
6
+ push:
7
+ pull_request:
8
+ types:
9
+ - opened
10
+ - synchronize
11
+ - reopened
12
+
13
+ jobs:
14
+ build:
15
+ runs-on: ubuntu-latest
16
+
17
+ steps:
18
+ - uses: actions/checkout@v3
19
+
20
+ - name: typos-action
21
+ uses: crate-ci/[email protected]
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ logs
2
+ __pycache__
3
+ wd14_tagger_model
4
+ venv
5
+ *.egg-info
6
+ build
7
+ .vscode
LICENSE.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2022] [kohya-ss]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README-ja.md ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## リポジトリについて
2
+ Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。
3
+
4
+ [README in English](./README.md) ←更新情報はこちらにあります
5
+
6
+ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。
7
+
8
+ 以下のスクリプトがあります。
9
+
10
+ * DreamBooth、U-NetおよびText Encoderの学習をサポート
11
+ * fine-tuning、同上
12
+ * 画像生成
13
+ * モデル変換(Stable Diffision ckpt/safetensorsとDiffusersの相互変換)
14
+
15
+ ## 使用法について
16
+
17
+ 当リポジトリ内およびnote.comに記事がありますのでそちらをご覧ください(将来的にはすべてこちらへ移すかもしれません)。
18
+
19
+ * [学習について、共通編](./train_README-ja.md) : データ整備やオプションなど
20
+ * [データセット設定](./config_README-ja.md)
21
+ * [DreamBoothの学習について](./train_db_README-ja.md)
22
+ * [fine-tuningのガイド](./fine_tune_README_ja.md):
23
+ * [LoRAの学習について](./train_network_README-ja.md)
24
+ * [Textual Inversionの学習について](./train_ti_README-ja.md)
25
+ * note.com [画像生成スクリプト](https://note.com/kohya_ss/n/n2693183a798e)
26
+ * note.com [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad)
27
+
28
+ ## Windowsでの動作に必要なプログラム
29
+
30
+ Python 3.10.6およびGitが必要です。
31
+
32
+ - Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
33
+ - git: https://git-scm.com/download/win
34
+
35
+ PowerShellを使う場合、venvを使えるようにするためには以下の手順でセキュリティ設定を変更してください。
36
+ (venvに限らずスクリプトの実行が可能になりますので注意してください。)
37
+
38
+ - PowerShellを管理者として開きます。
39
+ - 「Set-ExecutionPolicy Unrestricted」と入力し、Yと答えます。
40
+ - 管理者のPowerShellを閉じます。
41
+
42
+ ## Windows環境でのインストール
43
+
44
+ 以下の例ではPyTorchは1.12.1/CUDA 11.6版をインストールします。CUDA 11.3版やPyTorch 1.13を使う場合は適宜書き換えください。
45
+
46
+ (なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。)
47
+
48
+ 通常の(管理者ではない)PowerShellを開き以下を順に実行します。
49
+
50
+ ```powershell
51
+ git clone https://github.com/kohya-ss/sd-scripts.git
52
+ cd sd-scripts
53
+
54
+ python -m venv venv
55
+ .\venv\Scripts\activate
56
+
57
+ pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
58
+ pip install --upgrade -r requirements.txt
59
+ pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
60
+
61
+ cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
62
+ cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
63
+ cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
64
+
65
+ accelerate config
66
+ ```
67
+
68
+ <!--
69
+ pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
70
+ pip install --use-pep517 --upgrade -r requirements.txt
71
+ pip install -U -I --no-deps xformers==0.0.16
72
+ -->
73
+
74
+ コマンドプロンプトでは以下になります。
75
+
76
+
77
+ ```bat
78
+ git clone https://github.com/kohya-ss/sd-scripts.git
79
+ cd sd-scripts
80
+
81
+ python -m venv venv
82
+ .\venv\Scripts\activate
83
+
84
+ pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
85
+ pip install --upgrade -r requirements.txt
86
+ pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
87
+
88
+ copy /y .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
89
+ copy /y .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
90
+ copy /y .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
91
+
92
+ accelerate config
93
+ ```
94
+
95
+ (注:``python -m venv venv`` のほうが ``python -m venv --system-site-packages venv`` より安全そうなため書き換えました。globalなpythonにパッケージがインストールしてあると、後者だといろいろと問題が起きます。)
96
+
97
+ accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。)
98
+
99
+ ※0.15.0から日本語環境では選択のためにカーソルキーを押すと落ちます(……)。数字キーの0、1、2……で選択できますので、そちらを使ってください。
100
+
101
+ ```txt
102
+ - This machine
103
+ - No distributed training
104
+ - NO
105
+ - NO
106
+ - NO
107
+ - all
108
+ - fp16
109
+ ```
110
+
111
+ ※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問(
112
+ ``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。)
113
+
114
+ ### PyTorchとxformersのバージョンについて
115
+
116
+ 他のバージョンでは学習がうまくいかない場合があるようです。特に他の理由がなければ指定のバージョンをお使いください。
117
+
118
+ ## アップグレード
119
+
120
+ 新しいリリースがあった場合、以下のコマンドで更新できます。
121
+
122
+ ```powershell
123
+ cd sd-scripts
124
+ git pull
125
+ .\venv\Scripts\activate
126
+ pip install --use-pep517 --upgrade -r requirements.txt
127
+ ```
128
+
129
+ コマンドが成功すれば新しいバージョンが使用できます。
130
+
131
+ ## 謝意
132
+
133
+ LoRAの実装は[cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を基にしたものです。感謝申し上げます。
134
+
135
+ Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) が最初にリリースし、KohakuBlueleaf氏が [LoCon](https://github.com/KohakuBlueleaf/LoCon) でその有効性を明らかにしたものです。KohakuBlueleaf氏に深く感謝します。
136
+
137
+ ## ライセンス
138
+
139
+ スクリプトのライセンスはASL 2.0ですが(Diffusersおよびcloneofsimo氏のリポジトリ由来のものも同様)、一部他のライセンスのコードを含みます。
140
+
141
+ [Memory Efficient Attention Pytorch](https://github.com/lucidrains/memory-efficient-attention-pytorch): MIT
142
+
143
+ [bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT
144
+
145
+ [BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
146
+
147
+
README.md ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This repository contains training, generation and utility scripts for Stable Diffusion.
2
+
3
+ [__Change History__](#change-history) is moved to the bottom of the page.
4
+ 更新履歴は[ページ末尾](#change-history)に移しました。
5
+
6
+ [日本語版README](./README-ja.md)
7
+
8
+ For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais!
9
+
10
+ This repository contains the scripts for:
11
+
12
+ * DreamBooth training, including U-Net and Text Encoder
13
+ * Fine-tuning (native training), including U-Net and Text Encoder
14
+ * LoRA training
15
+ * Texutl Inversion training
16
+ * Image generation
17
+ * Model conversion (supports 1.x and 2.x, Stable Diffision ckpt/safetensors and Diffusers)
18
+
19
+ __Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ (SD 1.x based only) Thank you for great work!!!
20
+
21
+ ## About requirements.txt
22
+
23
+ These files do not contain requirements for PyTorch. Because the versions of them depend on your environment. Please install PyTorch at first (see installation guide below.)
24
+
25
+ The scripts are tested with PyTorch 1.12.1 and 1.13.0, Diffusers 0.10.2.
26
+
27
+ ## Links to how-to-use documents
28
+
29
+ All documents are in Japanese currently.
30
+
31
+ * [Training guide - common](./train_README-ja.md) : data preparation, options etc...
32
+ * [Dataset config](./config_README-ja.md)
33
+ * [DreamBooth training guide](./train_db_README-ja.md)
34
+ * [Step by Step fine-tuning guide](./fine_tune_README_ja.md):
35
+ * [training LoRA](./train_network_README-ja.md)
36
+ * [training Textual Inversion](./train_ti_README-ja.md)
37
+ * note.com [Image generation](https://note.com/kohya_ss/n/n2693183a798e)
38
+ * note.com [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad)
39
+
40
+ ## Windows Required Dependencies
41
+
42
+ Python 3.10.6 and Git:
43
+
44
+ - Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
45
+ - git: https://git-scm.com/download/win
46
+
47
+ Give unrestricted script access to powershell so venv can work:
48
+
49
+ - Open an administrator powershell window
50
+ - Type `Set-ExecutionPolicy Unrestricted` and answer A
51
+ - Close admin powershell window
52
+
53
+ ## Windows Installation
54
+
55
+ Open a regular Powershell terminal and type the following inside:
56
+
57
+ ```powershell
58
+ git clone https://github.com/kohya-ss/sd-scripts.git
59
+ cd sd-scripts
60
+
61
+ python -m venv venv
62
+ .\venv\Scripts\activate
63
+
64
+ pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
65
+ pip install --upgrade -r requirements.txt
66
+ pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
67
+
68
+ cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
69
+ cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
70
+ cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
71
+
72
+ accelerate config
73
+ ```
74
+
75
+ update: ``python -m venv venv`` is seemed to be safer than ``python -m venv --system-site-packages venv`` (some user have packages in global python).
76
+
77
+ Answers to accelerate config:
78
+
79
+ ```txt
80
+ - This machine
81
+ - No distributed training
82
+ - NO
83
+ - NO
84
+ - NO
85
+ - all
86
+ - fp16
87
+ ```
88
+
89
+ note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is occurred in training. In this case, answer `0` for the 6th question:
90
+ ``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``
91
+
92
+ (Single GPU with id `0` will be used.)
93
+
94
+ ### about PyTorch and xformers
95
+
96
+ Other versions of PyTorch and xformers seem to have problems with training.
97
+ If there is no other reason, please install the specified version.
98
+
99
+ ## Upgrade
100
+
101
+ When a new release comes out you can upgrade your repo with the following command:
102
+
103
+ ```powershell
104
+ cd sd-scripts
105
+ git pull
106
+ .\venv\Scripts\activate
107
+ pip install --use-pep517 --upgrade -r requirements.txt
108
+ ```
109
+
110
+ Once the commands have completed successfully you should be ready to use the new version.
111
+
112
+ ## Credits
113
+
114
+ The implementation for LoRA is based on [cloneofsimo's repo](https://github.com/cloneofsimo/lora). Thank you for great work!
115
+
116
+ The LoRA expansion to Conv2d 3x3 was initially released by cloneofsimo and its effectiveness was demonstrated at [LoCon](https://github.com/KohakuBlueleaf/LoCon) by KohakuBlueleaf. Thank you so much KohakuBlueleaf!
117
+
118
+ ## License
119
+
120
+ The majority of scripts is licensed under ASL 2.0 (including codes from Diffusers, cloneofsimo's and LoCon), however portions of the project are available under separate license terms:
121
+
122
+ [Memory Efficient Attention Pytorch](https://github.com/lucidrains/memory-efficient-attention-pytorch): MIT
123
+
124
+ [bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT
125
+
126
+ [BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
127
+
128
+ ## Change History
129
+
130
+ - 11 Mar. 2023, 2023/3/11:
131
+ - Fix `svd_merge_lora.py` causes an error about the device.
132
+ - `svd_merge_lora.py` でデバイス関連のエラーが発生する不具合を修正しました。
133
+ - 10 Mar. 2023, 2023/3/10: release v0.5.1
134
+ - Fix to LoRA modules in the model are same to the previous (before 0.5.0) if Conv2d-3x3 is disabled (no `conv_dim` arg, default).
135
+ - Conv2D with kernel size 1x1 in ResNet modules were accidentally included in v0.5.0.
136
+ - Trained models with v0.5.0 will work with Web UI's built-in LoRA and Additional Networks extension.
137
+ - Fix an issue that dim (rank) of LoRA module is limited to the in/out dimensions of the target Linear/Conv2d (in case of the dim > 320).
138
+ - `resize_lora.py` now have a feature to `dynamic resizing` which means each LoRA module can have different ranks (dims). Thanks to mgz-dev for this great work!
139
+ - The appropriate rank is selected based on the complexity of each module with an algorithm specified in the command line arguments. For details: https://github.com/kohya-ss/sd-scripts/pull/243
140
+ - Multiple GPUs training is finally supported in `train_network.py`. Thanks to ddPn08 to solve this long running issue!
141
+ - Dataset with fine-tuning method (with metadata json) now works without images if `.npz` files exist. Thanks to rvhfxb!
142
+ - `train_network.py` can work if the current directory is not the directory where the script is in. Thanks to mio2333!
143
+ - Fix `extract_lora_from_models.py` and `svd_merge_lora.py` doesn't work with higher rank (>320).
144
+
145
+ - LoRAのConv2d-3x3拡張を行わない場合(`conv_dim` を指定しない場合)、以前(v0.5.0)と同じ構成になるよう修正しました。
146
+ - ResNetのカーネルサイズ1x1のConv2dが誤って対象になっていました。
147
+ - ただv0.5.0で学習したモデルは Additional Networks 拡張、およびWeb UIのLoRA機能で問題なく使えると思われます。
148
+ - LoRAモジュールの dim (rank) が、対象モジュールの次元数以下に制限される不具合を修正しました(320より大きい dim を指定した場合)。
149
+ - `resize_lora.py` に `dynamic resizing` (リサイズ後の各LoRAモジュールが異なるrank (dim) を持てる機能)を追加しました。mgz-dev 氏の貢献に感謝します。
150
+ - 適切なランクがコマンドライン引数で指定したアルゴリズムにより自動的に選択されます。詳細はこちらをご覧ください: https://github.com/kohya-ss/sd-scripts/pull/243
151
+ - `train_network.py` でマルチGPU学習をサポートしました。長年の懸案を解決された ddPn08 氏に感謝します。
152
+ - fine-tuning方式のデータセット(メタデータ.jsonファイルを使うデータセット)で `.npz` が存在するときには画像がなくても動作するようになりました。rvhfxb 氏に感謝します。
153
+ - 他のディレクトリから `train_network.py` を呼び出しても動作するよう変更しました。 mio2333 氏に感謝します。
154
+ - `extract_lora_from_models.py` および `svd_merge_lora.py` が320より大きいrankを指定すると動かない不具合を修正しました。
155
+
156
+ - 9 Mar. 2023, 2023/3/9: release v0.5.0
157
+ - There may be problems due to major changes. If you cannot revert back to the previous version when problems occur, please do not update for a while.
158
+ - Minimum metadata (module name, dim, alpha and network_args) is recorded even with `--no_metadata`, issue https://github.com/kohya-ss/sd-scripts/issues/254
159
+ - `train_network.py` supports LoRA for Conv2d-3x3 (extended to conv2d with a kernel size not 1x1).
160
+ - Same as a current version of [LoCon](https://github.com/KohakuBlueleaf/LoCon). __Thank you very much KohakuBlueleaf for your help!__
161
+ - LoCon will be enhanced in the future. Compatibility for future versions is not guaranteed.
162
+ - Specify `--network_args` option like: `--network_args "conv_dim=4" "conv_alpha=1"`
163
+ - [Additional Networks extension](https://github.com/kohya-ss/sd-webui-additional-networks) version 0.5.0 or later is required to use 'LoRA for Conv2d-3x3' in Stable Diffusion web UI.
164
+ - __Stable Diffusion web UI built-in LoRA does not support 'LoRA for Conv2d-3x3' now. Consider carefully whether or not to use it.__
165
+ - Merging/extracting scripts also support LoRA for Conv2d-3x3.
166
+ - Free CUDA memory after sample generation to reduce VRAM usage, issue https://github.com/kohya-ss/sd-scripts/issues/260
167
+ - Empty caption doesn't cause error now, issue https://github.com/kohya-ss/sd-scripts/issues/258
168
+ - Fix sample generation is crashing in Textual Inversion training when using templates, or if height/width is not divisible by 8.
169
+ - Update documents (Japanese only).
170
+
171
+ - 大きく変更したため不具合があるかもしれません。問題が起きた時にスクリプトを前のバージョンに戻せない場合は、しばらく更新を控えてください。
172
+ - 最低限のメタデータ(module name, dim, alpha および network_args)が `--no_metadata` オプション指定時にも記録されます。issue https://github.com/kohya-ss/sd-scripts/issues/254
173
+ - `train_network.py` で LoRAの Conv2d-3x3 拡張に対応しました(カーネルサイズ1x1以外のConv2dにも対象範囲を拡大します)。
174
+ - 現在のバージョンの [LoCon](https://github.com/KohakuBlueleaf/LoCon) と同一の仕様です。__KohakuBlueleaf氏のご支援に深く感謝します。__
175
+ - LoCon が将来的に拡張された場合、それらのバージョンでの互換性は保証できません。
176
+ - `--network_args` オプションを `--network_args "conv_dim=4" "conv_alpha=1"` のように指定してください。
177
+ - Stable Diffusion web UI での使用には [Additional Networks extension](https://github.com/kohya-ss/sd-webui-additional-networks) のversion 0.5.0 以降が必要です。
178
+ - __Stable Diffusion web UI の LoRA 機能は LoRAの Conv2d-3x3 拡張に対応していないようです。使用するか否か慎重にご検討ください。__
179
+ - マージ、抽出のスクリプトについても LoRA の Conv2d-3x3 拡張に対応しました.
180
+ - サンプル画像生成後にCUDAメモリを解放しVRAM使用量を削減しました。 issue https://github.com/kohya-ss/sd-scripts/issues/260
181
+ - 空のキャプションが使えるようになりました。 issue https://github.com/kohya-ss/sd-scripts/issues/258
182
+ - Textual Inversion 学習でテンプレートを使ったとき、height/width が 8 で割り切れなかったときにサンプル画像生成がクラッシュするのを修正しました。
183
+ - ドキュメント類を更新しました。
184
+
185
+ - Sample image generation:
186
+ A prompt file might look like this, for example
187
+
188
+ ```
189
+ # prompt 1
190
+ masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
191
+
192
+ # prompt 2
193
+ masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
194
+ ```
195
+
196
+ Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used.
197
+
198
+ * `--n` Negative prompt up to the next option.
199
+ * `--w` Specifies the width of the generated image.
200
+ * `--h` Specifies the height of the generated image.
201
+ * `--d` Specifies the seed of the generated image.
202
+ * `--l` Specifies the CFG scale of the generated image.
203
+ * `--s` Specifies the number of steps in the generation.
204
+
205
+ The prompt weighting such as `( )` and `[ ]` are not working.
206
+
207
+ - サンプル画像生成:
208
+ プロンプトファイルは例えば以下のようになります。
209
+
210
+ ```
211
+ # prompt 1
212
+ masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
213
+
214
+ # prompt 2
215
+ masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
216
+ ```
217
+
218
+ `#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
219
+
220
+ * `--n` Negative prompt up to the next option.
221
+ * `--w` Specifies the width of the generated image.
222
+ * `--h` Specifies the height of the generated image.
223
+ * `--d` Specifies the seed of the generated image.
224
+ * `--l` Specifies the CFG scale of the generated image.
225
+ * `--s` Specifies the number of steps in the generation.
226
+
227
+ `( )` や `[ ]` などの重みづけは動作しません。
228
+
229
+ Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
230
+ 最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
append_module.py CHANGED
@@ -2,7 +2,19 @@ import argparse
2
  import json
3
  import shutil
4
  import time
5
- from typing import Dict, List, NamedTuple, Tuple
 
 
 
 
 
 
 
 
 
 
 
 
6
  from accelerate import Accelerator
7
  from torch.autograd.function import Function
8
  import glob
@@ -28,6 +40,7 @@ import safetensors.torch
28
 
29
  import library.model_util as model_util
30
  import library.train_util as train_util
 
31
 
32
  #============================================================================================================
33
  #AdafactorScheduleに暫定的にinitial_lrを層別に適用できるようにしたもの
@@ -115,6 +128,124 @@ def make_bucket_resolutions_fix(max_reso, min_reso, min_size=256, max_size=1024,
115
  return area_size_resos_list, area_size_list
116
 
117
  #============================================================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  #train_util 内より
119
  #============================================================================================================
120
  class BucketManager_append(train_util.BucketManager):
@@ -179,7 +310,7 @@ class BucketManager_append(train_util.BucketManager):
179
  bucket_size_id_list.append(bucket_size_id + i + 1)
180
  _min_error = 1000.
181
  _min_id = bucket_size_id
182
- for now_size_id in bucket_size_id:
183
  self.predefined_aspect_ratios = self.predefined_aspect_ratios_list[now_size_id]
184
  ar_errors = self.predefined_aspect_ratios - aspect_ratio
185
  ar_error = np.abs(ar_errors).min()
@@ -253,13 +384,13 @@ class BucketManager_append(train_util.BucketManager):
253
  return reso, resized_size, ar_error
254
 
255
  class DreamBoothDataset(train_util.DreamBoothDataset):
256
- def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset, min_resolution=None, area_step=None) -> None:
257
  print("use append DreamBoothDataset")
258
  self.min_resolution = min_resolution
259
  self.area_step = area_step
260
- super().__init__(batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens,
261
- resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight,
262
- flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
263
  def make_buckets(self):
264
  '''
265
  bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
@@ -352,40 +483,50 @@ class DreamBoothDataset(train_util.DreamBoothDataset):
352
  self.shuffle_buckets()
353
  self._length = len(self.buckets_indices)
354
 
355
- class FineTuningDataset(train_util.FineTuningDataset):
356
- def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
357
- train_util.glob_images = glob_images
358
- super().__init__( json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
359
- resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range,
360
- random_crop, dataset_repeats, debug_dataset)
361
-
362
- def glob_images(directory, base="*", npz_flag=True):
363
- img_paths = []
364
- dots = []
365
- for ext in train_util.IMAGE_EXTENSIONS:
366
- dots.append(ext)
367
- if npz_flag:
368
- dots.append(".npz")
369
- for ext in dots:
370
- if base == '*':
371
- img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
372
- else:
373
- img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
374
- return img_paths
375
-
376
  #============================================================================================================
377
  #networks.lora
378
  #============================================================================================================
379
- from networks.lora import LoRANetwork
380
- def replace_prepare_optimizer_params(networks):
381
- def prepare_optimizer_params(self, text_encoder_lr, unet_lr, scheduler_lr=None, loranames=None):
382
-
383
  def enumerate_params(loras, lora_name=None):
384
  params = []
385
  for lora in loras:
386
  if lora_name is not None:
387
- if lora_name in lora.lora_name:
388
- params.extend(lora.parameters())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  else:
390
  params.extend(lora.parameters())
391
  return params
@@ -393,6 +534,7 @@ def replace_prepare_optimizer_params(networks):
393
  self.requires_grad_(True)
394
  all_params = []
395
  ret_scheduler_lr = []
 
396
 
397
  if loranames is not None:
398
  textencoder_names = [None]
@@ -405,37 +547,181 @@ def replace_prepare_optimizer_params(networks):
405
  if self.text_encoder_loras:
406
  for textencoder_name in textencoder_names:
407
  param_data = {'params': enumerate_params(self.text_encoder_loras, lora_name=textencoder_name)}
 
408
  if text_encoder_lr is not None:
409
  param_data['lr'] = text_encoder_lr
410
- if scheduler_lr is not None:
411
- ret_scheduler_lr.append(scheduler_lr[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  all_params.append(param_data)
413
 
414
  if self.unet_loras:
415
  for unet_name in unet_names:
416
  param_data = {'params': enumerate_params(self.unet_loras, lora_name=unet_name)}
 
 
417
  if unet_lr is not None:
418
  param_data['lr'] = unet_lr
419
- if scheduler_lr is not None:
420
- ret_scheduler_lr.append(scheduler_lr[1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  all_params.append(param_data)
422
 
423
- return all_params, ret_scheduler_lr
424
-
425
- LoRANetwork.prepare_optimizer_params = prepare_optimizer_params
 
 
426
 
427
  #============================================================================================================
428
  #新規追加
429
  #============================================================================================================
430
  def add_append_arguments(parser: argparse.ArgumentParser):
431
  # for train_network_opt.py
432
- parser.add_argument("--optimizer", type=str, default="AdamW", choices=["AdamW", "RAdam", "AdaBound", "AdaBelief", "AggMo", "AdamP", "Adastand", "Adastand_belief", "Apollo", "Lamb", "Ranger", "RangerVA", "Lookahead_Adam", "Lookahead_DiffGrad", "Yogi", "NovoGrad", "QHAdam", "DiffGrad", "MADGRAD", "Adafactor"], help="使用するoptimizerを指定する")
433
- parser.add_argument("--optimizer_arg", type=str, default=None, nargs='*')
 
 
434
  parser.add_argument("--split_lora_networks", action="store_true")
435
  parser.add_argument("--split_lora_level", type=int, default=0, help="どれくらい細分化するかの設定 0がunetのみを層別に 1がunetを大枠で分割 2がtextencoder含めて層別")
 
 
436
  parser.add_argument("--min_resolution", type=str, default=None)
437
  parser.add_argument("--area_step", type=int, default=1)
438
  parser.add_argument("--config", type=str, default=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
 
440
  def create_split_names(split_flag, split_level):
441
  split_names = None
@@ -446,14 +732,28 @@ def create_split_names(split_flag, split_level):
446
  if split_level==1:
447
  unet_names.append(f"lora_unet_down_blocks_")
448
  unet_names.append(f"lora_unet_up_blocks_")
449
- elif split_level==2 or split_level==0:
450
- if split_level==2:
451
  text_encoder_names = []
452
  for i in range(12):
453
  text_encoder_names.append(f"lora_te_text_model_encoder_layers_{i}_")
454
- for i in range(3):
455
- unet_names.append(f"lora_unet_down_blocks_{i}")
456
- unet_names.append(f"lora_unet_up_blocks_{i+1}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  split_names["text_encoder"] = text_encoder_names
458
  split_names["unet"] = unet_names
459
  return split_names
@@ -465,7 +765,7 @@ def get_config(parser):
465
  import datetime
466
  if os.path.splitext(args.config)[-1] == ".yaml":
467
  args.config = os.path.splitext(args.config)[0]
468
- config_path = f"./{args.config}.yaml"
469
  if os.path.exists(config_path):
470
  print(f"{config_path} から設定を読み込み中...")
471
  margs, rest = parser.parse_known_args()
@@ -486,19 +786,41 @@ def get_config(parser):
486
  args_type_dic[key] = act.type
487
  #データタイプの確認とargsにkeyの内容を代入していく
488
  for key, v in configs.items():
489
- if key in args_dic:
490
- if args_dic[key] is not None:
491
- new_type = type(args_dic[key])
492
- if (not type(v) == new_type) and (not new_type==list):
493
- v = new_type(v)
494
- else:
495
- if v is not None:
496
  if not type(v) == args_type_dic[key]:
497
  v = args_type_dic[key](v)
498
- args_dic[key] = v
499
  #最後にデフォから指定が変わってるものを変更する
500
  for key, v in change_def_dic.items():
501
  args_dic[key] = v
502
  else:
503
  print(f"{config_path} が見つかりませんでした")
504
  return args
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import json
3
  import shutil
4
  import time
5
+ from typing import (
6
+ Dict,
7
+ List,
8
+ NamedTuple,
9
+ Optional,
10
+ Sequence,
11
+ Tuple,
12
+ Union,
13
+ )
14
+ from dataclasses import (
15
+ asdict,
16
+ dataclass,
17
+ )
18
  from accelerate import Accelerator
19
  from torch.autograd.function import Function
20
  import glob
 
40
 
41
  import library.model_util as model_util
42
  import library.train_util as train_util
43
+ import library.config_util as config_util
44
 
45
  #============================================================================================================
46
  #AdafactorScheduleに暫定的にinitial_lrを層別に適用できるようにしたもの
 
128
  return area_size_resos_list, area_size_list
129
 
130
  #============================================================================================================
131
+ #config_util 内より
132
+ #============================================================================================================
133
+ @dataclass
134
+ class DreamBoothDatasetParams(config_util.DreamBoothDatasetParams):
135
+ min_resolution: Optional[Tuple[int, int]] = None
136
+ area_step : int = 2
137
+
138
+ class ConfigSanitizer(config_util.ConfigSanitizer):
139
+ #@config_util.curry
140
+ @staticmethod
141
+ def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
142
+ config_util.Schema(config_util.ExactSequence([klass, klass]))(value)
143
+ return tuple(value)
144
+
145
+ #@config_util.curry
146
+ @staticmethod
147
+ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
148
+ config_util.Schema(config_util.Any(klass, config_util.ExactSequence([klass, klass])))(value)
149
+ try:
150
+ config_util.Schema(klass)(value)
151
+ return (value, value)
152
+ except:
153
+ return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
154
+ # datasets schema
155
+ DATASET_ASCENDABLE_SCHEMA = {
156
+ "batch_size": int,
157
+ "bucket_no_upscale": bool,
158
+ "bucket_reso_steps": int,
159
+ "enable_bucket": bool,
160
+ "max_bucket_reso": int,
161
+ "min_bucket_reso": int,
162
+ "resolution": config_util.functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
163
+ "min_resolution": config_util.functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
164
+ "area_step": int,
165
+ }
166
+ def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None:
167
+ super().__init__(support_dreambooth, support_finetuning, support_dropout)
168
+ def _check(self):
169
+ print(self.db_dataset_schema)
170
+
171
+ class BlueprintGenerator(config_util.BlueprintGenerator):
172
+ def __init__(self, sanitizer: ConfigSanitizer):
173
+ config_util.DreamBoothDatasetParams = DreamBoothDatasetParams
174
+ super().__init__(sanitizer)
175
+
176
+ def generate_dataset_group_by_blueprint(dataset_group_blueprint: config_util.DatasetGroupBlueprint):
177
+ datasets: List[Union[DreamBoothDataset, train_util.FineTuningDataset]] = []
178
+
179
+ for dataset_blueprint in dataset_group_blueprint.datasets:
180
+ if dataset_blueprint.is_dreambooth:
181
+ subset_klass = train_util.DreamBoothSubset
182
+ dataset_klass = DreamBoothDataset
183
+ else:
184
+ subset_klass = train_util.FineTuningSubset
185
+ dataset_klass = train_util.FineTuningDataset
186
+
187
+ subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
188
+ dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
189
+ datasets.append(dataset)
190
+
191
+ # print info
192
+ info = ""
193
+ for i, dataset in enumerate(datasets):
194
+ is_dreambooth = isinstance(dataset, DreamBoothDataset)
195
+ info += config_util.dedent(f"""\
196
+ [Dataset {i}]
197
+ batch_size: {dataset.batch_size}
198
+ resolution: {(dataset.width, dataset.height)}
199
+ enable_bucket: {dataset.enable_bucket}
200
+ """)
201
+
202
+ if dataset.enable_bucket:
203
+ info += config_util.indent(config_util.dedent(f"""\
204
+ min_bucket_reso: {dataset.min_bucket_reso}
205
+ max_bucket_reso: {dataset.max_bucket_reso}
206
+ bucket_reso_steps: {dataset.bucket_reso_steps}
207
+ bucket_no_upscale: {dataset.bucket_no_upscale}
208
+ \n"""), " ")
209
+ else:
210
+ info += "\n"
211
+
212
+ for j, subset in enumerate(dataset.subsets):
213
+ info += config_util.indent(config_util.dedent(f"""\
214
+ [Subset {j} of Dataset {i}]
215
+ image_dir: "{subset.image_dir}"
216
+ image_count: {subset.img_count}
217
+ num_repeats: {subset.num_repeats}
218
+ shuffle_caption: {subset.shuffle_caption}
219
+ keep_tokens: {subset.keep_tokens}
220
+ caption_dropout_rate: {subset.caption_dropout_rate}
221
+ caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
222
+ caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
223
+ color_aug: {subset.color_aug}
224
+ flip_aug: {subset.flip_aug}
225
+ face_crop_aug_range: {subset.face_crop_aug_range}
226
+ random_crop: {subset.random_crop}
227
+ """), " ")
228
+
229
+ if is_dreambooth:
230
+ info += config_util.indent(config_util.dedent(f"""\
231
+ is_reg: {subset.is_reg}
232
+ class_tokens: {subset.class_tokens}
233
+ caption_extension: {subset.caption_extension}
234
+ \n"""), " ")
235
+ else:
236
+ info += config_util.indent(config_util.dedent(f"""\
237
+ metadata_file: {subset.metadata_file}
238
+ \n"""), " ")
239
+
240
+ print(info)
241
+
242
+ # make buckets first because it determines the length of dataset
243
+ for i, dataset in enumerate(datasets):
244
+ print(f"[Dataset {i}]")
245
+ dataset.make_buckets()
246
+
247
+ return train_util.DatasetGroup(datasets)
248
+ #============================================================================================================
249
  #train_util 内より
250
  #============================================================================================================
251
  class BucketManager_append(train_util.BucketManager):
 
310
  bucket_size_id_list.append(bucket_size_id + i + 1)
311
  _min_error = 1000.
312
  _min_id = bucket_size_id
313
+ for now_size_id in bucket_size_id_list:
314
  self.predefined_aspect_ratios = self.predefined_aspect_ratios_list[now_size_id]
315
  ar_errors = self.predefined_aspect_ratios - aspect_ratio
316
  ar_error = np.abs(ar_errors).min()
 
384
  return reso, resized_size, ar_error
385
 
386
  class DreamBoothDataset(train_util.DreamBoothDataset):
387
+ def __init__(self, subsets: Sequence[train_util.DreamBoothSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset, min_resolution=None, area_step=None) -> None:
388
  print("use append DreamBoothDataset")
389
  self.min_resolution = min_resolution
390
  self.area_step = area_step
391
+ super().__init__(subsets, batch_size, tokenizer, max_token_length,
392
+ resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale,
393
+ prior_loss_weight, debug_dataset)
394
  def make_buckets(self):
395
  '''
396
  bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
 
483
  self.shuffle_buckets()
484
  self._length = len(self.buckets_indices)
485
 
486
+ import transformers
487
+ from torch.optim import Optimizer
488
+ from diffusers.optimization import SchedulerType
489
+ from typing import Union
490
+ def get_scheduler_Adafactor(
491
+ name: Union[str, SchedulerType],
492
+ optimizer: Optimizer,
493
+ scheduler_arg: Dict
494
+ ):
495
+ if name.startswith("adafactor"):
496
+ assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
497
+ print(scheduler_arg)
498
+ return AdafactorSchedule_append(optimizer, **scheduler_arg)
 
 
 
 
 
 
 
 
499
  #============================================================================================================
500
  #networks.lora
501
  #============================================================================================================
502
+ #from networks.lora import LoRANetwork
503
+ def replace_prepare_optimizer_params(networks, network_module):
504
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, loranames=None, lr_dic=None, block_args_dic=None):
505
+
506
  def enumerate_params(loras, lora_name=None):
507
  params = []
508
  for lora in loras:
509
  if lora_name is not None:
510
+ get_param_flag = False
511
+ if "attentions" in lora_name or "lora_unet_up_blocks_0_resnets_2":
512
+ lora_names = [lora_name]
513
+ if "attentions" in lora_name:
514
+ lora_names.append(lora_name.replace("attentions", "resnets"))
515
+ elif "lora_unet_up_blocks_0_resnets_2" in lora_name:
516
+ lora_names.append("lora_unet_up_blocks_0_upsamplers_")
517
+ elif "lora_unet_up_blocks_1_attentions_2_" in lora_name:
518
+ lora_names.append("lora_unet_up_blocks_1_upsamplers_")
519
+ elif "lora_unet_up_blocks_2_attentions_2_" in lora_name:
520
+ lora_names.append("lora_unet_up_blocks_2_upsamplers_")
521
+
522
+ for _name in lora_names:
523
+ if _name in lora.lora_name:
524
+ get_param_flag = True
525
+ break
526
+ else:
527
+ if lora_name in lora.lora_name:
528
+ get_param_flag = True
529
+ if get_param_flag: params.extend(lora.parameters())
530
  else:
531
  params.extend(lora.parameters())
532
  return params
 
534
  self.requires_grad_(True)
535
  all_params = []
536
  ret_scheduler_lr = []
537
+ used_names = []
538
 
539
  if loranames is not None:
540
  textencoder_names = [None]
 
547
  if self.text_encoder_loras:
548
  for textencoder_name in textencoder_names:
549
  param_data = {'params': enumerate_params(self.text_encoder_loras, lora_name=textencoder_name)}
550
+ used_names.append(textencoder_name)
551
  if text_encoder_lr is not None:
552
  param_data['lr'] = text_encoder_lr
553
+ if lr_dic is not None:
554
+ if textencoder_name in lr_dic:
555
+ param_data['lr'] = lr_dic[textencoder_name]
556
+ print(f"{textencoder_name} lr: {param_data['lr']}")
557
+
558
+ if block_args_dic is not None:
559
+ if "lora_te_" in block_args_dic:
560
+ for pname, value in block_args_dic["lora_te_"].items():
561
+ param_data[pname] = value
562
+ if textencoder_name in block_args_dic:
563
+ for pname, value in block_args_dic[textencoder_name].items():
564
+ param_data[pname] = value
565
+
566
+ if text_encoder_lr is not None:
567
+ ret_scheduler_lr.append(text_encoder_lr)
568
+ else:
569
+ ret_scheduler_lr.append(0.)
570
+ if lr_dic is not None:
571
+ if textencoder_name in lr_dic:
572
+ ret_scheduler_lr[-1] = lr_dic[textencoder_name]
573
  all_params.append(param_data)
574
 
575
  if self.unet_loras:
576
  for unet_name in unet_names:
577
  param_data = {'params': enumerate_params(self.unet_loras, lora_name=unet_name)}
578
+ if len(param_data["params"])==0: continue
579
+ used_names.append(unet_name)
580
  if unet_lr is not None:
581
  param_data['lr'] = unet_lr
582
+ if lr_dic is not None:
583
+ if unet_name in lr_dic:
584
+ param_data['lr'] = lr_dic[unet_name]
585
+ print(f"{unet_name} lr: {param_data['lr']}")
586
+
587
+ if block_args_dic is not None:
588
+ if "lora_unet_" in block_args_dic:
589
+ for pname, value in block_args_dic["lora_unet_"].items():
590
+ param_data[pname] = value
591
+ if unet_name in block_args_dic:
592
+ for pname, value in block_args_dic[unet_name].items():
593
+ param_data[pname] = value
594
+
595
+ if unet_lr is not None:
596
+ ret_scheduler_lr.append(unet_lr)
597
+ else:
598
+ ret_scheduler_lr.append(0.)
599
+ if lr_dic is not None:
600
+ if unet_name in lr_dic:
601
+ ret_scheduler_lr[-1] = lr_dic[unet_name]
602
  all_params.append(param_data)
603
 
604
+ return all_params, {"initial_lr" : ret_scheduler_lr}, used_names
605
+ try:
606
+ network_module.LoRANetwork.prepare_optimizer_params = prepare_optimizer_params
607
+ except:
608
+ print("cant't replace prepare_optimizer_params")
609
 
610
  #============================================================================================================
611
  #新規追加
612
  #============================================================================================================
613
  def add_append_arguments(parser: argparse.ArgumentParser):
614
  # for train_network_opt.py
615
+ #parser.add_argument("--optimizer", type=str, default="AdamW", choices=["AdamW", "RAdam", "AdaBound", "AdaBelief", "AggMo", "AdamP", "Adastand", "Adastand_belief", "Apollo", "Lamb", "Ranger", "RangerVA", "Lookahead_Adam", "Lookahead_DiffGrad", "Yogi", "NovoGrad", "QHAdam", "DiffGrad", "MADGRAD", "Adafactor"], help="使用するoptimizerを指定する")
616
+ #parser.add_argument("--optimizer_arg", type=str, default=None, nargs='*')
617
+ parser.add_argument("--use_lookahead", action="store_true")
618
+ parser.add_argument("--lookahead_arg", type=str, nargs="*", default=None)
619
  parser.add_argument("--split_lora_networks", action="store_true")
620
  parser.add_argument("--split_lora_level", type=int, default=0, help="どれくらい細分化するかの設定 0がunetのみを層別に 1がunetを大枠で分割 2がtextencoder含めて層別")
621
+ parser.add_argument("--blocks_lr_setting", type=str, default=None)
622
+ parser.add_argument("--block_optim_args", type=str, nargs="*", default=None)
623
  parser.add_argument("--min_resolution", type=str, default=None)
624
  parser.add_argument("--area_step", type=int, default=1)
625
  parser.add_argument("--config", type=str, default=None)
626
+ parser.add_argument("--not_output_config", action="store_true")
627
+
628
+ class MyNetwork_Names:
629
+ ex_block_weight_dic = {
630
+ "BASE": ["te"],
631
+ "IN01": ["down_0_at_0","donw_0_res_0"], "IN02": ["down_0_at_1","down_0_res_1"], "IN03": ["down_0_down"],
632
+ "IN04": ["down_1_at_0","donw_1_res_0"], "IN05": ["down_1_at_1","donw_1_res_1"], "IN06": ["down_1_down"],
633
+ "IN07": ["down_2_at_0","donw_2_res_0"], "IN08": ["down_2_at_1","donw_2_res_1"], "IN09": ["down_2_down"],
634
+ "IN10": ["down_3_res_0"], "IN11": ["down_3_res_1"],
635
+ "MID": ["mid"],
636
+ "OUT00": ["up_0_res_0"], "OUT01": ["up_0_res_1"], "OUT02": ["up_0_res_2", "up_0_up"],
637
+ "OUT03": ["up_1_at_0", "up_1_res_0"], "OUT04": ["up_1_at_1", "up_1_res_1"], "OUT05": ["up_1_at_2", "up_1_res_2", "up_1_up"],
638
+ "OUT06": ["up_2_at_0", "up_2_res_0"], "OUT07": ["up_2_at_1", "up_2_res_1"], "OUT08": ["up_2_at_2", "up_2_res_2", "up_2_up"],
639
+ "OUT09": ["up_3_at_0", "up_3_res_0"], "OUT10": ["up_3_at_1", "up_3_res_1"], "OUT11": ["up_3_at_2", "up_3_res_2"],
640
+ }
641
+
642
+ blocks_name_dic = { "te": "lora_te_",
643
+ "unet": "lora_unet_",
644
+ "mid": "lora_unet_mid_block_",
645
+ "down": "lora_unet_down_blocks_",
646
+ "up": "lora_unet_up_blocks_"}
647
+ for i in range(12):
648
+ blocks_name_dic[f"te_{i}"] = f"lora_te_text_model_encoder_layers_{i}_"
649
+ for i in range(3):
650
+ blocks_name_dic[f"down_{i}"] = f"lora_unet_down_blocks_{i}"
651
+ blocks_name_dic[f"up_{i+1}"] = f"lora_unet_up_blocks_{i+1}"
652
+ for i in range(4):
653
+ for j in range(2):
654
+ if i<=2: blocks_name_dic[f"down_{i}_at_{j}"] = f"lora_unet_down_blocks_{i}_attentions_{j}_"
655
+ blocks_name_dic[f"down_{i}_res_{j}"] = f"lora_unet_down_blocks_{i}_resnets_{j}"
656
+ for j in range(3):
657
+ if i>=1: blocks_name_dic[f"up_{i}_at_{j}"] = f"lora_unet_up_blocks_{i}_attentions_{j}_"
658
+ blocks_name_dic[f"up_{i}_res_{j}"] = f"lora_unet_up_blocks_{i}_resnets_{j}"
659
+ if i<=2:
660
+ blocks_name_dic[f"down_{i}_down"] = f"lora_unet_down_blocks_{i}_downsamplers_"
661
+ blocks_name_dic[f"up_{i}_up"] = f"lora_unet_up_blocks_{i}_upsamplers_"
662
+
663
+ def create_lr_blocks(lr_setting_str=None, block_optim_args=None):
664
+ ex_block_weight_dic = MyNetwork_Names.ex_block_weight_dic
665
+ blocks_name_dic = MyNetwork_Names.blocks_name_dic
666
+
667
+ lr_dic = {}
668
+ if lr_setting_str==None or lr_setting_str=="":
669
+ pass
670
+ else:
671
+ lr_settings = lr_setting_str.replace(" ", "").split(",")
672
+ for lr_setting in lr_settings:
673
+ key, value = lr_setting.split("=")
674
+ if key in ex_block_weight_dic:
675
+ keys = ex_block_weight_dic[key]
676
+ else:
677
+ keys = [key]
678
+ for key in keys:
679
+ if key in blocks_name_dic:
680
+ new_key = blocks_name_dic[key]
681
+ lr_dic[new_key] = float(value)
682
+ if len(lr_dic)==0:
683
+ lr_dic = None
684
+
685
+ args_dic = {}
686
+ if (block_optim_args is None):
687
+ block_optim_args = []
688
+ if (len(block_optim_args)>0):
689
+ for my_arg in block_optim_args:
690
+ my_arg = my_arg.replace(" ", "")
691
+ splits = my_arg.split(":")
692
+ b_name = splits[0]
693
+
694
+ key, _value = splits[1].split("=")
695
+ value_type = float
696
+ if len(splits)==3:
697
+ if _value=="str":
698
+ value_type = str
699
+ elif _value=="int":
700
+ value_type = int
701
+ _value = splits[2]
702
+ if _value=="true" or _value=="false":
703
+ value_type = bool
704
+ if "," in _value:
705
+ _value = _value.split(",")
706
+ for i in range(len(_value)):
707
+ _value[i] = value_type(_value[i])
708
+ value=tuple(_value)
709
+ else:
710
+ value = value_type(_value)
711
+
712
+ if b_name in ex_block_weight_dic:
713
+ b_names = ex_block_weight_dic[b_name]
714
+ else:
715
+ b_names = [b_name]
716
+ for b_name in b_names:
717
+ new_b_name = blocks_name_dic[b_name]
718
+ if not new_b_name in args_dic:
719
+ args_dic[new_b_name] = {}
720
+ args_dic[new_b_name][key] = value
721
+
722
+ if len(args_dic)==0:
723
+ args_dic = None
724
+ return lr_dic, args_dic
725
 
726
  def create_split_names(split_flag, split_level):
727
  split_names = None
 
732
  if split_level==1:
733
  unet_names.append(f"lora_unet_down_blocks_")
734
  unet_names.append(f"lora_unet_up_blocks_")
735
+ elif split_level==2 or split_level==0 or split_level==4:
736
+ if split_level>=2:
737
  text_encoder_names = []
738
  for i in range(12):
739
  text_encoder_names.append(f"lora_te_text_model_encoder_layers_{i}_")
740
+
741
+ if split_level<=2:
742
+ for i in range(3):
743
+ unet_names.append(f"lora_unet_down_blocks_{i}")
744
+ unet_names.append(f"lora_unet_up_blocks_{i+1}")
745
+
746
+ if split_level>=3:
747
+ for i in range(4):
748
+ for j in range(2):
749
+ if i<=2: unet_names.append(f"lora_unet_down_blocks_{i}_attentions_{j}_")
750
+ if i== 3: unet_names.append(f"lora_unet_down_blocks_{i}_resnets_{j}")
751
+ for j in range(3):
752
+ if i>=1: unet_names.append(f"lora_unet_up_blocks_{i}_attentions_{j}_")
753
+ if i==0: unet_names.append(f"lora_unet_up_blocks_{i}_resnets_{j}")
754
+ if i<=2:
755
+ unet_names.append(f"lora_unet_down_blocks_{i}_downsamplers_")
756
+
757
  split_names["text_encoder"] = text_encoder_names
758
  split_names["unet"] = unet_names
759
  return split_names
 
765
  import datetime
766
  if os.path.splitext(args.config)[-1] == ".yaml":
767
  args.config = os.path.splitext(args.config)[0]
768
+ config_path = f"{args.config}.yaml"
769
  if os.path.exists(config_path):
770
  print(f"{config_path} から設定を読み込み中...")
771
  margs, rest = parser.parse_known_args()
 
786
  args_type_dic[key] = act.type
787
  #データタイプの確認とargsにkeyの内容を代入していく
788
  for key, v in configs.items():
789
+ if v is not None:
790
+ if key in args_dic:
791
+ if args_dic[key] is not None:
792
+ new_type = type(args_dic[key])
793
+ if (not type(v) == new_type) and (not new_type==list):
794
+ v = new_type(v)
795
+ else:
796
  if not type(v) == args_type_dic[key]:
797
  v = args_type_dic[key](v)
798
+ args_dic[key] = v
799
  #最後にデフォから指定が変わってるものを変更する
800
  for key, v in change_def_dic.items():
801
  args_dic[key] = v
802
  else:
803
  print(f"{config_path} が見つかりませんでした")
804
  return args
805
+
806
+ '''
807
+ class GradientReversalFunction(torch.autograd.Function):
808
+ @staticmethod
809
+ def forward(ctx, input_forward: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
810
+ ctx.save_for_backward(scale)
811
+ return input_forward
812
+ @staticmethod
813
+ def backward(ctx, grad_backward: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
814
+ scale, = ctx.saved_tensors
815
+ return scale * -grad_backward, None
816
+
817
+ class GradientReversal(torch.nn.Module):
818
+ def __init__(self, scale: float):
819
+ super(GradientReversal, self).__init__()
820
+ self.scale = torch.tensor(scale)
821
+ def forward(self, x: torch.Tensor, flag: bool = False) -> torch.Tensor:
822
+ if flag:
823
+ return x
824
+ else:
825
+ return GradientReversalFunction.apply(x, self.scale)
826
+ '''
bitsandbytes_windows/cextension.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes as ct
2
+ from pathlib import Path
3
+ from warnings import warn
4
+
5
+ from .cuda_setup.main import evaluate_cuda_setup
6
+
7
+
8
+ class CUDALibrary_Singleton(object):
9
+ _instance = None
10
+
11
+ def __init__(self):
12
+ raise RuntimeError("Call get_instance() instead")
13
+
14
+ def initialize(self):
15
+ binary_name = evaluate_cuda_setup()
16
+ package_dir = Path(__file__).parent
17
+ binary_path = package_dir / binary_name
18
+
19
+ if not binary_path.exists():
20
+ print(f"CUDA SETUP: TODO: compile library for specific version: {binary_name}")
21
+ legacy_binary_name = "libbitsandbytes.so"
22
+ print(f"CUDA SETUP: Defaulting to {legacy_binary_name}...")
23
+ binary_path = package_dir / legacy_binary_name
24
+ if not binary_path.exists():
25
+ print('CUDA SETUP: CUDA detection failed. Either CUDA driver not installed, CUDA not installed, or you have multiple conflicting CUDA libraries!')
26
+ print('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.')
27
+ raise Exception('CUDA SETUP: Setup Failed!')
28
+ # self.lib = ct.cdll.LoadLibrary(binary_path)
29
+ self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$
30
+ else:
31
+ print(f"CUDA SETUP: Loading binary {binary_path}...")
32
+ # self.lib = ct.cdll.LoadLibrary(binary_path)
33
+ self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$
34
+
35
+ @classmethod
36
+ def get_instance(cls):
37
+ if cls._instance is None:
38
+ cls._instance = cls.__new__(cls)
39
+ cls._instance.initialize()
40
+ return cls._instance
41
+
42
+
43
+ lib = CUDALibrary_Singleton.get_instance().lib
44
+ try:
45
+ lib.cadam32bit_g32
46
+ lib.get_context.restype = ct.c_void_p
47
+ lib.get_cusparse.restype = ct.c_void_p
48
+ COMPILED_WITH_CUDA = True
49
+ except AttributeError:
50
+ warn(
51
+ "The installed version of bitsandbytes was compiled without GPU support. "
52
+ "8-bit optimizers and GPU quantization are unavailable."
53
+ )
54
+ COMPILED_WITH_CUDA = False
bitsandbytes_windows/libbitsandbytes_cpu.dll ADDED
Binary file (76.3 kB). View file
 
bitsandbytes_windows/libbitsandbytes_cuda116.dll ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88f7bd2916ca3effc43f88492f1e1b9088d13cb5be3b4a3a4aede6aa3bf8d412
3
+ size 4724224
bitsandbytes_windows/main.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ extract factors the build is dependent on:
3
+ [X] compute capability
4
+ [ ] TODO: Q - What if we have multiple GPUs of different makes?
5
+ - CUDA version
6
+ - Software:
7
+ - CPU-only: only CPU quantization functions (no optimizer, no matrix multiple)
8
+ - CuBLAS-LT: full-build 8-bit optimizer
9
+ - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)
10
+
11
+ evaluation:
12
+ - if paths faulty, return meaningful error
13
+ - else:
14
+ - determine CUDA version
15
+ - determine capabilities
16
+ - based on that set the default path
17
+ """
18
+
19
+ import ctypes
20
+
21
+ from .paths import determine_cuda_runtime_lib_path
22
+
23
+
24
+ def check_cuda_result(cuda, result_val):
25
+ # 3. Check for CUDA errors
26
+ if result_val != 0:
27
+ error_str = ctypes.c_char_p()
28
+ cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
29
+ print(f"CUDA exception! Error code: {error_str.value.decode()}")
30
+
31
+ def get_cuda_version(cuda, cudart_path):
32
+ # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
33
+ try:
34
+ cudart = ctypes.CDLL(cudart_path)
35
+ except OSError:
36
+ # TODO: shouldn't we error or at least warn here?
37
+ print(f'ERROR: libcudart.so could not be read from path: {cudart_path}!')
38
+ return None
39
+
40
+ version = ctypes.c_int()
41
+ check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ctypes.byref(version)))
42
+ version = int(version.value)
43
+ major = version//1000
44
+ minor = (version-(major*1000))//10
45
+
46
+ if major < 11:
47
+ print('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
48
+
49
+ return f'{major}{minor}'
50
+
51
+
52
+ def get_cuda_lib_handle():
53
+ # 1. find libcuda.so library (GPU driver) (/usr/lib)
54
+ try:
55
+ cuda = ctypes.CDLL("libcuda.so")
56
+ except OSError:
57
+ # TODO: shouldn't we error or at least warn here?
58
+ print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
59
+ return None
60
+ check_cuda_result(cuda, cuda.cuInit(0))
61
+
62
+ return cuda
63
+
64
+
65
+ def get_compute_capabilities(cuda):
66
+ """
67
+ 1. find libcuda.so library (GPU driver) (/usr/lib)
68
+ init_device -> init variables -> call function by reference
69
+ 2. call extern C function to determine CC
70
+ (https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
71
+ 3. Check for CUDA errors
72
+ https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api
73
+ # bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
74
+ """
75
+
76
+
77
+ nGpus = ctypes.c_int()
78
+ cc_major = ctypes.c_int()
79
+ cc_minor = ctypes.c_int()
80
+
81
+ device = ctypes.c_int()
82
+
83
+ check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
84
+ ccs = []
85
+ for i in range(nGpus.value):
86
+ check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i))
87
+ ref_major = ctypes.byref(cc_major)
88
+ ref_minor = ctypes.byref(cc_minor)
89
+ # 2. call extern C function to determine CC
90
+ check_cuda_result(
91
+ cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device)
92
+ )
93
+ ccs.append(f"{cc_major.value}.{cc_minor.value}")
94
+
95
+ return ccs
96
+
97
+
98
+ # def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error
99
+ def get_compute_capability(cuda):
100
+ """
101
+ Extracts the highest compute capbility from all available GPUs, as compute
102
+ capabilities are downwards compatible. If no GPUs are detected, it returns
103
+ None.
104
+ """
105
+ ccs = get_compute_capabilities(cuda)
106
+ if ccs is not None:
107
+ # TODO: handle different compute capabilities; for now, take the max
108
+ return ccs[-1]
109
+ return None
110
+
111
+
112
+ def evaluate_cuda_setup():
113
+ print('')
114
+ print('='*35 + 'BUG REPORT' + '='*35)
115
+ print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
116
+ print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
117
+ print('='*80)
118
+ return "libbitsandbytes_cuda116.dll" # $$$
119
+
120
+ binary_name = "libbitsandbytes_cpu.so"
121
+ #if not torch.cuda.is_available():
122
+ #print('No GPU detected. Loading CPU library...')
123
+ #return binary_name
124
+
125
+ cudart_path = determine_cuda_runtime_lib_path()
126
+ if cudart_path is None:
127
+ print(
128
+ "WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!"
129
+ )
130
+ return binary_name
131
+
132
+ print(f"CUDA SETUP: CUDA runtime path found: {cudart_path}")
133
+ cuda = get_cuda_lib_handle()
134
+ cc = get_compute_capability(cuda)
135
+ print(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
136
+ cuda_version_string = get_cuda_version(cuda, cudart_path)
137
+
138
+
139
+ if cc == '':
140
+ print(
141
+ "WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..."
142
+ )
143
+ return binary_name
144
+
145
+ # 7.5 is the minimum CC vor cublaslt
146
+ has_cublaslt = cc in ["7.5", "8.0", "8.6"]
147
+
148
+ # TODO:
149
+ # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
150
+ # (2) Multiple CUDA versions installed
151
+
152
+ # we use ls -l instead of nvcc to determine the cuda version
153
+ # since most installations will have the libcudart.so installed, but not the compiler
154
+ print(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
155
+
156
+ def get_binary_name():
157
+ "if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
158
+ bin_base_name = "libbitsandbytes_cuda"
159
+ if has_cublaslt:
160
+ return f"{bin_base_name}{cuda_version_string}.so"
161
+ else:
162
+ return f"{bin_base_name}{cuda_version_string}_nocublaslt.so"
163
+
164
+ binary_name = get_binary_name()
165
+
166
+ return binary_name
config_README-ja.md ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ For non-Japanese speakers: this README is provided only in Japanese in the current state. Sorry for inconvenience. We will provide English version in the near future.
2
+
3
+ `--dataset_config` で渡すことができる設定ファイルに関する説明です。
4
+
5
+ ## 概要
6
+
7
+ 設定ファイルを渡すことにより、ユーザが細かい設定を行えるようにします。
8
+
9
+ * 複数のデータセットが設定可能になります
10
+ * 例えば `resolution` をデータセットごとに設定して、それらを混合して学習できます。
11
+ * DreamBooth の手法と fine tuning の手法の両方に対応している学習方法では、DreamBooth 方式と fine tuning 方式のデータセットを混合することが可能です。
12
+ * サブセットごとに設定を変更することが可能になります
13
+ * データセットを画像ディレクトリ別またはメタデータ別に分割したものがサブセットです。いくつかのサブセットが集まってデータセットを構成します。
14
+ * `keep_tokens` や `flip_aug` 等のオプションはサブセットごとに設定可能です。一方、`resolution` や `batch_size` といったオプションはデータセットごとに設定可能で、同じデータセットに属するサブセットでは値が共通になります。詳しくは後述します。
15
+
16
+ 設定ファイルの形式は JSON か TOML を利用できます。記述のしやすさを考えると [TOML](https://toml.io/ja/v1.0.0-rc.2) を利用するのがオススメです。以下、TOML の利用を前提に説明します。
17
+
18
+ TOML で記述した設定ファイルの例です。
19
+
20
+ ```toml
21
+ [general]
22
+ shuffle_caption = true
23
+ caption_extension = '.txt'
24
+ keep_tokens = 1
25
+
26
+ # これは DreamBooth 方式のデータセット
27
+ [[datasets]]
28
+ resolution = 512
29
+ batch_size = 4
30
+ keep_tokens = 2
31
+
32
+ [[datasets.subsets]]
33
+ image_dir = 'C:\hoge'
34
+ class_tokens = 'hoge girl'
35
+ # このサブセットは keep_tokens = 2 (所属する datasets の値が使われる)
36
+
37
+ [[datasets.subsets]]
38
+ image_dir = 'C:\fuga'
39
+ class_tokens = 'fuga boy'
40
+ keep_tokens = 3
41
+
42
+ [[datasets.subsets]]
43
+ is_reg = true
44
+ image_dir = 'C:\reg'
45
+ class_tokens = 'human'
46
+ keep_tokens = 1
47
+
48
+ # これは fine tuning 方式のデータセット
49
+ [[datasets]]
50
+ resolution = [768, 768]
51
+ batch_size = 2
52
+
53
+ [[datasets.subsets]]
54
+ image_dir = 'C:\piyo'
55
+ metadata_file = 'C:\piyo\piyo_md.json'
56
+ # このサブセットは keep_tokens = 1 (general の値が使われる)
57
+ ```
58
+
59
+ この例では、3 つのディレクトリを DreamBooth 方式のデータセットとして 512x512 (batch size 4) で学習させ、1 つのディレクトリを fine tuning 方式のデータセットとして 768x768 (batch size 2) で学習させることになります。
60
+
61
+ ## データセット・サブセットに関する設定
62
+
63
+ データセット・サブセットに関する設定は、登録可能な箇所がいくつかに分かれています。
64
+
65
+ * `[general]`
66
+ * 全データセットまたは全サブセットに適用されるオプションを指定する箇所です。
67
+ * データセットごとの設定及びサブセットごとの設定に同名のオプションが存在していた場合には、データセット・サブセットごとの設定が優先されます。
68
+ * `[[datasets]]`
69
+ * `datasets` はデータセットに関する設定の登録箇所になります。各データセットに個別に適用されるオプションを指定する箇所です。
70
+ * サブセットごとの設定が存在していた場合には、サブセットごとの設定が優先されます。
71
+ * `[[datasets.subsets]]`
72
+ * `datasets.subsets` はサブセットに関する設定の登録箇所になります。各サブセットに個別に適用されるオプションを指定する箇所です。
73
+
74
+ 先程の例における、画像ディレクトリと登録箇所の対応に関するイメージ図です。
75
+
76
+ ```
77
+ C:\
78
+ ├─ hoge -> [[datasets.subsets]] No.1 ┐ ┐
79
+ ├─ fuga -> [[datasets.subsets]] No.2 |-> [[datasets]] No.1 |-> [general]
80
+ ├─ reg -> [[datasets.subsets]] No.3 ┘ |
81
+ └─ piyo -> [[datasets.subsets]] No.4 --> [[datasets]] No.2 ┘
82
+ ```
83
+
84
+ 画像ディレクトリがそれぞれ1つの `[[datasets.subsets]]` に対応しています。そして `[[datasets.subsets]]` が1つ以上組み合わさって1つの `[[datasets]]` を構成します。`[general]` には全ての `[[datasets]]`, `[[datasets.subsets]]` が属します。
85
+
86
+ 登録箇所ごとに指定可能なオプションは異なりますが、同名のオプションが指定された場合は下位の登録箇所にある値が優先されます。先程の例の `keep_tokens` オプションの扱われ方を確認してもらうと理解しやすいかと思います。
87
+
88
+ 加えて、学習方法が対応している手法によっても指定可能なオプションが変化します。
89
+
90
+ * DreamBooth 方式専用のオプション
91
+ * fine tuning 方式専用のオプション
92
+ * caption dropout の手法が使える場合のオプション
93
+
94
+ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学習方法では、両者を併用することができます。
95
+ 併用する際の注意点として、DreamBooth 方式なのか fine tuning 方式なのかはデータセット単位で判別を行っているため、同じデータセット中に DreamBooth 方式のサブセットと fine tuning 方式のサブセットを混在させることはできません。
96
+ つまり、これらを併用したい場合には異なる方式のサブセットが異なるデータセットに所属するように設定する必要があります。
97
+
98
+ プログラムの挙動としては、後述する `metadata_file` オプションが存在していたら fine tuning 方式のサブセットだと判断します。
99
+ そのため、同一のデータセットに所属するサブセットについて言うと、「全てが `metadata_file` オプションを持つ」か「全てが `metadata_file` オプションを持たない」かのどちらかになっていれば問題ありません。
100
+
101
+ 以下、利用可能なオプションを説明します。コマンドライン引数と名称が同一のオプションについては、基本的に説明を割愛します。他の README を参照してください。
102
+
103
+ ### 全学習方法で共通のオプション
104
+
105
+ 学習方法によらずに指定可能なオプションです。
106
+
107
+ #### データセット向けオプション
108
+
109
+ データセットの設定に関わるオプションです。`datasets.subsets` には記述できません。
110
+
111
+ | オプション名 | 設定例 | `[general]` | `[[datasets]]` |
112
+ | ---- | ---- | ---- | ---- |
113
+ | `batch_size` | `1` | o | o |
114
+ | `bucket_no_upscale` | `true` | o | o |
115
+ | `bucket_reso_steps` | `64` | o | o |
116
+ | `enable_bucket` | `true` | o | o |
117
+ | `max_bucket_reso` | `1024` | o | o |
118
+ | `min_bucket_reso` | `128` | o | o |
119
+ | `resolution` | `256`, `[512, 512]` | o | o |
120
+
121
+ * `batch_size`
122
+ * コマンドライン引数の `--train_batch_size` と同等です。
123
+
124
+ これらの設定はデータセットごとに固定です。
125
+ つまり、データセットに所属するサブセットはこれらの設定を共有することになります。
126
+ 例えば解像度が異なるデータセットを用意したい場合は、上に挙げた例のように別々のデータセットとして定義すれば別々の解像度を設定可能です。
127
+
128
+ #### サブセット向けオプション
129
+
130
+ サブセットの設定に関わるオプションです。
131
+
132
+ | オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
133
+ | ---- | ---- | ---- | ---- | ---- |
134
+ | `color_aug` | `false` | o | o | o |
135
+ | `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o |
136
+ | `flip_aug` | `true` | o | o | o |
137
+ | `keep_tokens` | `2` | o | o | o |
138
+ | `num_repeats` | `10` | o | o | o |
139
+ | `random_crop` | `false` | o | o | o |
140
+ | `shuffle_caption` | `true` | o | o | o |
141
+
142
+ * `num_repeats`
143
+ * サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
144
+
145
+ ### DreamBooth 方式専用のオプション
146
+
147
+ DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。
148
+
149
+ #### サブセット向けオプション
150
+
151
+ DreamBooth 方式のサブセットの設定に関わるオプションです。
152
+
153
+ | オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
154
+ | ---- | ---- | ---- | ---- | ---- |
155
+ | `image_dir` | `‘C:\hoge’` | - | - | o(必須) |
156
+ | `caption_extension` | `".txt"` | o | o | o |
157
+ | `class_tokens` | `“sks girl”` | - | - | o |
158
+ | `is_reg` | `false` | - | - | o |
159
+
160
+ まず注意点として、 `image_dir` には画像ファイルが直下に置かれているパスを指定する必要があります。従来の DreamBooth の手法ではサブディレクトリに画像を置く必要がありましたが、そちらとは仕様に互換性がありません。また、`5_cat` のようなフォルダ名にしても、画像の繰り返し回数とクラス名は反映されません。これらを個別に設定したい場合、`num_repeats` と `class_tokens` で明示的に指定する必要があることに注意してください。
161
+
162
+ * `image_dir`
163
+ * 画像ディレクトリのパスを指定します。指定必須オプションです。
164
+ * 画像はディレクトリ直下に置かれている必要があります。
165
+ * `class_tokens`
166
+ * クラストークンを設定します。
167
+ * 画像に対応する caption ファイルが存在しない場合にのみ学習時に利用されます。利用するかどうかの判定は画像ごとに行います。`class_tokens` を指定しなかった場合に caption ファイル���見つからなかった場合にはエラーになります。
168
+ * `is_reg`
169
+ * サブセットの画像が正規化用かどうかを指定します。指定しなかった場合は `false` として、つまり正規化画像ではないとして扱います。
170
+
171
+ ### fine tuning 方式専用のオプション
172
+
173
+ fine tuning 方式のオプションは、サブセット向けオプションのみ存在します。
174
+
175
+ #### サブセット向けオプション
176
+
177
+ fine tuning 方式のサブセットの設定に関わるオプションです。
178
+
179
+ | オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
180
+ | ---- | ---- | ---- | ---- | ---- |
181
+ | `image_dir` | `‘C:\hoge’` | - | - | o |
182
+ | `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o(必須) |
183
+
184
+ * `image_dir`
185
+ * 画像ディレクトリのパスを指定します。DreamBooth の手法の方とは異なり指定は必須ではありませんが、設定することを推奨します。
186
+ * 指定する必要がない状況としては、メタデータファイルの生成時に `--full_path` を付与して実行していた場合です。
187
+ * 画像はディレクトリ直下に置かれている必要があります。
188
+ * `metadata_file`
189
+ * サブセットで利用されるメタデータファイルのパスを指定します。指定必須オプションです。
190
+ * コマンドライン引数の `--in_json` と同等です。
191
+ * サブセットごとにメタデータファイルを指定する必要がある仕様上、ディレクトリを跨いだメタデータを1つのメタデータファイルとして作成することは避けた方が良いでしょう。画像ディレクトリごとにメタデータファイルを用意し、それらを別々のサブセットとして登録することを強く推奨します。
192
+
193
+ ### caption dropout の手法が使える場合に指定可能なオプション
194
+
195
+ caption dropout の手法が使える場合のオプションは、サブセット向けオプションのみ存在します。
196
+ DreamBooth 方式か fine tuning 方式かに関わらず、caption dropout に対応している学習方法であれば指定可能です。
197
+
198
+ #### サブセット向けオプション
199
+
200
+ caption dropout が使えるサブセットの設定に関わるオプションです。
201
+
202
+ | オプション名 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
203
+ | ---- | ---- | ---- | ---- |
204
+ | `caption_dropout_every_n_epochs` | o | o | o |
205
+ | `caption_dropout_rate` | o | o | o |
206
+ | `caption_tag_dropout_rate` | o | o | o |
207
+
208
+ ## 重複したサブセットが存在する時の挙動
209
+
210
+ DreamBooth 方式のデータセットの場合、その中にある `image_dir` が同一のサブセットは重複していると見なされます。
211
+ fine tuning 方式のデータセットの場合は、その中にある `metadata_file` が同一のサブセットは重複していると見なされます。
212
+ データセット中に重複したサブセットが存在する場合、2個目以降は無視されます。
213
+
214
+ 一方、異なるデータセットに所属している場合は、重複しているとは見なされません。
215
+ 例えば、以下のように同一の `image_dir` を持つサブセットを別々のデータセットに入れた場合には、重複していないと見なします。
216
+ これは、同じ画像でも異なる解像度で学習したい場合に役立ちます。
217
+
218
+ ```toml
219
+ # 別々のデータセットに存在している場合は重複とは見なされず、両方とも学習に使われる
220
+
221
+ [[datasets]]
222
+ resolution = 512
223
+
224
+ [[datasets.subsets]]
225
+ image_dir = 'C:\hoge'
226
+
227
+ [[datasets]]
228
+ resolution = 768
229
+
230
+ [[datasets.subsets]]
231
+ image_dir = 'C:\hoge'
232
+ ```
233
+
234
+ ## コマンドライン引数との併用
235
+
236
+ 設定ファイルのオプションの中には、コマンドライン引数のオプションと役割が重複しているものがあります。
237
+
238
+ 以下に挙げるコマンドライン引数のオプションは、設定ファイルを渡した場合には無視されます。
239
+
240
+ * `--train_data_dir`
241
+ * `--reg_data_dir`
242
+ * `--in_json`
243
+
244
+ 以下に挙げるコマンドライン引数のオプションは、コマンドライン引数と設定ファイルで同時に指定された場合、コマンドライン引数の値よりも設定ファイルの値が優先されます。特に断りがなければ同名のオプションとなります。
245
+
246
+ | コマンドライン引数のオプション | 優先される設定ファイルのオプション |
247
+ | ---------------------------------- | ---------------------------------- |
248
+ | `--bucket_no_upscale` | |
249
+ | `--bucket_reso_steps` | |
250
+ | `--caption_dropout_every_n_epochs` | |
251
+ | `--caption_dropout_rate` | |
252
+ | `--caption_extension` | |
253
+ | `--caption_tag_dropout_rate` | |
254
+ | `--color_aug` | |
255
+ | `--dataset_repeats` | `num_repeats` |
256
+ | `--enable_bucket` | |
257
+ | `--face_crop_aug_range` | |
258
+ | `--flip_aug` | |
259
+ | `--keep_tokens` | |
260
+ | `--min_bucket_reso` | |
261
+ | `--random_crop` | |
262
+ | `--resolution` | |
263
+ | `--shuffle_caption` | |
264
+ | `--train_batch_size` | `batch_size` |
265
+
266
+ ## エラーの手引き
267
+
268
+ 現在、外部ライブラリを利用して設定ファイルの記述が正しいかどうかをチェックしているのですが、整備が行き届いておらずエラーメッセージがわかりづらいという問題があります。
269
+ 将来的にはこの問題の改善に取り組む予定です。
270
+
271
+ 次善策として、頻出のエラーとその対処法について載せておきます。
272
+ 正しいはずなのにエラーが出る場合、エラー内容がどうしても分からない場合は、バグかもしれないのでご連絡ください。
273
+
274
+ * `voluptuous.error.MultipleInvalid: required key not provided @ ...`: 指定必須のオプションが指定されていないというエラーです。指定を忘れているか、オプション名を間違って記述している可能性が高いです。
275
+ * `...` の箇所にはエラーが発生した場所が載っています。例えば `voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']` のようなエラーが出たら、0 番目の `datasets` 中の 0 番目の `subsets` の設定に `image_dir` が存在しないということになります。
276
+ * `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: 指定する値の形式が不正というエラーです。値の形式が間違っている可能性が高いです。`int` の部分は対象となるオプションによって変わります。この README に載っているオプションの「設定例」が役立つかもしれません。
277
+ * `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: 対応していないオプション名が存在している場合に発生するエラーです。オプション名を間違って記述しているか、誤って紛れ込んでいる可能性が高いです。
278
+
279
+
fine_tune.py CHANGED
@@ -13,7 +13,11 @@ import diffusers
13
  from diffusers import DDPMScheduler
14
 
15
  import library.train_util as train_util
16
-
 
 
 
 
17
 
18
  def collate_fn(examples):
19
  return examples[0]
@@ -30,25 +34,36 @@ def train(args):
30
 
31
  tokenizer = train_util.load_tokenizer(args)
32
 
33
- train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
34
- tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
35
- args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
36
- args.bucket_reso_steps, args.bucket_no_upscale,
37
- args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
38
- args.dataset_repeats, args.debug_dataset)
39
-
40
- # 学習データのdropout率を設定する
41
- train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
42
-
43
- train_dataset.make_buckets()
 
 
 
 
 
 
 
 
44
 
45
  if args.debug_dataset:
46
- train_util.debug_dataset(train_dataset)
47
  return
48
- if len(train_dataset) == 0:
49
  print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
50
  return
51
 
 
 
 
52
  # acceleratorを準備する
53
  print("prepare accelerator")
54
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
@@ -109,7 +124,7 @@ def train(args):
109
  vae.requires_grad_(False)
110
  vae.eval()
111
  with torch.no_grad():
112
- train_dataset.cache_latents(vae)
113
  vae.to("cpu")
114
  if torch.cuda.is_available():
115
  torch.cuda.empty_cache()
@@ -149,33 +164,13 @@ def train(args):
149
 
150
  # 学習に必要なクラスを準備する
151
  print("prepare optimizer, data loader etc.")
152
-
153
- # 8-bit Adamを使う
154
- if args.use_8bit_adam:
155
- try:
156
- import bitsandbytes as bnb
157
- except ImportError:
158
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
159
- print("use 8-bit Adam optimizer")
160
- optimizer_class = bnb.optim.AdamW8bit
161
- elif args.use_lion_optimizer:
162
- try:
163
- import lion_pytorch
164
- except ImportError:
165
- raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
166
- print("use Lion optimizer")
167
- optimizer_class = lion_pytorch.Lion
168
- else:
169
- optimizer_class = torch.optim.AdamW
170
-
171
- # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
172
- optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
173
 
174
  # dataloaderを準備する
175
  # DataLoaderのプロセス数:0はメインプロセスになる
176
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
177
  train_dataloader = torch.utils.data.DataLoader(
178
- train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
179
 
180
  # 学習ステップ数を計算する
181
  if args.max_train_epochs is not None:
@@ -183,8 +178,9 @@ def train(args):
183
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
184
 
185
  # lr schedulerを用意する
186
- lr_scheduler = diffusers.optimization.get_scheduler(
187
- args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
 
188
 
189
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
190
  if args.full_fp16:
@@ -218,7 +214,7 @@ def train(args):
218
  # 学習する
219
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
220
  print("running training / 学習開始")
221
- print(f" num examples / サンプル数: {train_dataset.num_train_images}")
222
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
223
  print(f" num epochs / epoch数: {num_train_epochs}")
224
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
@@ -237,7 +233,7 @@ def train(args):
237
 
238
  for epoch in range(num_train_epochs):
239
  print(f"epoch {epoch+1}/{num_train_epochs}")
240
- train_dataset.set_current_epoch(epoch + 1)
241
 
242
  for m in training_models:
243
  m.train()
@@ -286,11 +282,11 @@ def train(args):
286
  loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
287
 
288
  accelerator.backward(loss)
289
- if accelerator.sync_gradients:
290
  params_to_clip = []
291
  for m in training_models:
292
  params_to_clip.extend(m.parameters())
293
- accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
294
 
295
  optimizer.step()
296
  lr_scheduler.step()
@@ -301,11 +297,16 @@ def train(args):
301
  progress_bar.update(1)
302
  global_step += 1
303
 
 
 
304
  current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
305
  if args.logging_dir is not None:
306
- logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
 
 
307
  accelerator.log(logs, step=global_step)
308
 
 
309
  loss_total += current_loss
310
  avr_loss = loss_total / (step+1)
311
  logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
@@ -315,7 +316,7 @@ def train(args):
315
  break
316
 
317
  if args.logging_dir is not None:
318
- logs = {"epoch_loss": loss_total / len(train_dataloader)}
319
  accelerator.log(logs, step=epoch+1)
320
 
321
  accelerator.wait_for_everyone()
@@ -325,6 +326,8 @@ def train(args):
325
  train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
326
  save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
327
 
 
 
328
  is_main_process = accelerator.is_main_process
329
  if is_main_process:
330
  unet = unwrap_model(unet)
@@ -351,6 +354,8 @@ if __name__ == '__main__':
351
  train_util.add_dataset_arguments(parser, False, True, True)
352
  train_util.add_training_arguments(parser, False)
353
  train_util.add_sd_saving_arguments(parser)
 
 
354
 
355
  parser.add_argument("--diffusers_xformers", action='store_true',
356
  help='use xformers by diffusers / Diffusersでxformersを使用する')
 
13
  from diffusers import DDPMScheduler
14
 
15
  import library.train_util as train_util
16
+ import library.config_util as config_util
17
+ from library.config_util import (
18
+ ConfigSanitizer,
19
+ BlueprintGenerator,
20
+ )
21
 
22
  def collate_fn(examples):
23
  return examples[0]
 
34
 
35
  tokenizer = train_util.load_tokenizer(args)
36
 
37
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
38
+ if args.dataset_config is not None:
39
+ print(f"Load dataset config from {args.dataset_config}")
40
+ user_config = config_util.load_user_config(args.dataset_config)
41
+ ignored = ["train_data_dir", "in_json"]
42
+ if any(getattr(args, attr) is not None for attr in ignored):
43
+ print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
44
+ else:
45
+ user_config = {
46
+ "datasets": [{
47
+ "subsets": [{
48
+ "image_dir": args.train_data_dir,
49
+ "metadata_file": args.in_json,
50
+ }]
51
+ }]
52
+ }
53
+
54
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
55
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
56
 
57
  if args.debug_dataset:
58
+ train_util.debug_dataset(train_dataset_group)
59
  return
60
+ if len(train_dataset_group) == 0:
61
  print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
62
  return
63
 
64
+ if cache_latents:
65
+ assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
66
+
67
  # acceleratorを準備する
68
  print("prepare accelerator")
69
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
 
124
  vae.requires_grad_(False)
125
  vae.eval()
126
  with torch.no_grad():
127
+ train_dataset_group.cache_latents(vae)
128
  vae.to("cpu")
129
  if torch.cuda.is_available():
130
  torch.cuda.empty_cache()
 
164
 
165
  # 学習に必要なクラスを準備する
166
  print("prepare optimizer, data loader etc.")
167
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  # dataloaderを準備する
170
  # DataLoaderのプロセス数:0はメインプロセスになる
171
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
172
  train_dataloader = torch.utils.data.DataLoader(
173
+ train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
174
 
175
  # 学習ステップ数を計算する
176
  if args.max_train_epochs is not None:
 
178
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
179
 
180
  # lr schedulerを用意する
181
+ lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
182
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
183
+ num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
184
 
185
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
186
  if args.full_fp16:
 
214
  # 学習する
215
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
216
  print("running training / 学習開始")
217
+ print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
218
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
219
  print(f" num epochs / epoch数: {num_train_epochs}")
220
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
 
233
 
234
  for epoch in range(num_train_epochs):
235
  print(f"epoch {epoch+1}/{num_train_epochs}")
236
+ train_dataset_group.set_current_epoch(epoch + 1)
237
 
238
  for m in training_models:
239
  m.train()
 
282
  loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
283
 
284
  accelerator.backward(loss)
285
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
286
  params_to_clip = []
287
  for m in training_models:
288
  params_to_clip.extend(m.parameters())
289
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
290
 
291
  optimizer.step()
292
  lr_scheduler.step()
 
297
  progress_bar.update(1)
298
  global_step += 1
299
 
300
+ train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
301
+
302
  current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
303
  if args.logging_dir is not None:
304
+ logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
305
+ if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
306
+ logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
307
  accelerator.log(logs, step=global_step)
308
 
309
+ # TODO moving averageにする
310
  loss_total += current_loss
311
  avr_loss = loss_total / (step+1)
312
  logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
 
316
  break
317
 
318
  if args.logging_dir is not None:
319
+ logs = {"loss/epoch": loss_total / len(train_dataloader)}
320
  accelerator.log(logs, step=epoch+1)
321
 
322
  accelerator.wait_for_everyone()
 
326
  train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
327
  save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
328
 
329
+ train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
330
+
331
  is_main_process = accelerator.is_main_process
332
  if is_main_process:
333
  unet = unwrap_model(unet)
 
354
  train_util.add_dataset_arguments(parser, False, True, True)
355
  train_util.add_training_arguments(parser, False)
356
  train_util.add_sd_saving_arguments(parser)
357
+ train_util.add_optimizer_arguments(parser)
358
+ config_util.add_config_arguments(parser)
359
 
360
  parser.add_argument("--diffusers_xformers", action='store_true',
361
  help='use xformers by diffusers / Diffusersでxformersを使用する')
fine_tune_README_ja.md ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NovelAIの提案した学習手法、自動キャプションニング、タグ付け、Windows+VRAM 12GB(SD v1.xの場合)環境等に対応したfine tuningです。ここでfine tuningとは、モデルを画像とキャプションで学習することを指します(LoRAやTextual Inversion、Hypernetworksは含みません)
2
+
3
+ [学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。
4
+
5
+ # 概要
6
+
7
+ Diffusersを用いてStable DiffusionのU-Netのfine tuningを行います。NovelAIの記事にある以下の改善に対応しています(Aspect Ratio BucketingについてはNovelAIのコードを参考にしましたが、最終的なコードはすべてオリジナルです)。
8
+
9
+ * CLIP(Text Encoder)の最後の層ではなく最後から二番目の層の出力を用いる。
10
+ * 正方形以外の解像度での学習(Aspect Ratio Bucketing) 。
11
+ * トークン長を75から225に拡張する。
12
+ * BLIPによるキャプショニング(キャプションの自動作成)、DeepDanbooruまたはWD14Taggerによる自動タグ付けを行う。
13
+ * Hypernetworkの学習にも対応する。
14
+ * Stable Diffusion v2.0(baseおよび768/v)に対応。
15
+ * VAEの出力をあらかじめ取得しディスクに保存しておくことで、学習の省メモリ化、高速化を図る。
16
+
17
+ デフォルトではText Encoderの学習は行いません。モデル全体のfine tuningではU-Netだけを学習するのが一般的なようです(NovelAIもそのようです)。オプション指定でText Encoderも学習対象とできます。
18
+
19
+ # 追加機能について
20
+
21
+ ## CLIPの出力の変更
22
+
23
+ プロンプトを画像に反映するため、テキストの特徴量への変換を行うのがCLIP(Text Encoder)です。Stable DiffusionではCLIPの最後の層の出力を用いていますが、それを最後から二番目の層の出力を用いるよう変更できます。NovelAIによると、これによりより正確にプロンプトが反映されるようになるとのことです。
24
+ 元のまま、最後の層の出力を用いることも可能です。
25
+
26
+ ※Stable Diffusion 2.0では最後から二番目の層をデフォルトで使います。clip_skipオプションを指定しないでください。
27
+
28
+ ## 正方形以外の解像度での学習
29
+
30
+ Stable Diffusionは512\*512で学習されていますが、それに加えて256\*1024や384\*640といった解像度でも学習します。これによりトリミングされる部分が減り、より正しくプロンプトと画像の関係が学習されることが期待されます。
31
+ 学習解像度はパラメータとして与えられた解像度の面積(=メモリ使用量)を超えない範囲で、64ピクセル単位で縦横に調整、作成されます。
32
+
33
+ 機械学習では入力サイズをすべて統一するのが一般的ですが、特に制約があるわけではなく、実際は同一のバッチ内で統一されていれば大丈夫です。NovelAIの言うbucketingは、あらかじめ教師データを、アスペクト比に応じた学習解像度ごとに分類しておくことを指しているようです。そしてバッチを各bucket内の画像で作成することで、バッチの画像サイズを統一します。
34
+
35
+ ## トークン長の75から225への拡張
36
+
37
+ Stable Diffusionでは最大75トークン(開始・終了を含むと77トークン)ですが、それを225トークンまで拡張します。
38
+ ただしCLIPが受け付ける最大長は75トークンですので、225トークンの場合、単純に三分割してCLIPを呼び出してから結果を連結しています。
39
+
40
+ ※これが望ましい実装なのかどうかはいまひとつわかりません。とりあえず動いてはいるようです。特に2.0では何も参考になる実装がないので独自に実装してあります。
41
+
42
+ ※Automatic1111氏のWeb UIではカンマを意識して分割、といったこともしているようですが、私の場合はそこまでしておらず単純な分割です。
43
+
44
+ # 学習の手順
45
+
46
+ あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。
47
+
48
+ ## データの準備
49
+
50
+ [学習データの準備について](./train_README-ja.md) を参照してください。fine tuningではメタデータを用いるfine tuning方式のみ対応しています。
51
+
52
+ ## 学習の実行
53
+ たとえば以下のように実行します。以下は省メモリ化のための設定です。それぞれの行を必要に応じて書き換えてください。
54
+
55
+ ```
56
+ accelerate launch --num_cpu_threads_per_process 1 fine_tune.py
57
+ --pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
58
+ --output_dir=<学習したモデルの出力先フォルダ>
59
+ --output_name=<学習したモデル出力時のファイル名>
60
+ --dataset_config=<データ準備で作成した.tomlファイル>
61
+ --save_model_as=safetensors
62
+ --learning_rate=5e-6 --max_train_steps=10000
63
+ --use_8bit_adam --xformers --gradient_checkpointing
64
+ --mixed_precision=fp16
65
+ ```
66
+
67
+ `num_cpu_threads_per_process` には通常は1を指定するとよいようです。
68
+
69
+ `pretrained_model_name_or_path` に追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。
70
+
71
+ `output_dir` に学習後のモデルを保存するフォルダを指定します。`output_name` にモデルのファイル名を拡張子を除いて指定します。`save_model_as` でsafetensors形式での保存を指定しています。
72
+
73
+ `dataset_config` に `.toml` ファイルを指定します。ファイル内でのバッチサイズ指定は、当初はメモリ消費を抑えるために `1` としてください。
74
+
75
+ 学習させるステップ数 `max_train_steps` を10000とします。学習率 `learning_rate` はここでは5e-6を指定しています。
76
+
77
+ 省メモリ化のため `mixed_precision="fp16"` を指定します(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。また `gradient_checkpointing` を指定します。
78
+
79
+ オプティマイザ(モデルを学習データにあうように最適化=学習させるクラス)にメモリ消費の少ない 8bit AdamW を使うため、 `optimizer_type="AdamW8bit"` を指定します。
80
+
81
+ `xformers` オプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。
82
+
83
+ ある程度メモリがある場合は、`.toml` ファイルを編集してバッチサイズをたとえば `4` くらいに増やしてください(高速化と精度向上の可能性があります)。
84
+
85
+ ### よく使われるオプションについて
86
+
87
+ 以下の場合にはオプションに関するドキュメントを参照してください。
88
+
89
+ - Stable Diffusion 2.xまたはそこからの派生モデルを学習する
90
+ - clip skipを2以上を前提としたモデルを学習する
91
+ - 75トークンを超えたキャプションで学習する
92
+
93
+ ### バッチサイズについて
94
+
95
+ モデル全体を学習するためLoRA等の学習に比べるとメモリ消費量は多くなります(DreamBoothと同じ)。
96
+
97
+ ### 学習率について
98
+
99
+ 1e-6から5e-6程度が一般的なようです。他のfine tuningの例なども参照してみてください。
100
+
101
+ ### 以前の形式のデータセット指定をした場合のコマンドライン
102
+
103
+ 解像度やバッチサイズをオプションで指定します。コマンドラインの例は以下の通りです。
104
+
105
+ ```
106
+ accelerate launch --num_cpu_threads_per_process 1 fine_tune.py
107
+ --pretrained_model_name_or_path=model.ckpt
108
+ --in_json meta_lat.json
109
+ --train_data_dir=train_data
110
+ --output_dir=fine_tuned
111
+ --shuffle_caption
112
+ --train_batch_size=1 --learning_rate=5e-6 --max_train_steps=10000
113
+ --use_8bit_adam --xformers --gradient_checkpointing
114
+ --mixed_precision=bf16
115
+ --save_every_n_epochs=4
116
+ ```
117
+
118
+ <!--
119
+ ### 勾配をfp16とした学習(実験的機能)
120
+ full_fp16オプションを指定すると勾配を通常のfloat32からfloat16(fp16)に変更して学習します(mixed precisionではなく完全なfp16学習になるようです)。これによりSD1.xの512*512サイズでは8GB未満、SD2.xの512*512サイズで12GB未満のVRAM使用量で学習できるようです。
121
+
122
+ あらかじめaccelerate configでfp16を指定し、オプションでmixed_precision="fp16"としてください(bf16では動作しません)。
123
+
124
+ メモリ使用量を最小化するためには、xformers、use_8bit_adam、gradient_checkpointingの各オプションを指定し、train_batch_sizeを1としてください。
125
+ (余裕があるようならtrain_batch_sizeを段階的に増やすと若干精度が上がるはずです。)
126
+
127
+ PyTorchのソースにパッチを当てて無理やり実現しています(PyTorch 1.12.1と1.13.0で確認)。精度はかなり落ちますし、途中で学習失敗する確率も高くなります。学習率やステップ数の設定もシビアなようです。それらを認識したうえで自己責任でお使いください。
128
+ -->
129
+
130
+ # fine tuning特有のその他の主なオプション
131
+
132
+ すべてのオプションについては別文書を参照してください。
133
+
134
+ ## `train_text_encoder`
135
+ Text Encoderも学習対象とします。メモリ使用量が若干増加します。
136
+
137
+ 通常のfine tuningではText Encoderは学習対象としませんが(恐らくText Encoderの出力に従うようにU-Netを学習するため)、学習データ数が少ない場合には、DreamBoothのようにText Encoder側に学習させるのも有効的なようです。
138
+
139
+ ## `diffusers_xformers`
140
+ スクリプト独自のxformers置換機能ではなくDiffusersのxformers機能を利用します。Hypernetworkの学習はできなくなります。
finetune/blip/blip.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import warnings
9
+ warnings.filterwarnings("ignore")
10
+
11
+ # from models.vit import VisionTransformer, interpolate_pos_embed
12
+ # from models.med import BertConfig, BertModel, BertLMHeadModel
13
+ from blip.vit import VisionTransformer, interpolate_pos_embed
14
+ from blip.med import BertConfig, BertModel, BertLMHeadModel
15
+ from transformers import BertTokenizer
16
+
17
+ import torch
18
+ from torch import nn
19
+ import torch.nn.functional as F
20
+
21
+ import os
22
+ from urllib.parse import urlparse
23
+ from timm.models.hub import download_cached_file
24
+
25
+ class BLIP_Base(nn.Module):
26
+ def __init__(self,
27
+ med_config = 'configs/med_config.json',
28
+ image_size = 224,
29
+ vit = 'base',
30
+ vit_grad_ckpt = False,
31
+ vit_ckpt_layer = 0,
32
+ ):
33
+ """
34
+ Args:
35
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
36
+ image_size (int): input image size
37
+ vit (str): model size of vision transformer
38
+ """
39
+ super().__init__()
40
+
41
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
42
+ self.tokenizer = init_tokenizer()
43
+ med_config = BertConfig.from_json_file(med_config)
44
+ med_config.encoder_width = vision_width
45
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
46
+
47
+
48
+ def forward(self, image, caption, mode):
49
+
50
+ assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
51
+ text = self.tokenizer(caption, return_tensors="pt").to(image.device)
52
+
53
+ if mode=='image':
54
+ # return image features
55
+ image_embeds = self.visual_encoder(image)
56
+ return image_embeds
57
+
58
+ elif mode=='text':
59
+ # return text features
60
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
61
+ return_dict = True, mode = 'text')
62
+ return text_output.last_hidden_state
63
+
64
+ elif mode=='multimodal':
65
+ # return multimodel features
66
+ image_embeds = self.visual_encoder(image)
67
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
68
+
69
+ text.input_ids[:,0] = self.tokenizer.enc_token_id
70
+ output = self.text_encoder(text.input_ids,
71
+ attention_mask = text.attention_mask,
72
+ encoder_hidden_states = image_embeds,
73
+ encoder_attention_mask = image_atts,
74
+ return_dict = True,
75
+ )
76
+ return output.last_hidden_state
77
+
78
+
79
+
80
+ class BLIP_Decoder(nn.Module):
81
+ def __init__(self,
82
+ med_config = 'configs/med_config.json',
83
+ image_size = 384,
84
+ vit = 'base',
85
+ vit_grad_ckpt = False,
86
+ vit_ckpt_layer = 0,
87
+ prompt = 'a picture of ',
88
+ ):
89
+ """
90
+ Args:
91
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
92
+ image_size (int): input image size
93
+ vit (str): model size of vision transformer
94
+ """
95
+ super().__init__()
96
+
97
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
98
+ self.tokenizer = init_tokenizer()
99
+ med_config = BertConfig.from_json_file(med_config)
100
+ med_config.encoder_width = vision_width
101
+ self.text_decoder = BertLMHeadModel(config=med_config)
102
+
103
+ self.prompt = prompt
104
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
105
+
106
+
107
+ def forward(self, image, caption):
108
+
109
+ image_embeds = self.visual_encoder(image)
110
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
111
+
112
+ text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
113
+
114
+ text.input_ids[:,0] = self.tokenizer.bos_token_id
115
+
116
+ decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
117
+ decoder_targets[:,:self.prompt_length] = -100
118
+
119
+ decoder_output = self.text_decoder(text.input_ids,
120
+ attention_mask = text.attention_mask,
121
+ encoder_hidden_states = image_embeds,
122
+ encoder_attention_mask = image_atts,
123
+ labels = decoder_targets,
124
+ return_dict = True,
125
+ )
126
+ loss_lm = decoder_output.loss
127
+
128
+ return loss_lm
129
+
130
+ def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
131
+ image_embeds = self.visual_encoder(image)
132
+
133
+ if not sample:
134
+ image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
135
+
136
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
137
+ model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
138
+
139
+ prompt = [self.prompt] * image.size(0)
140
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
141
+ input_ids[:,0] = self.tokenizer.bos_token_id
142
+ input_ids = input_ids[:, :-1]
143
+
144
+ if sample:
145
+ #nucleus sampling
146
+ outputs = self.text_decoder.generate(input_ids=input_ids,
147
+ max_length=max_length,
148
+ min_length=min_length,
149
+ do_sample=True,
150
+ top_p=top_p,
151
+ num_return_sequences=1,
152
+ eos_token_id=self.tokenizer.sep_token_id,
153
+ pad_token_id=self.tokenizer.pad_token_id,
154
+ repetition_penalty=1.1,
155
+ **model_kwargs)
156
+ else:
157
+ #beam search
158
+ outputs = self.text_decoder.generate(input_ids=input_ids,
159
+ max_length=max_length,
160
+ min_length=min_length,
161
+ num_beams=num_beams,
162
+ eos_token_id=self.tokenizer.sep_token_id,
163
+ pad_token_id=self.tokenizer.pad_token_id,
164
+ repetition_penalty=repetition_penalty,
165
+ **model_kwargs)
166
+
167
+ captions = []
168
+ for output in outputs:
169
+ caption = self.tokenizer.decode(output, skip_special_tokens=True)
170
+ captions.append(caption[len(self.prompt):])
171
+ return captions
172
+
173
+
174
+ def blip_decoder(pretrained='',**kwargs):
175
+ model = BLIP_Decoder(**kwargs)
176
+ if pretrained:
177
+ model,msg = load_checkpoint(model,pretrained)
178
+ assert(len(msg.missing_keys)==0)
179
+ return model
180
+
181
+ def blip_feature_extractor(pretrained='',**kwargs):
182
+ model = BLIP_Base(**kwargs)
183
+ if pretrained:
184
+ model,msg = load_checkpoint(model,pretrained)
185
+ assert(len(msg.missing_keys)==0)
186
+ return model
187
+
188
+ def init_tokenizer():
189
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
190
+ tokenizer.add_special_tokens({'bos_token':'[DEC]'})
191
+ tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
192
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
193
+ return tokenizer
194
+
195
+
196
+ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
197
+
198
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
199
+ if vit=='base':
200
+ vision_width = 768
201
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
202
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
203
+ drop_path_rate=0 or drop_path_rate
204
+ )
205
+ elif vit=='large':
206
+ vision_width = 1024
207
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
208
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
209
+ drop_path_rate=0.1 or drop_path_rate
210
+ )
211
+ return visual_encoder, vision_width
212
+
213
+ def is_url(url_or_filename):
214
+ parsed = urlparse(url_or_filename)
215
+ return parsed.scheme in ("http", "https")
216
+
217
+ def load_checkpoint(model,url_or_filename):
218
+ if is_url(url_or_filename):
219
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
220
+ checkpoint = torch.load(cached_file, map_location='cpu')
221
+ elif os.path.isfile(url_or_filename):
222
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
223
+ else:
224
+ raise RuntimeError('checkpoint url or path is invalid')
225
+
226
+ state_dict = checkpoint['model']
227
+
228
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
229
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
230
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
231
+ model.visual_encoder_m)
232
+ for key in model.state_dict().keys():
233
+ if key in state_dict.keys():
234
+ if state_dict[key].shape!=model.state_dict()[key].shape:
235
+ del state_dict[key]
236
+
237
+ msg = model.load_state_dict(state_dict,strict=False)
238
+ print('load checkpoint from %s'%url_or_filename)
239
+ return model,msg
240
+
finetune/blip/med.py ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ '''
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ class BertEmbeddings(nn.Module):
53
+ """Construct the embeddings from word and position embeddings."""
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
58
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
59
+
60
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
61
+ # any TensorFlow checkpoint file
62
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
63
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
64
+
65
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
66
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
67
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
68
+
69
+ self.config = config
70
+
71
+ def forward(
72
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
73
+ ):
74
+ if input_ids is not None:
75
+ input_shape = input_ids.size()
76
+ else:
77
+ input_shape = inputs_embeds.size()[:-1]
78
+
79
+ seq_length = input_shape[1]
80
+
81
+ if position_ids is None:
82
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
83
+
84
+ if inputs_embeds is None:
85
+ inputs_embeds = self.word_embeddings(input_ids)
86
+
87
+ embeddings = inputs_embeds
88
+
89
+ if self.position_embedding_type == "absolute":
90
+ position_embeddings = self.position_embeddings(position_ids)
91
+ embeddings += position_embeddings
92
+ embeddings = self.LayerNorm(embeddings)
93
+ embeddings = self.dropout(embeddings)
94
+ return embeddings
95
+
96
+
97
+ class BertSelfAttention(nn.Module):
98
+ def __init__(self, config, is_cross_attention):
99
+ super().__init__()
100
+ self.config = config
101
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
102
+ raise ValueError(
103
+ "The hidden size (%d) is not a multiple of the number of attention "
104
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
105
+ )
106
+
107
+ self.num_attention_heads = config.num_attention_heads
108
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
109
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
110
+
111
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
112
+ if is_cross_attention:
113
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
114
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
115
+ else:
116
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
117
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
118
+
119
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
120
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
121
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
122
+ self.max_position_embeddings = config.max_position_embeddings
123
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
124
+ self.save_attention = False
125
+
126
+ def save_attn_gradients(self, attn_gradients):
127
+ self.attn_gradients = attn_gradients
128
+
129
+ def get_attn_gradients(self):
130
+ return self.attn_gradients
131
+
132
+ def save_attention_map(self, attention_map):
133
+ self.attention_map = attention_map
134
+
135
+ def get_attention_map(self):
136
+ return self.attention_map
137
+
138
+ def transpose_for_scores(self, x):
139
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
140
+ x = x.view(*new_x_shape)
141
+ return x.permute(0, 2, 1, 3)
142
+
143
+ def forward(
144
+ self,
145
+ hidden_states,
146
+ attention_mask=None,
147
+ head_mask=None,
148
+ encoder_hidden_states=None,
149
+ encoder_attention_mask=None,
150
+ past_key_value=None,
151
+ output_attentions=False,
152
+ ):
153
+ mixed_query_layer = self.query(hidden_states)
154
+
155
+ # If this is instantiated as a cross-attention module, the keys
156
+ # and values come from an encoder; the attention mask needs to be
157
+ # such that the encoder's padding tokens are not attended to.
158
+ is_cross_attention = encoder_hidden_states is not None
159
+
160
+ if is_cross_attention:
161
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
162
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
163
+ attention_mask = encoder_attention_mask
164
+ elif past_key_value is not None:
165
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
166
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
167
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
168
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
169
+ else:
170
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
171
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
172
+
173
+ query_layer = self.transpose_for_scores(mixed_query_layer)
174
+
175
+ past_key_value = (key_layer, value_layer)
176
+
177
+ # Take the dot product between "query" and "key" to get the raw attention scores.
178
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
179
+
180
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
181
+ seq_length = hidden_states.size()[1]
182
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
183
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
184
+ distance = position_ids_l - position_ids_r
185
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
186
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
187
+
188
+ if self.position_embedding_type == "relative_key":
189
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
190
+ attention_scores = attention_scores + relative_position_scores
191
+ elif self.position_embedding_type == "relative_key_query":
192
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
193
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
194
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
195
+
196
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
197
+ if attention_mask is not None:
198
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
199
+ attention_scores = attention_scores + attention_mask
200
+
201
+ # Normalize the attention scores to probabilities.
202
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
203
+
204
+ if is_cross_attention and self.save_attention:
205
+ self.save_attention_map(attention_probs)
206
+ attention_probs.register_hook(self.save_attn_gradients)
207
+
208
+ # This is actually dropping out entire tokens to attend to, which might
209
+ # seem a bit unusual, but is taken from the original Transformer paper.
210
+ attention_probs_dropped = self.dropout(attention_probs)
211
+
212
+ # Mask heads if we want to
213
+ if head_mask is not None:
214
+ attention_probs_dropped = attention_probs_dropped * head_mask
215
+
216
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
217
+
218
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
219
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
220
+ context_layer = context_layer.view(*new_context_layer_shape)
221
+
222
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
223
+
224
+ outputs = outputs + (past_key_value,)
225
+ return outputs
226
+
227
+
228
+ class BertSelfOutput(nn.Module):
229
+ def __init__(self, config):
230
+ super().__init__()
231
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
232
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
233
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
234
+
235
+ def forward(self, hidden_states, input_tensor):
236
+ hidden_states = self.dense(hidden_states)
237
+ hidden_states = self.dropout(hidden_states)
238
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
239
+ return hidden_states
240
+
241
+
242
+ class BertAttention(nn.Module):
243
+ def __init__(self, config, is_cross_attention=False):
244
+ super().__init__()
245
+ self.self = BertSelfAttention(config, is_cross_attention)
246
+ self.output = BertSelfOutput(config)
247
+ self.pruned_heads = set()
248
+
249
+ def prune_heads(self, heads):
250
+ if len(heads) == 0:
251
+ return
252
+ heads, index = find_pruneable_heads_and_indices(
253
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
254
+ )
255
+
256
+ # Prune linear layers
257
+ self.self.query = prune_linear_layer(self.self.query, index)
258
+ self.self.key = prune_linear_layer(self.self.key, index)
259
+ self.self.value = prune_linear_layer(self.self.value, index)
260
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
261
+
262
+ # Update hyper params and store pruned heads
263
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
264
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
265
+ self.pruned_heads = self.pruned_heads.union(heads)
266
+
267
+ def forward(
268
+ self,
269
+ hidden_states,
270
+ attention_mask=None,
271
+ head_mask=None,
272
+ encoder_hidden_states=None,
273
+ encoder_attention_mask=None,
274
+ past_key_value=None,
275
+ output_attentions=False,
276
+ ):
277
+ self_outputs = self.self(
278
+ hidden_states,
279
+ attention_mask,
280
+ head_mask,
281
+ encoder_hidden_states,
282
+ encoder_attention_mask,
283
+ past_key_value,
284
+ output_attentions,
285
+ )
286
+ attention_output = self.output(self_outputs[0], hidden_states)
287
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
288
+ return outputs
289
+
290
+
291
+ class BertIntermediate(nn.Module):
292
+ def __init__(self, config):
293
+ super().__init__()
294
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
295
+ if isinstance(config.hidden_act, str):
296
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
297
+ else:
298
+ self.intermediate_act_fn = config.hidden_act
299
+
300
+ def forward(self, hidden_states):
301
+ hidden_states = self.dense(hidden_states)
302
+ hidden_states = self.intermediate_act_fn(hidden_states)
303
+ return hidden_states
304
+
305
+
306
+ class BertOutput(nn.Module):
307
+ def __init__(self, config):
308
+ super().__init__()
309
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
310
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
311
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
312
+
313
+ def forward(self, hidden_states, input_tensor):
314
+ hidden_states = self.dense(hidden_states)
315
+ hidden_states = self.dropout(hidden_states)
316
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
317
+ return hidden_states
318
+
319
+
320
+ class BertLayer(nn.Module):
321
+ def __init__(self, config, layer_num):
322
+ super().__init__()
323
+ self.config = config
324
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
325
+ self.seq_len_dim = 1
326
+ self.attention = BertAttention(config)
327
+ self.layer_num = layer_num
328
+ if self.config.add_cross_attention:
329
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
330
+ self.intermediate = BertIntermediate(config)
331
+ self.output = BertOutput(config)
332
+
333
+ def forward(
334
+ self,
335
+ hidden_states,
336
+ attention_mask=None,
337
+ head_mask=None,
338
+ encoder_hidden_states=None,
339
+ encoder_attention_mask=None,
340
+ past_key_value=None,
341
+ output_attentions=False,
342
+ mode=None,
343
+ ):
344
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
345
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
346
+ self_attention_outputs = self.attention(
347
+ hidden_states,
348
+ attention_mask,
349
+ head_mask,
350
+ output_attentions=output_attentions,
351
+ past_key_value=self_attn_past_key_value,
352
+ )
353
+ attention_output = self_attention_outputs[0]
354
+
355
+ outputs = self_attention_outputs[1:-1]
356
+ present_key_value = self_attention_outputs[-1]
357
+
358
+ if mode=='multimodal':
359
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
360
+
361
+ cross_attention_outputs = self.crossattention(
362
+ attention_output,
363
+ attention_mask,
364
+ head_mask,
365
+ encoder_hidden_states,
366
+ encoder_attention_mask,
367
+ output_attentions=output_attentions,
368
+ )
369
+ attention_output = cross_attention_outputs[0]
370
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
371
+ layer_output = apply_chunking_to_forward(
372
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
373
+ )
374
+ outputs = (layer_output,) + outputs
375
+
376
+ outputs = outputs + (present_key_value,)
377
+
378
+ return outputs
379
+
380
+ def feed_forward_chunk(self, attention_output):
381
+ intermediate_output = self.intermediate(attention_output)
382
+ layer_output = self.output(intermediate_output, attention_output)
383
+ return layer_output
384
+
385
+
386
+ class BertEncoder(nn.Module):
387
+ def __init__(self, config):
388
+ super().__init__()
389
+ self.config = config
390
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
391
+ self.gradient_checkpointing = False
392
+
393
+ def forward(
394
+ self,
395
+ hidden_states,
396
+ attention_mask=None,
397
+ head_mask=None,
398
+ encoder_hidden_states=None,
399
+ encoder_attention_mask=None,
400
+ past_key_values=None,
401
+ use_cache=None,
402
+ output_attentions=False,
403
+ output_hidden_states=False,
404
+ return_dict=True,
405
+ mode='multimodal',
406
+ ):
407
+ all_hidden_states = () if output_hidden_states else None
408
+ all_self_attentions = () if output_attentions else None
409
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
410
+
411
+ next_decoder_cache = () if use_cache else None
412
+
413
+ for i in range(self.config.num_hidden_layers):
414
+ layer_module = self.layer[i]
415
+ if output_hidden_states:
416
+ all_hidden_states = all_hidden_states + (hidden_states,)
417
+
418
+ layer_head_mask = head_mask[i] if head_mask is not None else None
419
+ past_key_value = past_key_values[i] if past_key_values is not None else None
420
+
421
+ if self.gradient_checkpointing and self.training:
422
+
423
+ if use_cache:
424
+ logger.warn(
425
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
426
+ )
427
+ use_cache = False
428
+
429
+ def create_custom_forward(module):
430
+ def custom_forward(*inputs):
431
+ return module(*inputs, past_key_value, output_attentions)
432
+
433
+ return custom_forward
434
+
435
+ layer_outputs = torch.utils.checkpoint.checkpoint(
436
+ create_custom_forward(layer_module),
437
+ hidden_states,
438
+ attention_mask,
439
+ layer_head_mask,
440
+ encoder_hidden_states,
441
+ encoder_attention_mask,
442
+ mode=mode,
443
+ )
444
+ else:
445
+ layer_outputs = layer_module(
446
+ hidden_states,
447
+ attention_mask,
448
+ layer_head_mask,
449
+ encoder_hidden_states,
450
+ encoder_attention_mask,
451
+ past_key_value,
452
+ output_attentions,
453
+ mode=mode,
454
+ )
455
+
456
+ hidden_states = layer_outputs[0]
457
+ if use_cache:
458
+ next_decoder_cache += (layer_outputs[-1],)
459
+ if output_attentions:
460
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
461
+
462
+ if output_hidden_states:
463
+ all_hidden_states = all_hidden_states + (hidden_states,)
464
+
465
+ if not return_dict:
466
+ return tuple(
467
+ v
468
+ for v in [
469
+ hidden_states,
470
+ next_decoder_cache,
471
+ all_hidden_states,
472
+ all_self_attentions,
473
+ all_cross_attentions,
474
+ ]
475
+ if v is not None
476
+ )
477
+ return BaseModelOutputWithPastAndCrossAttentions(
478
+ last_hidden_state=hidden_states,
479
+ past_key_values=next_decoder_cache,
480
+ hidden_states=all_hidden_states,
481
+ attentions=all_self_attentions,
482
+ cross_attentions=all_cross_attentions,
483
+ )
484
+
485
+
486
+ class BertPooler(nn.Module):
487
+ def __init__(self, config):
488
+ super().__init__()
489
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
490
+ self.activation = nn.Tanh()
491
+
492
+ def forward(self, hidden_states):
493
+ # We "pool" the model by simply taking the hidden state corresponding
494
+ # to the first token.
495
+ first_token_tensor = hidden_states[:, 0]
496
+ pooled_output = self.dense(first_token_tensor)
497
+ pooled_output = self.activation(pooled_output)
498
+ return pooled_output
499
+
500
+
501
+ class BertPredictionHeadTransform(nn.Module):
502
+ def __init__(self, config):
503
+ super().__init__()
504
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
505
+ if isinstance(config.hidden_act, str):
506
+ self.transform_act_fn = ACT2FN[config.hidden_act]
507
+ else:
508
+ self.transform_act_fn = config.hidden_act
509
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
510
+
511
+ def forward(self, hidden_states):
512
+ hidden_states = self.dense(hidden_states)
513
+ hidden_states = self.transform_act_fn(hidden_states)
514
+ hidden_states = self.LayerNorm(hidden_states)
515
+ return hidden_states
516
+
517
+
518
+ class BertLMPredictionHead(nn.Module):
519
+ def __init__(self, config):
520
+ super().__init__()
521
+ self.transform = BertPredictionHeadTransform(config)
522
+
523
+ # The output weights are the same as the input embeddings, but there is
524
+ # an output-only bias for each token.
525
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
526
+
527
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
528
+
529
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
530
+ self.decoder.bias = self.bias
531
+
532
+ def forward(self, hidden_states):
533
+ hidden_states = self.transform(hidden_states)
534
+ hidden_states = self.decoder(hidden_states)
535
+ return hidden_states
536
+
537
+
538
+ class BertOnlyMLMHead(nn.Module):
539
+ def __init__(self, config):
540
+ super().__init__()
541
+ self.predictions = BertLMPredictionHead(config)
542
+
543
+ def forward(self, sequence_output):
544
+ prediction_scores = self.predictions(sequence_output)
545
+ return prediction_scores
546
+
547
+
548
+ class BertPreTrainedModel(PreTrainedModel):
549
+ """
550
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
551
+ models.
552
+ """
553
+
554
+ config_class = BertConfig
555
+ base_model_prefix = "bert"
556
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
557
+
558
+ def _init_weights(self, module):
559
+ """ Initialize the weights """
560
+ if isinstance(module, (nn.Linear, nn.Embedding)):
561
+ # Slightly different from the TF version which uses truncated_normal for initialization
562
+ # cf https://github.com/pytorch/pytorch/pull/5617
563
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
564
+ elif isinstance(module, nn.LayerNorm):
565
+ module.bias.data.zero_()
566
+ module.weight.data.fill_(1.0)
567
+ if isinstance(module, nn.Linear) and module.bias is not None:
568
+ module.bias.data.zero_()
569
+
570
+
571
+ class BertModel(BertPreTrainedModel):
572
+ """
573
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
574
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
575
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
576
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
577
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
578
+ input to the forward pass.
579
+ """
580
+
581
+ def __init__(self, config, add_pooling_layer=True):
582
+ super().__init__(config)
583
+ self.config = config
584
+
585
+ self.embeddings = BertEmbeddings(config)
586
+
587
+ self.encoder = BertEncoder(config)
588
+
589
+ self.pooler = BertPooler(config) if add_pooling_layer else None
590
+
591
+ self.init_weights()
592
+
593
+
594
+ def get_input_embeddings(self):
595
+ return self.embeddings.word_embeddings
596
+
597
+ def set_input_embeddings(self, value):
598
+ self.embeddings.word_embeddings = value
599
+
600
+ def _prune_heads(self, heads_to_prune):
601
+ """
602
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
603
+ class PreTrainedModel
604
+ """
605
+ for layer, heads in heads_to_prune.items():
606
+ self.encoder.layer[layer].attention.prune_heads(heads)
607
+
608
+
609
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
610
+ """
611
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
612
+
613
+ Arguments:
614
+ attention_mask (:obj:`torch.Tensor`):
615
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
616
+ input_shape (:obj:`Tuple[int]`):
617
+ The shape of the input to the model.
618
+ device: (:obj:`torch.device`):
619
+ The device of the input to the model.
620
+
621
+ Returns:
622
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
623
+ """
624
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
625
+ # ourselves in which case we just need to make it broadcastable to all heads.
626
+ if attention_mask.dim() == 3:
627
+ extended_attention_mask = attention_mask[:, None, :, :]
628
+ elif attention_mask.dim() == 2:
629
+ # Provided a padding mask of dimensions [batch_size, seq_length]
630
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
631
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
632
+ if is_decoder:
633
+ batch_size, seq_length = input_shape
634
+
635
+ seq_ids = torch.arange(seq_length, device=device)
636
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
637
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
638
+ # causal and attention masks must have same type with pytorch version < 1.3
639
+ causal_mask = causal_mask.to(attention_mask.dtype)
640
+
641
+ if causal_mask.shape[1] < attention_mask.shape[1]:
642
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
643
+ causal_mask = torch.cat(
644
+ [
645
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
646
+ causal_mask,
647
+ ],
648
+ axis=-1,
649
+ )
650
+
651
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
652
+ else:
653
+ extended_attention_mask = attention_mask[:, None, None, :]
654
+ else:
655
+ raise ValueError(
656
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
657
+ input_shape, attention_mask.shape
658
+ )
659
+ )
660
+
661
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
662
+ # masked positions, this operation will create a tensor which is 0.0 for
663
+ # positions we want to attend and -10000.0 for masked positions.
664
+ # Since we are adding it to the raw scores before the softmax, this is
665
+ # effectively the same as removing these entirely.
666
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
667
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
668
+ return extended_attention_mask
669
+
670
+ def forward(
671
+ self,
672
+ input_ids=None,
673
+ attention_mask=None,
674
+ position_ids=None,
675
+ head_mask=None,
676
+ inputs_embeds=None,
677
+ encoder_embeds=None,
678
+ encoder_hidden_states=None,
679
+ encoder_attention_mask=None,
680
+ past_key_values=None,
681
+ use_cache=None,
682
+ output_attentions=None,
683
+ output_hidden_states=None,
684
+ return_dict=None,
685
+ is_decoder=False,
686
+ mode='multimodal',
687
+ ):
688
+ r"""
689
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
690
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
691
+ the model is configured as a decoder.
692
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
693
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
694
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
695
+ - 1 for tokens that are **not masked**,
696
+ - 0 for tokens that are **masked**.
697
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
698
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
699
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
700
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
701
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
702
+ use_cache (:obj:`bool`, `optional`):
703
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
704
+ decoding (see :obj:`past_key_values`).
705
+ """
706
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
707
+ output_hidden_states = (
708
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
709
+ )
710
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
711
+
712
+ if is_decoder:
713
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
714
+ else:
715
+ use_cache = False
716
+
717
+ if input_ids is not None and inputs_embeds is not None:
718
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
719
+ elif input_ids is not None:
720
+ input_shape = input_ids.size()
721
+ batch_size, seq_length = input_shape
722
+ device = input_ids.device
723
+ elif inputs_embeds is not None:
724
+ input_shape = inputs_embeds.size()[:-1]
725
+ batch_size, seq_length = input_shape
726
+ device = inputs_embeds.device
727
+ elif encoder_embeds is not None:
728
+ input_shape = encoder_embeds.size()[:-1]
729
+ batch_size, seq_length = input_shape
730
+ device = encoder_embeds.device
731
+ else:
732
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
733
+
734
+ # past_key_values_length
735
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
736
+
737
+ if attention_mask is None:
738
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
739
+
740
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
741
+ # ourselves in which case we just need to make it broadcastable to all heads.
742
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
743
+ device, is_decoder)
744
+
745
+ # If a 2D or 3D attention mask is provided for the cross-attention
746
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
747
+ if encoder_hidden_states is not None:
748
+ if type(encoder_hidden_states) == list:
749
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
750
+ else:
751
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
752
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
753
+
754
+ if type(encoder_attention_mask) == list:
755
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
756
+ elif encoder_attention_mask is None:
757
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
758
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
759
+ else:
760
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
761
+ else:
762
+ encoder_extended_attention_mask = None
763
+
764
+ # Prepare head mask if needed
765
+ # 1.0 in head_mask indicate we keep the head
766
+ # attention_probs has shape bsz x n_heads x N x N
767
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
768
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
769
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
770
+
771
+ if encoder_embeds is None:
772
+ embedding_output = self.embeddings(
773
+ input_ids=input_ids,
774
+ position_ids=position_ids,
775
+ inputs_embeds=inputs_embeds,
776
+ past_key_values_length=past_key_values_length,
777
+ )
778
+ else:
779
+ embedding_output = encoder_embeds
780
+
781
+ encoder_outputs = self.encoder(
782
+ embedding_output,
783
+ attention_mask=extended_attention_mask,
784
+ head_mask=head_mask,
785
+ encoder_hidden_states=encoder_hidden_states,
786
+ encoder_attention_mask=encoder_extended_attention_mask,
787
+ past_key_values=past_key_values,
788
+ use_cache=use_cache,
789
+ output_attentions=output_attentions,
790
+ output_hidden_states=output_hidden_states,
791
+ return_dict=return_dict,
792
+ mode=mode,
793
+ )
794
+ sequence_output = encoder_outputs[0]
795
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
796
+
797
+ if not return_dict:
798
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
799
+
800
+ return BaseModelOutputWithPoolingAndCrossAttentions(
801
+ last_hidden_state=sequence_output,
802
+ pooler_output=pooled_output,
803
+ past_key_values=encoder_outputs.past_key_values,
804
+ hidden_states=encoder_outputs.hidden_states,
805
+ attentions=encoder_outputs.attentions,
806
+ cross_attentions=encoder_outputs.cross_attentions,
807
+ )
808
+
809
+
810
+
811
+ class BertLMHeadModel(BertPreTrainedModel):
812
+
813
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
814
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
815
+
816
+ def __init__(self, config):
817
+ super().__init__(config)
818
+
819
+ self.bert = BertModel(config, add_pooling_layer=False)
820
+ self.cls = BertOnlyMLMHead(config)
821
+
822
+ self.init_weights()
823
+
824
+ def get_output_embeddings(self):
825
+ return self.cls.predictions.decoder
826
+
827
+ def set_output_embeddings(self, new_embeddings):
828
+ self.cls.predictions.decoder = new_embeddings
829
+
830
+ def forward(
831
+ self,
832
+ input_ids=None,
833
+ attention_mask=None,
834
+ position_ids=None,
835
+ head_mask=None,
836
+ inputs_embeds=None,
837
+ encoder_hidden_states=None,
838
+ encoder_attention_mask=None,
839
+ labels=None,
840
+ past_key_values=None,
841
+ use_cache=None,
842
+ output_attentions=None,
843
+ output_hidden_states=None,
844
+ return_dict=None,
845
+ return_logits=False,
846
+ is_decoder=True,
847
+ reduction='mean',
848
+ mode='multimodal',
849
+ ):
850
+ r"""
851
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
852
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
853
+ the model is configured as a decoder.
854
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
855
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
856
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
857
+ - 1 for tokens that are **not masked**,
858
+ - 0 for tokens that are **masked**.
859
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
860
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
861
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
862
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
863
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
864
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
865
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
866
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
867
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
868
+ use_cache (:obj:`bool`, `optional`):
869
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
870
+ decoding (see :obj:`past_key_values`).
871
+ Returns:
872
+ Example::
873
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
874
+ >>> import torch
875
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
876
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
877
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
878
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
879
+ >>> outputs = model(**inputs)
880
+ >>> prediction_logits = outputs.logits
881
+ """
882
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
883
+ if labels is not None:
884
+ use_cache = False
885
+
886
+ outputs = self.bert(
887
+ input_ids,
888
+ attention_mask=attention_mask,
889
+ position_ids=position_ids,
890
+ head_mask=head_mask,
891
+ inputs_embeds=inputs_embeds,
892
+ encoder_hidden_states=encoder_hidden_states,
893
+ encoder_attention_mask=encoder_attention_mask,
894
+ past_key_values=past_key_values,
895
+ use_cache=use_cache,
896
+ output_attentions=output_attentions,
897
+ output_hidden_states=output_hidden_states,
898
+ return_dict=return_dict,
899
+ is_decoder=is_decoder,
900
+ mode=mode,
901
+ )
902
+
903
+ sequence_output = outputs[0]
904
+ prediction_scores = self.cls(sequence_output)
905
+
906
+ if return_logits:
907
+ return prediction_scores[:, :-1, :].contiguous()
908
+
909
+ lm_loss = None
910
+ if labels is not None:
911
+ # we are doing next-token prediction; shift prediction scores and input ids by one
912
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
913
+ labels = labels[:, 1:].contiguous()
914
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
915
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
916
+ if reduction=='none':
917
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
918
+
919
+ if not return_dict:
920
+ output = (prediction_scores,) + outputs[2:]
921
+ return ((lm_loss,) + output) if lm_loss is not None else output
922
+
923
+ return CausalLMOutputWithCrossAttentions(
924
+ loss=lm_loss,
925
+ logits=prediction_scores,
926
+ past_key_values=outputs.past_key_values,
927
+ hidden_states=outputs.hidden_states,
928
+ attentions=outputs.attentions,
929
+ cross_attentions=outputs.cross_attentions,
930
+ )
931
+
932
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
933
+ input_shape = input_ids.shape
934
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
935
+ if attention_mask is None:
936
+ attention_mask = input_ids.new_ones(input_shape)
937
+
938
+ # cut decoder_input_ids if past is used
939
+ if past is not None:
940
+ input_ids = input_ids[:, -1:]
941
+
942
+ return {
943
+ "input_ids": input_ids,
944
+ "attention_mask": attention_mask,
945
+ "past_key_values": past,
946
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
947
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
948
+ "is_decoder": True,
949
+ }
950
+
951
+ def _reorder_cache(self, past, beam_idx):
952
+ reordered_past = ()
953
+ for layer_past in past:
954
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
955
+ return reordered_past
finetune/blip/med_config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30524,
19
+ "encoder_width": 768,
20
+ "add_cross_attention": true
21
+ }
22
+
finetune/blip/vit.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on timm code base
8
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ '''
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from functools import partial
15
+
16
+ from timm.models.vision_transformer import _cfg, PatchEmbed
17
+ from timm.models.registry import register_model
18
+ from timm.models.layers import trunc_normal_, DropPath
19
+ from timm.models.helpers import named_apply, adapt_input_conv
20
+
21
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
22
+
23
+ class Mlp(nn.Module):
24
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
25
+ """
26
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
42
+
43
+
44
+ class Attention(nn.Module):
45
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
50
+ self.scale = qk_scale or head_dim ** -0.5
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+ self.attn_gradients = None
56
+ self.attention_map = None
57
+
58
+ def save_attn_gradients(self, attn_gradients):
59
+ self.attn_gradients = attn_gradients
60
+
61
+ def get_attn_gradients(self):
62
+ return self.attn_gradients
63
+
64
+ def save_attention_map(self, attention_map):
65
+ self.attention_map = attention_map
66
+
67
+ def get_attention_map(self):
68
+ return self.attention_map
69
+
70
+ def forward(self, x, register_hook=False):
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
73
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
74
+
75
+ attn = (q @ k.transpose(-2, -1)) * self.scale
76
+ attn = attn.softmax(dim=-1)
77
+ attn = self.attn_drop(attn)
78
+
79
+ if register_hook:
80
+ self.save_attention_map(attn)
81
+ attn.register_hook(self.save_attn_gradients)
82
+
83
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
84
+ x = self.proj(x)
85
+ x = self.proj_drop(x)
86
+ return x
87
+
88
+
89
+ class Block(nn.Module):
90
+
91
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
92
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
93
+ super().__init__()
94
+ self.norm1 = norm_layer(dim)
95
+ self.attn = Attention(
96
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
97
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
98
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
99
+ self.norm2 = norm_layer(dim)
100
+ mlp_hidden_dim = int(dim * mlp_ratio)
101
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
102
+
103
+ if use_grad_checkpointing:
104
+ self.attn = checkpoint_wrapper(self.attn)
105
+ self.mlp = checkpoint_wrapper(self.mlp)
106
+
107
+ def forward(self, x, register_hook=False):
108
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
109
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
110
+ return x
111
+
112
+
113
+ class VisionTransformer(nn.Module):
114
+ """ Vision Transformer
115
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
116
+ https://arxiv.org/abs/2010.11929
117
+ """
118
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
119
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
120
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
121
+ use_grad_checkpointing=False, ckpt_layer=0):
122
+ """
123
+ Args:
124
+ img_size (int, tuple): input image size
125
+ patch_size (int, tuple): patch size
126
+ in_chans (int): number of input channels
127
+ num_classes (int): number of classes for classification head
128
+ embed_dim (int): embedding dimension
129
+ depth (int): depth of transformer
130
+ num_heads (int): number of attention heads
131
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
132
+ qkv_bias (bool): enable bias for qkv if True
133
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
134
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
135
+ drop_rate (float): dropout rate
136
+ attn_drop_rate (float): attention dropout rate
137
+ drop_path_rate (float): stochastic depth rate
138
+ norm_layer: (nn.Module): normalization layer
139
+ """
140
+ super().__init__()
141
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
142
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
143
+
144
+ self.patch_embed = PatchEmbed(
145
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
146
+
147
+ num_patches = self.patch_embed.num_patches
148
+
149
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
150
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
151
+ self.pos_drop = nn.Dropout(p=drop_rate)
152
+
153
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
154
+ self.blocks = nn.ModuleList([
155
+ Block(
156
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
157
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
158
+ use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
159
+ )
160
+ for i in range(depth)])
161
+ self.norm = norm_layer(embed_dim)
162
+
163
+ trunc_normal_(self.pos_embed, std=.02)
164
+ trunc_normal_(self.cls_token, std=.02)
165
+ self.apply(self._init_weights)
166
+
167
+ def _init_weights(self, m):
168
+ if isinstance(m, nn.Linear):
169
+ trunc_normal_(m.weight, std=.02)
170
+ if isinstance(m, nn.Linear) and m.bias is not None:
171
+ nn.init.constant_(m.bias, 0)
172
+ elif isinstance(m, nn.LayerNorm):
173
+ nn.init.constant_(m.bias, 0)
174
+ nn.init.constant_(m.weight, 1.0)
175
+
176
+ @torch.jit.ignore
177
+ def no_weight_decay(self):
178
+ return {'pos_embed', 'cls_token'}
179
+
180
+ def forward(self, x, register_blk=-1):
181
+ B = x.shape[0]
182
+ x = self.patch_embed(x)
183
+
184
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
185
+ x = torch.cat((cls_tokens, x), dim=1)
186
+
187
+ x = x + self.pos_embed[:,:x.size(1),:]
188
+ x = self.pos_drop(x)
189
+
190
+ for i,blk in enumerate(self.blocks):
191
+ x = blk(x, register_blk==i)
192
+ x = self.norm(x)
193
+
194
+ return x
195
+
196
+ @torch.jit.ignore()
197
+ def load_pretrained(self, checkpoint_path, prefix=''):
198
+ _load_weights(self, checkpoint_path, prefix)
199
+
200
+
201
+ @torch.no_grad()
202
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
203
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
204
+ """
205
+ import numpy as np
206
+
207
+ def _n2p(w, t=True):
208
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
209
+ w = w.flatten()
210
+ if t:
211
+ if w.ndim == 4:
212
+ w = w.transpose([3, 2, 0, 1])
213
+ elif w.ndim == 3:
214
+ w = w.transpose([2, 0, 1])
215
+ elif w.ndim == 2:
216
+ w = w.transpose([1, 0])
217
+ return torch.from_numpy(w)
218
+
219
+ w = np.load(checkpoint_path)
220
+ if not prefix and 'opt/target/embedding/kernel' in w:
221
+ prefix = 'opt/target/'
222
+
223
+ if hasattr(model.patch_embed, 'backbone'):
224
+ # hybrid
225
+ backbone = model.patch_embed.backbone
226
+ stem_only = not hasattr(backbone, 'stem')
227
+ stem = backbone if stem_only else backbone.stem
228
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
229
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
230
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
231
+ if not stem_only:
232
+ for i, stage in enumerate(backbone.stages):
233
+ for j, block in enumerate(stage.blocks):
234
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
235
+ for r in range(3):
236
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
237
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
238
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
239
+ if block.downsample is not None:
240
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
241
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
242
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
243
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
244
+ else:
245
+ embed_conv_w = adapt_input_conv(
246
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
247
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
248
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
249
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
250
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
251
+ if pos_embed_w.shape != model.pos_embed.shape:
252
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
253
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
254
+ model.pos_embed.copy_(pos_embed_w)
255
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
256
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
257
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
258
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
259
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
260
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
261
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
262
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
263
+ for i, block in enumerate(model.blocks.children()):
264
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
265
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
266
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
267
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
268
+ block.attn.qkv.weight.copy_(torch.cat([
269
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
270
+ block.attn.qkv.bias.copy_(torch.cat([
271
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
272
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
273
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
274
+ for r in range(2):
275
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
276
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
277
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
278
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
279
+
280
+
281
+ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
282
+ # interpolate position embedding
283
+ embedding_size = pos_embed_checkpoint.shape[-1]
284
+ num_patches = visual_encoder.patch_embed.num_patches
285
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
286
+ # height (== width) for the checkpoint position embedding
287
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
288
+ # height (== width) for the new position embedding
289
+ new_size = int(num_patches ** 0.5)
290
+
291
+ if orig_size!=new_size:
292
+ # class_token and dist_token are kept unchanged
293
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
294
+ # only the position tokens are interpolated
295
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
296
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
297
+ pos_tokens = torch.nn.functional.interpolate(
298
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
299
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
300
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
301
+ print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
302
+
303
+ return new_pos_embed
304
+ else:
305
+ return pos_embed_checkpoint
finetune/clean_captions_and_tags.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # このスクリプトのライセンスは、Apache License 2.0とします
2
+ # (c) 2022 Kohya S. @kohya_ss
3
+
4
+ import argparse
5
+ import glob
6
+ import os
7
+ import json
8
+ import re
9
+
10
+ from tqdm import tqdm
11
+
12
+ PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
13
+ PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
14
+ PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ')
15
+ PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ')
16
+
17
+ # 複数人がいるとき、複数の髪色や目の色が定義されていれば削除する
18
+ PATTERNS_REMOVE_IN_MULTI = [
19
+ PATTERN_HAIR_LENGTH,
20
+ PATTERN_HAIR_CUT,
21
+ re.compile(r', [\w\-]+ eyes, '),
22
+ re.compile(r', ([\w\-]+ sleeves|sleeveless), '),
23
+ # 複数の髪型定義がある場合は削除する
24
+ re.compile(
25
+ r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '),
26
+ ]
27
+
28
+
29
+ def clean_tags(image_key, tags):
30
+ # replace '_' to ' '
31
+ tags = tags.replace('^_^', '^@@@^')
32
+ tags = tags.replace('_', ' ')
33
+ tags = tags.replace('^@@@^', '^_^')
34
+
35
+ # remove rating: deepdanbooruのみ
36
+ tokens = tags.split(", rating")
37
+ if len(tokens) == 1:
38
+ # WD14 taggerのときはこちらになるのでメッセージは出さない
39
+ # print("no rating:")
40
+ # print(f"{image_key} {tags}")
41
+ pass
42
+ else:
43
+ if len(tokens) > 2:
44
+ print("multiple ratings:")
45
+ print(f"{image_key} {tags}")
46
+ tags = tokens[0]
47
+
48
+ tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
49
+
50
+ # 複数の人物がいる場合は髪色等のタグを削除する
51
+ if 'girls' in tags or 'boys' in tags:
52
+ for pat in PATTERNS_REMOVE_IN_MULTI:
53
+ found = pat.findall(tags)
54
+ if len(found) > 1: # 二つ以上、タグがある
55
+ tags = pat.sub("", tags)
56
+
57
+ # 髪の特殊対応
58
+ srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) # 髪の長さタグは例外なので避けておく(全員が同じ髪の長さの場合)
59
+ if srch_hair_len:
60
+ org = srch_hair_len.group()
61
+ tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags)
62
+
63
+ found = PATTERN_HAIR.findall(tags)
64
+ if len(found) > 1:
65
+ tags = PATTERN_HAIR.sub("", tags)
66
+
67
+ if srch_hair_len:
68
+ tags = tags.replace(", @@@, ", org) # 戻す
69
+
70
+ # white shirtとshirtみたいな重複タグの削除
71
+ found = PATTERN_WORD.findall(tags)
72
+ for word in found:
73
+ if re.search(f", ((\w+) )+{word}, ", tags):
74
+ tags = tags.replace(f", {word}, ", "")
75
+
76
+ tags = tags.replace(", , ", ", ")
77
+ assert tags.startswith(", ") and tags.endswith(", ")
78
+ tags = tags[2:-2]
79
+ return tags
80
+
81
+
82
+ # 上から順に検索、置換される
83
+ # ('置換元文字列', '置換後文字列')
84
+ CAPTION_REPLACEMENTS = [
85
+ ('anime anime', 'anime'),
86
+ ('young ', ''),
87
+ ('anime girl', 'girl'),
88
+ ('cartoon female', 'girl'),
89
+ ('cartoon lady', 'girl'),
90
+ ('cartoon character', 'girl'), # a or ~s
91
+ ('cartoon woman', 'girl'),
92
+ ('cartoon women', 'girls'),
93
+ ('cartoon girl', 'girl'),
94
+ ('anime female', 'girl'),
95
+ ('anime lady', 'girl'),
96
+ ('anime character', 'girl'), # a or ~s
97
+ ('anime woman', 'girl'),
98
+ ('anime women', 'girls'),
99
+ ('lady', 'girl'),
100
+ ('female', 'girl'),
101
+ ('woman', 'girl'),
102
+ ('women', 'girls'),
103
+ ('people', 'girls'),
104
+ ('person', 'girl'),
105
+ ('a cartoon figure', 'a figure'),
106
+ ('a cartoon image', 'an image'),
107
+ ('a cartoon picture', 'a picture'),
108
+ ('an anime cartoon image', 'an image'),
109
+ ('a cartoon anime drawing', 'a drawing'),
110
+ ('a cartoon drawing', 'a drawing'),
111
+ ('girl girl', 'girl'),
112
+ ]
113
+
114
+
115
+ def clean_caption(caption):
116
+ for rf, rt in CAPTION_REPLACEMENTS:
117
+ replaced = True
118
+ while replaced:
119
+ bef = caption
120
+ caption = caption.replace(rf, rt)
121
+ replaced = bef != caption
122
+ return caption
123
+
124
+
125
+ def main(args):
126
+ if os.path.exists(args.in_json):
127
+ print(f"loading existing metadata: {args.in_json}")
128
+ with open(args.in_json, "rt", encoding='utf-8') as f:
129
+ metadata = json.load(f)
130
+ else:
131
+ print("no metadata / メタデータファイルがありません")
132
+ return
133
+
134
+ print("cleaning captions and tags.")
135
+ image_keys = list(metadata.keys())
136
+ for image_key in tqdm(image_keys):
137
+ tags = metadata[image_key].get('tags')
138
+ if tags is None:
139
+ print(f"image does not have tags / メタデータにタグがありません: {image_key}")
140
+ else:
141
+ org = tags
142
+ tags = clean_tags(image_key, tags)
143
+ metadata[image_key]['tags'] = tags
144
+ if args.debug and org != tags:
145
+ print("FROM: " + org)
146
+ print("TO: " + tags)
147
+
148
+ caption = metadata[image_key].get('caption')
149
+ if caption is None:
150
+ print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
151
+ else:
152
+ org = caption
153
+ caption = clean_caption(caption)
154
+ metadata[image_key]['caption'] = caption
155
+ if args.debug and org != caption:
156
+ print("FROM: " + org)
157
+ print("TO: " + caption)
158
+
159
+ # metadataを書き出して終わり
160
+ print(f"writing metadata: {args.out_json}")
161
+ with open(args.out_json, "wt", encoding='utf-8') as f:
162
+ json.dump(metadata, f, indent=2)
163
+ print("done!")
164
+
165
+
166
+ if __name__ == '__main__':
167
+ parser = argparse.ArgumentParser()
168
+ # parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
169
+ parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
170
+ parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
171
+ parser.add_argument("--debug", action="store_true", help="debug mode")
172
+
173
+ args, unknown = parser.parse_known_args()
174
+ if len(unknown) == 1:
175
+ print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
176
+ print("All captions and tags in the metadata are processed.")
177
+ print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
178
+ print("メタデータ内のすべてのキャプションとタグが処理されます。")
179
+ args.in_json = args.out_json
180
+ args.out_json = unknown[0]
181
+ elif len(unknown) > 0:
182
+ raise ValueError(f"error: unrecognized arguments: {unknown}")
183
+
184
+ main(args)
finetune/hypernetwork_nai.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NAI compatible
2
+
3
+ import torch
4
+
5
+
6
+ class HypernetworkModule(torch.nn.Module):
7
+ def __init__(self, dim, multiplier=1.0):
8
+ super().__init__()
9
+
10
+ linear1 = torch.nn.Linear(dim, dim * 2)
11
+ linear2 = torch.nn.Linear(dim * 2, dim)
12
+ linear1.weight.data.normal_(mean=0.0, std=0.01)
13
+ linear1.bias.data.zero_()
14
+ linear2.weight.data.normal_(mean=0.0, std=0.01)
15
+ linear2.bias.data.zero_()
16
+ linears = [linear1, linear2]
17
+
18
+ self.linear = torch.nn.Sequential(*linears)
19
+ self.multiplier = multiplier
20
+
21
+ def forward(self, x):
22
+ return x + self.linear(x) * self.multiplier
23
+
24
+
25
+ class Hypernetwork(torch.nn.Module):
26
+ enable_sizes = [320, 640, 768, 1280]
27
+ # return self.modules[Hypernetwork.enable_sizes.index(size)]
28
+
29
+ def __init__(self, multiplier=1.0) -> None:
30
+ super().__init__()
31
+ self.modules = []
32
+ for size in Hypernetwork.enable_sizes:
33
+ self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier)))
34
+ self.register_module(f"{size}_0", self.modules[-1][0])
35
+ self.register_module(f"{size}_1", self.modules[-1][1])
36
+
37
+ def apply_to_stable_diffusion(self, text_encoder, vae, unet):
38
+ blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks
39
+ for block in blocks:
40
+ for subblk in block:
41
+ if 'SpatialTransformer' in str(type(subblk)):
42
+ for tf_block in subblk.transformer_blocks:
43
+ for attn in [tf_block.attn1, tf_block.attn2]:
44
+ size = attn.context_dim
45
+ if size in Hypernetwork.enable_sizes:
46
+ attn.hypernetwork = self
47
+ else:
48
+ attn.hypernetwork = None
49
+
50
+ def apply_to_diffusers(self, text_encoder, vae, unet):
51
+ blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks
52
+ for block in blocks:
53
+ if hasattr(block, 'attentions'):
54
+ for subblk in block.attentions:
55
+ if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~
56
+ for tf_block in subblk.transformer_blocks:
57
+ for attn in [tf_block.attn1, tf_block.attn2]:
58
+ size = attn.to_k.in_features
59
+ if size in Hypernetwork.enable_sizes:
60
+ attn.hypernetwork = self
61
+ else:
62
+ attn.hypernetwork = None
63
+ return True # TODO error checking
64
+
65
+ def forward(self, x, context):
66
+ size = context.shape[-1]
67
+ assert size in Hypernetwork.enable_sizes
68
+ module = self.modules[Hypernetwork.enable_sizes.index(size)]
69
+ return module[0].forward(context), module[1].forward(context)
70
+
71
+ def load_from_state_dict(self, state_dict):
72
+ # old ver to new ver
73
+ changes = {
74
+ 'linear1.bias': 'linear.0.bias',
75
+ 'linear1.weight': 'linear.0.weight',
76
+ 'linear2.bias': 'linear.1.bias',
77
+ 'linear2.weight': 'linear.1.weight',
78
+ }
79
+ for key_from, key_to in changes.items():
80
+ if key_from in state_dict:
81
+ state_dict[key_to] = state_dict[key_from]
82
+ del state_dict[key_from]
83
+
84
+ for size, sd in state_dict.items():
85
+ if type(size) == int:
86
+ self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True)
87
+ self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True)
88
+ return True
89
+
90
+ def get_state_dict(self):
91
+ state_dict = {}
92
+ for i, size in enumerate(Hypernetwork.enable_sizes):
93
+ sd0 = self.modules[i][0].state_dict()
94
+ sd1 = self.modules[i][1].state_dict()
95
+ state_dict[size] = [sd0, sd1]
96
+ return state_dict
finetune/make_captions.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import os
4
+ import json
5
+ import random
6
+
7
+ from PIL import Image
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ import torch
11
+ from torchvision import transforms
12
+ from torchvision.transforms.functional import InterpolationMode
13
+ from blip.blip import blip_decoder
14
+ import library.train_util as train_util
15
+
16
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+
18
+
19
+ IMAGE_SIZE = 384
20
+
21
+ # 正方形でいいのか? という気がするがソースがそうなので
22
+ IMAGE_TRANSFORM = transforms.Compose([
23
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
26
+ ])
27
+
28
+ # 共通化したいが微妙に処理が異なる……
29
+ class ImageLoadingTransformDataset(torch.utils.data.Dataset):
30
+ def __init__(self, image_paths):
31
+ self.images = image_paths
32
+
33
+ def __len__(self):
34
+ return len(self.images)
35
+
36
+ def __getitem__(self, idx):
37
+ img_path = self.images[idx]
38
+
39
+ try:
40
+ image = Image.open(img_path).convert("RGB")
41
+ # convert to tensor temporarily so dataloader will accept it
42
+ tensor = IMAGE_TRANSFORM(image)
43
+ except Exception as e:
44
+ print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
45
+ return None
46
+
47
+ return (tensor, img_path)
48
+
49
+
50
+ def collate_fn_remove_corrupted(batch):
51
+ """Collate function that allows to remove corrupted examples in the
52
+ dataloader. It expects that the dataloader returns 'None' when that occurs.
53
+ The 'None's in the batch are removed.
54
+ """
55
+ # Filter out all the Nones (corrupted examples)
56
+ batch = list(filter(lambda x: x is not None, batch))
57
+ return batch
58
+
59
+
60
+ def main(args):
61
+ # fix the seed for reproducibility
62
+ seed = args.seed # + utils.get_rank()
63
+ torch.manual_seed(seed)
64
+ np.random.seed(seed)
65
+ random.seed(seed)
66
+
67
+ if not os.path.exists("blip"):
68
+ args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
69
+
70
+ cwd = os.getcwd()
71
+ print('Current Working Directory is: ', cwd)
72
+ os.chdir('finetune')
73
+
74
+ print(f"load images from {args.train_data_dir}")
75
+ image_paths = train_util.glob_images(args.train_data_dir)
76
+ print(f"found {len(image_paths)} images.")
77
+
78
+ print(f"loading BLIP caption: {args.caption_weights}")
79
+ model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json")
80
+ model.eval()
81
+ model = model.to(DEVICE)
82
+ print("BLIP loaded")
83
+
84
+ # captioningする
85
+ def run_batch(path_imgs):
86
+ imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
87
+
88
+ with torch.no_grad():
89
+ if args.beam_search:
90
+ captions = model.generate(imgs, sample=False, num_beams=args.num_beams,
91
+ max_length=args.max_length, min_length=args.min_length)
92
+ else:
93
+ captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length)
94
+
95
+ for (image_path, _), caption in zip(path_imgs, captions):
96
+ with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
97
+ f.write(caption + "\n")
98
+ if args.debug:
99
+ print(image_path, caption)
100
+
101
+ # 読み込みの高速化のためにDataLoaderを使うオプション
102
+ if args.max_data_loader_n_workers is not None:
103
+ dataset = ImageLoadingTransformDataset(image_paths)
104
+ data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
105
+ num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
106
+ else:
107
+ data = [[(None, ip)] for ip in image_paths]
108
+
109
+ b_imgs = []
110
+ for data_entry in tqdm(data, smoothing=0.0):
111
+ for data in data_entry:
112
+ if data is None:
113
+ continue
114
+
115
+ img_tensor, image_path = data
116
+ if img_tensor is None:
117
+ try:
118
+ raw_image = Image.open(image_path)
119
+ if raw_image.mode != 'RGB':
120
+ raw_image = raw_image.convert("RGB")
121
+ img_tensor = IMAGE_TRANSFORM(raw_image)
122
+ except Exception as e:
123
+ print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
124
+ continue
125
+
126
+ b_imgs.append((image_path, img_tensor))
127
+ if len(b_imgs) >= args.batch_size:
128
+ run_batch(b_imgs)
129
+ b_imgs.clear()
130
+ if len(b_imgs) > 0:
131
+ run_batch(b_imgs)
132
+
133
+ print("done!")
134
+
135
+
136
+ if __name__ == '__main__':
137
+ parser = argparse.ArgumentParser()
138
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
139
+ parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
140
+ help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)")
141
+ parser.add_argument("--caption_extention", type=str, default=None,
142
+ help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
143
+ parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
144
+ parser.add_argument("--beam_search", action="store_true",
145
+ help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)")
146
+ parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
147
+ parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
148
+ help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
149
+ parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
150
+ parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
151
+ parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
152
+ parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
153
+ parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed')
154
+ parser.add_argument("--debug", action="store_true", help="debug mode")
155
+
156
+ args = parser.parse_args()
157
+
158
+ # スペルミスしていたオプションを復元する
159
+ if args.caption_extention is not None:
160
+ args.caption_extension = args.caption_extention
161
+
162
+ main(args)
finetune/make_captions_by_git.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import re
4
+
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+ import torch
8
+ from transformers import AutoProcessor, AutoModelForCausalLM
9
+ from transformers.generation.utils import GenerationMixin
10
+
11
+ import library.train_util as train_util
12
+
13
+
14
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+
16
+ PATTERN_REPLACE = [
17
+ re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
18
+ re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'),
19
+ re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"),
20
+ re.compile(r'with the number \d+ on (it|\w+ \w+)'),
21
+ re.compile(r'with the words "'),
22
+ re.compile(r'word \w+ on it'),
23
+ re.compile(r'that says the word \w+ on it'),
24
+ re.compile('that says\'the word "( on it)?'),
25
+ ]
26
+
27
+ # 誤検知しまくりの with the word xxxx を消す
28
+
29
+
30
+ def remove_words(captions, debug):
31
+ removed_caps = []
32
+ for caption in captions:
33
+ cap = caption
34
+ for pat in PATTERN_REPLACE:
35
+ cap = pat.sub("", cap)
36
+ if debug and cap != caption:
37
+ print(caption)
38
+ print(cap)
39
+ removed_caps.append(cap)
40
+ return removed_caps
41
+
42
+
43
+ def collate_fn_remove_corrupted(batch):
44
+ """Collate function that allows to remove corrupted examples in the
45
+ dataloader. It expects that the dataloader returns 'None' when that occurs.
46
+ The 'None's in the batch are removed.
47
+ """
48
+ # Filter out all the Nones (corrupted examples)
49
+ batch = list(filter(lambda x: x is not None, batch))
50
+ return batch
51
+
52
+
53
+ def main(args):
54
+ # GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
55
+ org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
56
+ curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
57
+
58
+ # input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す
59
+ # ここより上で置き換えようとするとすごく大変
60
+ def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs):
61
+ input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs)
62
+ if input_ids.size()[0] != curr_batch_size[0]:
63
+ input_ids = input_ids.repeat(curr_batch_size[0], 1)
64
+ return input_ids
65
+ GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
66
+
67
+ print(f"load images from {args.train_data_dir}")
68
+ image_paths = train_util.glob_images(args.train_data_dir)
69
+ print(f"found {len(image_paths)} images.")
70
+
71
+ # できればcacheに依存せず明示的にダウンロードしたい
72
+ print(f"loading GIT: {args.model_id}")
73
+ git_processor = AutoProcessor.from_pretrained(args.model_id)
74
+ git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
75
+ print("GIT loaded")
76
+
77
+ # captioningする
78
+ def run_batch(path_imgs):
79
+ imgs = [im for _, im in path_imgs]
80
+
81
+ curr_batch_size[0] = len(path_imgs)
82
+ inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
83
+ generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
84
+ captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
85
+
86
+ if args.remove_words:
87
+ captions = remove_words(captions, args.debug)
88
+
89
+ for (image_path, _), caption in zip(path_imgs, captions):
90
+ with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
91
+ f.write(caption + "\n")
92
+ if args.debug:
93
+ print(image_path, caption)
94
+
95
+ # 読み込みの高速化のためにDataLoaderを使うオプション
96
+ if args.max_data_loader_n_workers is not None:
97
+ dataset = train_util.ImageLoadingDataset(image_paths)
98
+ data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
99
+ num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
100
+ else:
101
+ data = [[(None, ip)] for ip in image_paths]
102
+
103
+ b_imgs = []
104
+ for data_entry in tqdm(data, smoothing=0.0):
105
+ for data in data_entry:
106
+ if data is None:
107
+ continue
108
+
109
+ image, image_path = data
110
+ if image is None:
111
+ try:
112
+ image = Image.open(image_path)
113
+ if image.mode != 'RGB':
114
+ image = image.convert("RGB")
115
+ except Exception as e:
116
+ print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
117
+ continue
118
+
119
+ b_imgs.append((image_path, image))
120
+ if len(b_imgs) >= args.batch_size:
121
+ run_batch(b_imgs)
122
+ b_imgs.clear()
123
+
124
+ if len(b_imgs) > 0:
125
+ run_batch(b_imgs)
126
+
127
+ print("done!")
128
+
129
+
130
+ if __name__ == '__main__':
131
+ parser = argparse.ArgumentParser()
132
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
133
+ parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
134
+ parser.add_argument("--model_id", type=str, default="microsoft/git-large-textcaps",
135
+ help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID")
136
+ parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
137
+ parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
138
+ help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
139
+ parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
140
+ parser.add_argument("--remove_words", action="store_true",
141
+ help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
142
+ parser.add_argument("--debug", action="store_true", help="debug mode")
143
+
144
+ args = parser.parse_args()
145
+ main(args)
finetune/merge_captions_to_metadata.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+ from typing import List
5
+ from tqdm import tqdm
6
+ import library.train_util as train_util
7
+
8
+
9
+ def main(args):
10
+ assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
11
+
12
+ train_data_dir_path = Path(args.train_data_dir)
13
+ image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
14
+ print(f"found {len(image_paths)} images.")
15
+
16
+ if args.in_json is None and Path(args.out_json).is_file():
17
+ args.in_json = args.out_json
18
+
19
+ if args.in_json is not None:
20
+ print(f"loading existing metadata: {args.in_json}")
21
+ metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
22
+ print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
23
+ else:
24
+ print("new metadata will be created / 新しいメタデータファイルが作成されます")
25
+ metadata = {}
26
+
27
+ print("merge caption texts to metadata json.")
28
+ for image_path in tqdm(image_paths):
29
+ caption_path = image_path.with_suffix(args.caption_extension)
30
+ caption = caption_path.read_text(encoding='utf-8').strip()
31
+
32
+ image_key = str(image_path) if args.full_path else image_path.stem
33
+ if image_key not in metadata:
34
+ metadata[image_key] = {}
35
+
36
+ metadata[image_key]['caption'] = caption
37
+ if args.debug:
38
+ print(image_key, caption)
39
+
40
+ # metadataを書き出して終わり
41
+ print(f"writing metadata: {args.out_json}")
42
+ Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
43
+ print("done!")
44
+
45
+
46
+ if __name__ == '__main__':
47
+ parser = argparse.ArgumentParser()
48
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
49
+ parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
50
+ parser.add_argument("--in_json", type=str,
51
+ help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
52
+ parser.add_argument("--caption_extention", type=str, default=None,
53
+ help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
54
+ parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
55
+ parser.add_argument("--full_path", action="store_true",
56
+ help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
57
+ parser.add_argument("--recursive", action="store_true",
58
+ help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
59
+ parser.add_argument("--debug", action="store_true", help="debug mode")
60
+
61
+ args = parser.parse_args()
62
+
63
+ # スペルミスしていたオプションを復元する
64
+ if args.caption_extention is not None:
65
+ args.caption_extension = args.caption_extention
66
+
67
+ main(args)
finetune/merge_dd_tags_to_metadata.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+ from typing import List
5
+ from tqdm import tqdm
6
+ import library.train_util as train_util
7
+
8
+
9
+ def main(args):
10
+ assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
11
+
12
+ train_data_dir_path = Path(args.train_data_dir)
13
+ image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
14
+ print(f"found {len(image_paths)} images.")
15
+
16
+ if args.in_json is None and Path(args.out_json).is_file():
17
+ args.in_json = args.out_json
18
+
19
+ if args.in_json is not None:
20
+ print(f"loading existing metadata: {args.in_json}")
21
+ metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
22
+ print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
23
+ else:
24
+ print("new metadata will be created / 新しいメタデータファイルが作成されます")
25
+ metadata = {}
26
+
27
+ print("merge tags to metadata json.")
28
+ for image_path in tqdm(image_paths):
29
+ tags_path = image_path.with_suffix(args.caption_extension)
30
+ tags = tags_path.read_text(encoding='utf-8').strip()
31
+
32
+ image_key = str(image_path) if args.full_path else image_path.stem
33
+ if image_key not in metadata:
34
+ metadata[image_key] = {}
35
+
36
+ metadata[image_key]['tags'] = tags
37
+ if args.debug:
38
+ print(image_key, tags)
39
+
40
+ # metadataを書き出して終わり
41
+ print(f"writing metadata: {args.out_json}")
42
+ Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
43
+
44
+ print("done!")
45
+
46
+
47
+ if __name__ == '__main__':
48
+ parser = argparse.ArgumentParser()
49
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
50
+ parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
51
+ parser.add_argument("--in_json", type=str,
52
+ help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
53
+ parser.add_argument("--full_path", action="store_true",
54
+ help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
55
+ parser.add_argument("--recursive", action="store_true",
56
+ help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
57
+ parser.add_argument("--caption_extension", type=str, default=".txt",
58
+ help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子")
59
+ parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
60
+
61
+ args = parser.parse_args()
62
+ main(args)
finetune/prepare_buckets_latents.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import json
4
+
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+ from PIL import Image
8
+ import cv2
9
+ import torch
10
+ from torchvision import transforms
11
+
12
+ import library.model_util as model_util
13
+ import library.train_util as train_util
14
+
15
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+
17
+ IMAGE_TRANSFORMS = transforms.Compose(
18
+ [
19
+ transforms.ToTensor(),
20
+ transforms.Normalize([0.5], [0.5]),
21
+ ]
22
+ )
23
+
24
+
25
+ def collate_fn_remove_corrupted(batch):
26
+ """Collate function that allows to remove corrupted examples in the
27
+ dataloader. It expects that the dataloader returns 'None' when that occurs.
28
+ The 'None's in the batch are removed.
29
+ """
30
+ # Filter out all the Nones (corrupted examples)
31
+ batch = list(filter(lambda x: x is not None, batch))
32
+ return batch
33
+
34
+
35
+ def get_latents(vae, images, weight_dtype):
36
+ img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
37
+ img_tensors = torch.stack(img_tensors)
38
+ img_tensors = img_tensors.to(DEVICE, weight_dtype)
39
+ with torch.no_grad():
40
+ latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy()
41
+ return latents
42
+
43
+
44
+ def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
45
+ if is_full_path:
46
+ base_name = os.path.splitext(os.path.basename(image_key))[0]
47
+ else:
48
+ base_name = image_key
49
+ if flip:
50
+ base_name += '_flip'
51
+ return os.path.join(data_dir, base_name)
52
+
53
+
54
+ def main(args):
55
+ # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
56
+ if args.bucket_reso_steps % 8 > 0:
57
+ print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
58
+
59
+ image_paths = train_util.glob_images(args.train_data_dir)
60
+ print(f"found {len(image_paths)} images.")
61
+
62
+ if os.path.exists(args.in_json):
63
+ print(f"loading existing metadata: {args.in_json}")
64
+ with open(args.in_json, "rt", encoding='utf-8') as f:
65
+ metadata = json.load(f)
66
+ else:
67
+ print(f"no metadata / メタデータファイルがありません: {args.in_json}")
68
+ return
69
+
70
+ weight_dtype = torch.float32
71
+ if args.mixed_precision == "fp16":
72
+ weight_dtype = torch.float16
73
+ elif args.mixed_precision == "bf16":
74
+ weight_dtype = torch.bfloat16
75
+
76
+ vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
77
+ vae.eval()
78
+ vae.to(DEVICE, dtype=weight_dtype)
79
+
80
+ # bucketのサイズを計算する
81
+ max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
82
+ assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
83
+
84
+ bucket_manager = train_util.BucketManager(args.bucket_no_upscale, max_reso,
85
+ args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps)
86
+ if not args.bucket_no_upscale:
87
+ bucket_manager.make_buckets()
88
+ else:
89
+ print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
90
+
91
+ # 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
92
+ img_ar_errors = []
93
+
94
+ def process_batch(is_last):
95
+ for bucket in bucket_manager.buckets:
96
+ if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
97
+ latents = get_latents(vae, [img for _, img in bucket], weight_dtype)
98
+ assert latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8, \
99
+ f"latent shape {latents.shape}, {bucket[0][1].shape}"
100
+
101
+ for (image_key, _), latent in zip(bucket, latents):
102
+ npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
103
+ np.savez(npz_file_name, latent)
104
+
105
+ # flip
106
+ if args.flip_aug:
107
+ latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない
108
+
109
+ for (image_key, _), latent in zip(bucket, latents):
110
+ npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
111
+ np.savez(npz_file_name, latent)
112
+ else:
113
+ # remove existing flipped npz
114
+ for image_key, _ in bucket:
115
+ npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz"
116
+ if os.path.isfile(npz_file_name):
117
+ print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
118
+ os.remove(npz_file_name)
119
+
120
+ bucket.clear()
121
+
122
+ # 読み込みの高速化のためにDataLoaderを使うオプション
123
+ if args.max_data_loader_n_workers is not None:
124
+ dataset = train_util.ImageLoadingDataset(image_paths)
125
+ data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
126
+ num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
127
+ else:
128
+ data = [[(None, ip)] for ip in image_paths]
129
+
130
+ bucket_counts = {}
131
+ for data_entry in tqdm(data, smoothing=0.0):
132
+ if data_entry[0] is None:
133
+ continue
134
+
135
+ img_tensor, image_path = data_entry[0]
136
+ if img_tensor is not None:
137
+ image = transforms.functional.to_pil_image(img_tensor)
138
+ else:
139
+ try:
140
+ image = Image.open(image_path)
141
+ if image.mode != 'RGB':
142
+ image = image.convert("RGB")
143
+ except Exception as e:
144
+ print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
145
+ continue
146
+
147
+ image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
148
+ if image_key not in metadata:
149
+ metadata[image_key] = {}
150
+
151
+ # 本当はこのあとの部分もDataSetに持っていけば高速化できるがいろいろ大変
152
+
153
+ reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height)
154
+ img_ar_errors.append(abs(ar_error))
155
+ bucket_counts[reso] = bucket_counts.get(reso, 0) + 1
156
+
157
+ # メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
158
+ metadata[image_key]['train_resolution'] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
159
+
160
+ if not args.bucket_no_upscale:
161
+ # upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
162
+ assert resized_size[0] == reso[0] or resized_size[1] == reso[
163
+ 1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
164
+ assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
165
+ 1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
166
+
167
+ assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
168
+ 1], f"internal error resized size is small: {resized_size}, {reso}"
169
+
170
+ # 既に存在するファイルがあればshapeを確認して同じならskipする
171
+ if args.skip_existing:
172
+ npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"]
173
+ if args.flip_aug:
174
+ npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz")
175
+
176
+ found = True
177
+ for npz_file in npz_files:
178
+ if not os.path.exists(npz_file):
179
+ found = False
180
+ break
181
+
182
+ dat = np.load(npz_file)['arr_0']
183
+ if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認
184
+ found = False
185
+ break
186
+ if found:
187
+ continue
188
+
189
+ # 画像をリサイズしてトリミングする
190
+ # PILにinter_areaがないのでcv2で……
191
+ image = np.array(image)
192
+ if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要?
193
+ image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
194
+
195
+ if resized_size[0] > reso[0]:
196
+ trim_size = resized_size[0] - reso[0]
197
+ image = image[:, trim_size//2:trim_size//2 + reso[0]]
198
+
199
+ if resized_size[1] > reso[1]:
200
+ trim_size = resized_size[1] - reso[1]
201
+ image = image[trim_size//2:trim_size//2 + reso[1]]
202
+
203
+ assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
204
+
205
+ # # debug
206
+ # cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1])
207
+
208
+ # バッチへ追加
209
+ bucket_manager.add_image(reso, (image_key, image))
210
+
211
+ # バッチを推論するか判定して推論する
212
+ process_batch(False)
213
+
214
+ # 残りを処理する
215
+ process_batch(True)
216
+
217
+ bucket_manager.sort()
218
+ for i, reso in enumerate(bucket_manager.resos):
219
+ count = bucket_counts.get(reso, 0)
220
+ if count > 0:
221
+ print(f"bucket {i} {reso}: {count}")
222
+ img_ar_errors = np.array(img_ar_errors)
223
+ print(f"mean ar error: {np.mean(img_ar_errors)}")
224
+
225
+ # metadataを書き出して終わり
226
+ print(f"writing metadata: {args.out_json}")
227
+ with open(args.out_json, "wt", encoding='utf-8') as f:
228
+ json.dump(metadata, f, indent=2)
229
+ print("done!")
230
+
231
+
232
+ if __name__ == '__main__':
233
+ parser = argparse.ArgumentParser()
234
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
235
+ parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
236
+ parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
237
+ parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
238
+ parser.add_argument("--v2", action='store_true',
239
+ help='not used (for backward compatibility) / 使用されません(互換性のため残してあります)')
240
+ parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
241
+ parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
242
+ help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
243
+ parser.add_argument("--max_resolution", type=str, default="512,512",
244
+ help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
245
+ parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
246
+ parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
247
+ parser.add_argument("--bucket_reso_steps", type=int, default=64,
248
+ help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
249
+ parser.add_argument("--bucket_no_upscale", action="store_true",
250
+ help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
251
+ parser.add_argument("--mixed_precision", type=str, default="no",
252
+ choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
253
+ parser.add_argument("--full_path", action="store_true",
254
+ help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
255
+ parser.add_argument("--flip_aug", action="store_true",
256
+ help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
257
+ parser.add_argument("--skip_existing", action="store_true",
258
+ help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)")
259
+
260
+ args = parser.parse_args()
261
+ main(args)
finetune/tag_images_by_wd14_tagger.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import csv
3
+ import glob
4
+ import os
5
+
6
+ from PIL import Image
7
+ import cv2
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ from tensorflow.keras.models import load_model
11
+ from huggingface_hub import hf_hub_download
12
+ import torch
13
+
14
+ import library.train_util as train_util
15
+
16
+ # from wd14 tagger
17
+ IMAGE_SIZE = 448
18
+
19
+ # wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
20
+ DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2'
21
+ FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
22
+ SUB_DIR = "variables"
23
+ SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
24
+ CSV_FILE = FILES[-1]
25
+
26
+
27
+ def preprocess_image(image):
28
+ image = np.array(image)
29
+ image = image[:, :, ::-1] # RGB->BGR
30
+
31
+ # pad to square
32
+ size = max(image.shape[0:2])
33
+ pad_x = size - image.shape[1]
34
+ pad_y = size - image.shape[0]
35
+ pad_l = pad_x // 2
36
+ pad_t = pad_y // 2
37
+ image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255)
38
+
39
+ interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
40
+ image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
41
+
42
+ image = image.astype(np.float32)
43
+ return image
44
+
45
+
46
+ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
47
+ def __init__(self, image_paths):
48
+ self.images = image_paths
49
+
50
+ def __len__(self):
51
+ return len(self.images)
52
+
53
+ def __getitem__(self, idx):
54
+ img_path = self.images[idx]
55
+
56
+ try:
57
+ image = Image.open(img_path).convert("RGB")
58
+ image = preprocess_image(image)
59
+ tensor = torch.tensor(image)
60
+ except Exception as e:
61
+ print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
62
+ return None
63
+
64
+ return (tensor, img_path)
65
+
66
+
67
+ def collate_fn_remove_corrupted(batch):
68
+ """Collate function that allows to remove corrupted examples in the
69
+ dataloader. It expects that the dataloader returns 'None' when that occurs.
70
+ The 'None's in the batch are removed.
71
+ """
72
+ # Filter out all the Nones (corrupted examples)
73
+ batch = list(filter(lambda x: x is not None, batch))
74
+ return batch
75
+
76
+
77
+ def main(args):
78
+ # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
79
+ # depreacatedの警告が出るけどなくなったらその時
80
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
81
+ if not os.path.exists(args.model_dir) or args.force_download:
82
+ print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
83
+ for file in FILES:
84
+ hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
85
+ for file in SUB_DIR_FILES:
86
+ hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join(
87
+ args.model_dir, SUB_DIR), force_download=True, force_filename=file)
88
+ else:
89
+ print("using existing wd14 tagger model")
90
+
91
+ # 画像を読み込む
92
+ image_paths = train_util.glob_images(args.train_data_dir)
93
+ print(f"found {len(image_paths)} images.")
94
+
95
+ print("loading model and labels")
96
+ model = load_model(args.model_dir)
97
+
98
+ # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
99
+ # 依存ライブラリを増やしたくないので自力で読むよ
100
+ with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
101
+ reader = csv.reader(f)
102
+ l = [row for row in reader]
103
+ header = l[0] # tag_id,name,category,count
104
+ rows = l[1:]
105
+ assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}"
106
+
107
+ tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ
108
+
109
+ # 推論する
110
+ def run_batch(path_imgs):
111
+ imgs = np.array([im for _, im in path_imgs])
112
+
113
+ probs = model(imgs, training=False)
114
+ probs = probs.numpy()
115
+
116
+ for (image_path, _), prob in zip(path_imgs, probs):
117
+ # 最初の4つはratingなので無視する
118
+ # # First 4 labels are actually ratings: pick one with argmax
119
+ # ratings_names = label_names[:4]
120
+ # rating_index = ratings_names["probs"].argmax()
121
+ # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
122
+
123
+ # それ以降はタグなのでconfidenceがthresholdより高いものを追加する
124
+ # Everything else is tags: pick any where prediction confidence > threshold
125
+ tag_text = ""
126
+ for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
127
+ if p >= args.thresh and i < len(tags):
128
+ tag_text += ", " + tags[i]
129
+
130
+ if len(tag_text) > 0:
131
+ tag_text = tag_text[2:] # 最初の ", " を消す
132
+
133
+ with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
134
+ f.write(tag_text + '\n')
135
+ if args.debug:
136
+ print(image_path, tag_text)
137
+
138
+ # 読み込みの高速化のためにDataLoaderを使うオプション
139
+ if args.max_data_loader_n_workers is not None:
140
+ dataset = ImageLoadingPrepDataset(image_paths)
141
+ data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
142
+ num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
143
+ else:
144
+ data = [[(None, ip)] for ip in image_paths]
145
+
146
+ b_imgs = []
147
+ for data_entry in tqdm(data, smoothing=0.0):
148
+ for data in data_entry:
149
+ if data is None:
150
+ continue
151
+
152
+ image, image_path = data
153
+ if image is not None:
154
+ image = image.detach().numpy()
155
+ else:
156
+ try:
157
+ image = Image.open(image_path)
158
+ if image.mode != 'RGB':
159
+ image = image.convert("RGB")
160
+ image = preprocess_image(image)
161
+ except Exception as e:
162
+ print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
163
+ continue
164
+ b_imgs.append((image_path, image))
165
+
166
+ if len(b_imgs) >= args.batch_size:
167
+ run_batch(b_imgs)
168
+ b_imgs.clear()
169
+
170
+ if len(b_imgs) > 0:
171
+ run_batch(b_imgs)
172
+
173
+ print("done!")
174
+
175
+
176
+ if __name__ == '__main__':
177
+ parser = argparse.ArgumentParser()
178
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
179
+ parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
180
+ help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
181
+ parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
182
+ help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ")
183
+ parser.add_argument("--force_download", action='store_true',
184
+ help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします")
185
+ parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
186
+ parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
187
+ parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
188
+ help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
189
+ parser.add_argument("--caption_extention", type=str, default=None,
190
+ help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
191
+ parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
192
+ parser.add_argument("--debug", action="store_true", help="debug mode")
193
+
194
+ args = parser.parse_args()
195
+
196
+ # スペルミスしていたオプションを復元する
197
+ if args.caption_extention is not None:
198
+ args.caption_extension = args.caption_extention
199
+
200
+ main(args)
gen_img_diffusers.py CHANGED
@@ -47,7 +47,7 @@ VGG(
47
  """
48
 
49
  import json
50
- from typing import List, Optional, Union
51
  import glob
52
  import importlib
53
  import inspect
@@ -60,7 +60,6 @@ import math
60
  import os
61
  import random
62
  import re
63
- from typing import Any, Callable, List, Optional, Union
64
 
65
  import diffusers
66
  import numpy as np
@@ -81,6 +80,9 @@ from PIL import Image
81
  from PIL.PngImagePlugin import PngInfo
82
 
83
  import library.model_util as model_util
 
 
 
84
 
85
  # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
86
  TOKENIZER_PATH = "openai/clip-vit-large-patch14"
@@ -487,6 +489,9 @@ class PipelineLike():
487
  self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
488
  self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
489
 
 
 
 
490
  # Textual Inversion
491
  def add_token_replacement(self, target_token_id, rep_token_ids):
492
  self.token_replacements[target_token_id] = rep_token_ids
@@ -500,7 +505,11 @@ class PipelineLike():
500
  new_tokens.append(token)
501
  return new_tokens
502
 
 
 
 
503
  # region xformersとか使う部分:独自に書き換えるので関係なし
 
504
  def enable_xformers_memory_efficient_attention(self):
505
  r"""
506
  Enable memory efficient attention as implemented in xformers.
@@ -581,6 +590,8 @@ class PipelineLike():
581
  latents: Optional[torch.FloatTensor] = None,
582
  max_embeddings_multiples: Optional[int] = 3,
583
  output_type: Optional[str] = "pil",
 
 
584
  # return_dict: bool = True,
585
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
586
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
@@ -672,6 +683,9 @@ class PipelineLike():
672
  else:
673
  raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
674
 
 
 
 
675
  if strength < 0 or strength > 1:
676
  raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
677
 
@@ -752,7 +766,7 @@ class PipelineLike():
752
  text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
753
  text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
754
 
755
- if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None:
756
  if isinstance(clip_guide_images, PIL.Image.Image):
757
  clip_guide_images = [clip_guide_images]
758
 
@@ -765,7 +779,7 @@ class PipelineLike():
765
  image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
766
  if len(image_embeddings_clip) == 1:
767
  image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
768
- else:
769
  size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
770
  clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
771
  clip_guide_images = torch.cat(clip_guide_images, dim=0)
@@ -774,6 +788,10 @@ class PipelineLike():
774
  image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
775
  if len(image_embeddings_vgg16) == 1:
776
  image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
 
 
 
 
777
 
778
  # set timesteps
779
  self.scheduler.set_timesteps(num_inference_steps, self.device)
@@ -781,7 +799,6 @@ class PipelineLike():
781
  latents_dtype = text_embeddings.dtype
782
  init_latents_orig = None
783
  mask = None
784
- noise = None
785
 
786
  if init_image is None:
787
  # get the initial random noise unless the user supplied it
@@ -813,6 +830,8 @@ class PipelineLike():
813
  if isinstance(init_image[0], PIL.Image.Image):
814
  init_image = [preprocess_image(im) for im in init_image]
815
  init_image = torch.cat(init_image)
 
 
816
 
817
  # mask image to tensor
818
  if mask_image is not None:
@@ -823,9 +842,24 @@ class PipelineLike():
823
 
824
  # encode the init image into latents and scale the latents
825
  init_image = init_image.to(device=self.device, dtype=latents_dtype)
826
- init_latent_dist = self.vae.encode(init_image).latent_dist
827
- init_latents = init_latent_dist.sample(generator=generator)
828
- init_latents = 0.18215 * init_latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
829
  if len(init_latents) == 1:
830
  init_latents = init_latents.repeat((batch_size, 1, 1, 1))
831
  init_latents_orig = init_latents
@@ -864,12 +898,21 @@ class PipelineLike():
864
  extra_step_kwargs["eta"] = eta
865
 
866
  num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
 
 
 
 
867
  for i, t in enumerate(tqdm(timesteps)):
868
  # expand the latents if we are doing classifier free guidance
869
  latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
870
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
871
  # predict the noise residual
872
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
 
 
 
 
873
 
874
  # perform guidance
875
  if do_classifier_free_guidance:
@@ -911,8 +954,19 @@ class PipelineLike():
911
  if is_cancelled_callback is not None and is_cancelled_callback():
912
  return None
913
 
 
 
 
914
  latents = 1 / 0.18215 * latents
915
- image = self.vae.decode(latents).sample
 
 
 
 
 
 
 
 
916
 
917
  image = (image / 2 + 0.5).clamp(0, 1)
918
 
@@ -1595,10 +1649,11 @@ def get_unweighted_text_embeddings(
1595
  if pad == eos: # v1
1596
  text_input_chunk[:, -1] = text_input[0, -1]
1597
  else: # v2
1598
- if text_input_chunk[:, -1] != eos and text_input_chunk[:, -1] != pad: # 最後に普通の文字がある
1599
- text_input_chunk[:, -1] = eos
1600
- if text_input_chunk[:, 1] == pad: # BOSだけであとはPAD
1601
- text_input_chunk[:, 1] = eos
 
1602
 
1603
  if clip_skip is None or clip_skip == 1:
1604
  text_embedding = pipe.text_encoder(text_input_chunk)[0]
@@ -1799,7 +1854,7 @@ def preprocess_mask(mask):
1799
  mask = mask.convert("L")
1800
  w, h = mask.size
1801
  w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
1802
- mask = mask.resize((w // 8, h // 8), resample=PIL.Image.LANCZOS)
1803
  mask = np.array(mask).astype(np.float32) / 255.0
1804
  mask = np.tile(mask, (4, 1, 1))
1805
  mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
@@ -1817,6 +1872,35 @@ def preprocess_mask(mask):
1817
  # return text_encoder
1818
 
1819
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1820
  def main(args):
1821
  if args.fp16:
1822
  dtype = torch.float16
@@ -1881,10 +1965,7 @@ def main(args):
1881
  # tokenizerを読み込む
1882
  print("loading tokenizer")
1883
  if use_stable_diffusion_format:
1884
- if args.v2:
1885
- tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
1886
- else:
1887
- tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
1888
 
1889
  # schedulerを用意する
1890
  sched_init_args = {}
@@ -1995,11 +2076,13 @@ def main(args):
1995
  # networkを組み込む
1996
  if args.network_module:
1997
  networks = []
 
1998
  for i, network_module in enumerate(args.network_module):
1999
  print("import network module:", network_module)
2000
  imported_module = importlib.import_module(network_module)
2001
 
2002
  network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
 
2003
 
2004
  net_kwargs = {}
2005
  if args.network_args and i < len(args.network_args):
@@ -2014,7 +2097,7 @@ def main(args):
2014
  network_weight = args.network_weights[i]
2015
  print("load network weights from:", network_weight)
2016
 
2017
- if model_util.is_safetensors(network_weight):
2018
  from safetensors.torch import safe_open
2019
  with safe_open(network_weight, framework="pt") as f:
2020
  metadata = f.metadata()
@@ -2037,6 +2120,18 @@ def main(args):
2037
  else:
2038
  networks = []
2039
 
 
 
 
 
 
 
 
 
 
 
 
 
2040
  if args.opt_channels_last:
2041
  print(f"set optimizing: channels last")
2042
  text_encoder.to(memory_format=torch.channels_last)
@@ -2050,9 +2145,14 @@ def main(args):
2050
  if vgg16_model is not None:
2051
  vgg16_model.to(memory_format=torch.channels_last)
2052
 
 
 
 
 
2053
  pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
2054
  clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
2055
  vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
 
2056
  print("pipeline is ready.")
2057
 
2058
  if args.diffusers_xformers:
@@ -2177,18 +2277,34 @@ def main(args):
2177
  mask_images = l
2178
 
2179
  # 画像サイズにオプション指定があるときはリサイズする
2180
- if init_images is not None and args.W is not None and args.H is not None:
2181
- print(f"resize img2img source images to {args.W}*{args.H}")
2182
- init_images = resize_images(init_images, (args.W, args.H))
 
2183
  if mask_images is not None:
2184
  print(f"resize img2img mask images to {args.W}*{args.H}")
2185
  mask_images = resize_images(mask_images, (args.W, args.H))
2186
 
 
 
 
 
 
 
 
 
 
 
 
 
2187
  prev_image = None # for VGG16 guided
2188
  if args.guide_image_path is not None:
2189
- print(f"load image for CLIP/VGG16 guidance: {args.guide_image_path}")
2190
- guide_images = load_images(args.guide_image_path)
2191
- print(f"loaded {len(guide_images)} guide images for CLIP/VGG16 guidance")
 
 
 
2192
  if len(guide_images) == 0:
2193
  print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
2194
  guide_images = None
@@ -2219,33 +2335,46 @@ def main(args):
2219
  iter_seed = random.randint(0, 0x7fffffff)
2220
 
2221
  # バッチ処理の関数
2222
- def process_batch(batch, highres_fix, highres_1st=False):
2223
  batch_size = len(batch)
2224
 
2225
  # highres_fixの処理
2226
  if highres_fix and not highres_1st:
2227
- # 1st stageのバッチを作成して呼び出す
2228
- print("process 1st stage1")
2229
  batch_1st = []
2230
- for params1, (width, height, steps, scale, negative_scale, strength) in batch:
2231
- width_1st = int(width * args.highres_fix_scale + .5)
2232
- height_1st = int(height * args.highres_fix_scale + .5)
2233
  width_1st = width_1st - width_1st % 32
2234
  height_1st = height_1st - height_1st % 32
2235
- batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, negative_scale, strength)))
 
 
 
2236
  images_1st = process_batch(batch_1st, True, True)
2237
 
2238
  # 2nd stageのバッチを作成して以下処理する
2239
- print("process 2nd stage1")
 
 
 
 
 
 
 
 
2240
  batch_2nd = []
2241
- for i, (b1, image) in enumerate(zip(batch, images_1st)):
2242
- image = image.resize((width, height), resample=PIL.Image.LANCZOS)
2243
- (step, prompt, negative_prompt, seed, _, _, clip_prompt, guide_image), params2 = b1
2244
- batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2))
 
2245
  batch = batch_2nd
2246
 
2247
- (step_first, _, _, _, init_image, mask_image, _, guide_image), (width,
2248
- height, steps, scale, negative_scale, strength) = batch[0]
 
2249
  noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
2250
 
2251
  prompts = []
@@ -2278,7 +2407,7 @@ def main(args):
2278
  all_images_are_same = True
2279
  all_masks_are_same = True
2280
  all_guide_images_are_same = True
2281
- for i, ((_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
2282
  prompts.append(prompt)
2283
  negative_prompts.append(negative_prompt)
2284
  seeds.append(seed)
@@ -2295,9 +2424,13 @@ def main(args):
2295
  all_masks_are_same = mask_images[-2] is mask_image
2296
 
2297
  if guide_image is not None:
2298
- guide_images.append(guide_image)
2299
- if i > 0 and all_guide_images_are_same:
2300
- all_guide_images_are_same = guide_images[-2] is guide_image
 
 
 
 
2301
 
2302
  # make start code
2303
  torch.manual_seed(seed)
@@ -2320,10 +2453,24 @@ def main(args):
2320
  if guide_images is not None and all_guide_images_are_same:
2321
  guide_images = guide_images[0]
2322
 
 
 
 
 
 
 
 
 
2323
  # generate
 
 
 
 
2324
  images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
2325
- output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
2326
- if highres_1st and not args.highres_fix_save_1st:
 
 
2327
  return images
2328
 
2329
  # save image
@@ -2398,6 +2545,7 @@ def main(args):
2398
  strength = 0.8 if args.strength is None else args.strength
2399
  negative_prompt = ""
2400
  clip_prompt = None
 
2401
 
2402
  prompt_args = prompt.strip().split(' --')
2403
  prompt = prompt_args[0]
@@ -2461,6 +2609,15 @@ def main(args):
2461
  clip_prompt = m.group(1)
2462
  print(f"clip prompt: {clip_prompt}")
2463
  continue
 
 
 
 
 
 
 
 
 
2464
  except ValueError as ex:
2465
  print(f"Exception in parsing / 解析エラー: {parg}")
2466
  print(ex)
@@ -2498,7 +2655,12 @@ def main(args):
2498
  mask_image = mask_images[global_step % len(mask_images)]
2499
 
2500
  if guide_images is not None:
2501
- guide_image = guide_images[global_step % len(guide_images)]
 
 
 
 
 
2502
  elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
2503
  if prev_image is None:
2504
  print("Generate 1st image without guide image.")
@@ -2506,10 +2668,9 @@ def main(args):
2506
  print("Use previous image as guide image.")
2507
  guide_image = prev_image
2508
 
2509
- # TODO named tupleか何かにする
2510
- b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
2511
- (width, height, steps, scale, negative_scale, strength))
2512
- if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
2513
  process_batch(batch_data, highres_fix)
2514
  batch_data.clear()
2515
 
@@ -2553,6 +2714,8 @@ if __name__ == '__main__':
2553
  parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
2554
  parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
2555
  parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
 
 
2556
  parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
2557
  parser.add_argument('--sampler', type=str, default='ddim',
2558
  choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
@@ -2564,6 +2727,8 @@ if __name__ == '__main__':
2564
  parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
2565
  parser.add_argument("--vae", type=str, default=None,
2566
  help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
 
 
2567
  # parser.add_argument("--replace_clip_l14_336", action='store_true',
2568
  # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
2569
  parser.add_argument("--seed", type=int, default=None,
@@ -2578,12 +2743,15 @@ if __name__ == '__main__':
2578
  parser.add_argument("--opt_channels_last", action='store_true',
2579
  help='set channels last option to model / モデルにchannels lastを指定し最適化する')
2580
  parser.add_argument("--network_module", type=str, default=None, nargs='*',
2581
- help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
2582
  parser.add_argument("--network_weights", type=str, default=None, nargs='*',
2583
- help='Hypernetwork weights to load / Hypernetworkの重み')
2584
- parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
 
2585
  parser.add_argument("--network_args", type=str, default=None, nargs='*',
2586
  help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
 
 
2587
  parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
2588
  help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
2589
  parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
@@ -2597,15 +2765,26 @@ if __name__ == '__main__':
2597
  help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
2598
  parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
2599
  help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
2600
- parser.add_argument("--guide_image_path", type=str, default=None, help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
 
2601
  parser.add_argument("--highres_fix_scale", type=float, default=None,
2602
  help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
2603
  parser.add_argument("--highres_fix_steps", type=int, default=28,
2604
  help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
2605
  parser.add_argument("--highres_fix_save_1st", action='store_true',
2606
  help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
 
 
2607
  parser.add_argument("--negative_scale", type=float, default=None,
2608
  help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
2609
 
 
 
 
 
 
 
 
 
2610
  args = parser.parse_args()
2611
  main(args)
 
47
  """
48
 
49
  import json
50
+ from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
51
  import glob
52
  import importlib
53
  import inspect
 
60
  import os
61
  import random
62
  import re
 
63
 
64
  import diffusers
65
  import numpy as np
 
80
  from PIL.PngImagePlugin import PngInfo
81
 
82
  import library.model_util as model_util
83
+ import library.train_util as train_util
84
+ import tools.original_control_net as original_control_net
85
+ from tools.original_control_net import ControlNetInfo
86
 
87
  # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
88
  TOKENIZER_PATH = "openai/clip-vit-large-patch14"
 
489
  self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
490
  self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
491
 
492
+ # ControlNet
493
+ self.control_nets: List[ControlNetInfo] = []
494
+
495
  # Textual Inversion
496
  def add_token_replacement(self, target_token_id, rep_token_ids):
497
  self.token_replacements[target_token_id] = rep_token_ids
 
505
  new_tokens.append(token)
506
  return new_tokens
507
 
508
+ def set_control_nets(self, ctrl_nets):
509
+ self.control_nets = ctrl_nets
510
+
511
  # region xformersとか使う部分:独自に書き換えるので関係なし
512
+
513
  def enable_xformers_memory_efficient_attention(self):
514
  r"""
515
  Enable memory efficient attention as implemented in xformers.
 
590
  latents: Optional[torch.FloatTensor] = None,
591
  max_embeddings_multiples: Optional[int] = 3,
592
  output_type: Optional[str] = "pil",
593
+ vae_batch_size: float = None,
594
+ return_latents: bool = False,
595
  # return_dict: bool = True,
596
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
597
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
 
683
  else:
684
  raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
685
 
686
+ vae_batch_size = batch_size if vae_batch_size is None else (
687
+ int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size)))
688
+
689
  if strength < 0 or strength > 1:
690
  raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
691
 
 
766
  text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
767
  text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
768
 
769
+ if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None or self.control_nets:
770
  if isinstance(clip_guide_images, PIL.Image.Image):
771
  clip_guide_images = [clip_guide_images]
772
 
 
779
  image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
780
  if len(image_embeddings_clip) == 1:
781
  image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
782
+ elif self.vgg16_guidance_scale > 0:
783
  size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
784
  clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
785
  clip_guide_images = torch.cat(clip_guide_images, dim=0)
 
788
  image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
789
  if len(image_embeddings_vgg16) == 1:
790
  image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
791
+ else:
792
+ # ControlNetのhintにguide imageを流用する
793
+ # 前処理はControlNet側で行う
794
+ pass
795
 
796
  # set timesteps
797
  self.scheduler.set_timesteps(num_inference_steps, self.device)
 
799
  latents_dtype = text_embeddings.dtype
800
  init_latents_orig = None
801
  mask = None
 
802
 
803
  if init_image is None:
804
  # get the initial random noise unless the user supplied it
 
830
  if isinstance(init_image[0], PIL.Image.Image):
831
  init_image = [preprocess_image(im) for im in init_image]
832
  init_image = torch.cat(init_image)
833
+ if isinstance(init_image, list):
834
+ init_image = torch.stack(init_image)
835
 
836
  # mask image to tensor
837
  if mask_image is not None:
 
842
 
843
  # encode the init image into latents and scale the latents
844
  init_image = init_image.to(device=self.device, dtype=latents_dtype)
845
+ if init_image.size()[2:] == (height // 8, width // 8):
846
+ init_latents = init_image
847
+ else:
848
+ if vae_batch_size >= batch_size:
849
+ init_latent_dist = self.vae.encode(init_image).latent_dist
850
+ init_latents = init_latent_dist.sample(generator=generator)
851
+ else:
852
+ if torch.cuda.is_available():
853
+ torch.cuda.empty_cache()
854
+ init_latents = []
855
+ for i in tqdm(range(0, batch_size, vae_batch_size)):
856
+ init_latent_dist = self.vae.encode(init_image[i:i + vae_batch_size]
857
+ if vae_batch_size > 1 else init_image[i].unsqueeze(0)).latent_dist
858
+ init_latents.append(init_latent_dist.sample(generator=generator))
859
+ init_latents = torch.cat(init_latents)
860
+
861
+ init_latents = 0.18215 * init_latents
862
+
863
  if len(init_latents) == 1:
864
  init_latents = init_latents.repeat((batch_size, 1, 1, 1))
865
  init_latents_orig = init_latents
 
898
  extra_step_kwargs["eta"] = eta
899
 
900
  num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
901
+
902
+ if self.control_nets:
903
+ guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
904
+
905
  for i, t in enumerate(tqdm(timesteps)):
906
  # expand the latents if we are doing classifier free guidance
907
  latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
908
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
909
+
910
  # predict the noise residual
911
+ if self.control_nets:
912
+ noise_pred = original_control_net.call_unet_and_control_net(
913
+ i, num_latent_input, self.unet, self.control_nets, guided_hints, i / len(timesteps), latent_model_input, t, text_embeddings).sample
914
+ else:
915
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
916
 
917
  # perform guidance
918
  if do_classifier_free_guidance:
 
954
  if is_cancelled_callback is not None and is_cancelled_callback():
955
  return None
956
 
957
+ if return_latents:
958
+ return (latents, False)
959
+
960
  latents = 1 / 0.18215 * latents
961
+ if vae_batch_size >= batch_size:
962
+ image = self.vae.decode(latents).sample
963
+ else:
964
+ if torch.cuda.is_available():
965
+ torch.cuda.empty_cache()
966
+ images = []
967
+ for i in tqdm(range(0, batch_size, vae_batch_size)):
968
+ images.append(self.vae.decode(latents[i:i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample)
969
+ image = torch.cat(images)
970
 
971
  image = (image / 2 + 0.5).clamp(0, 1)
972
 
 
1649
  if pad == eos: # v1
1650
  text_input_chunk[:, -1] = text_input[0, -1]
1651
  else: # v2
1652
+ for j in range(len(text_input_chunk)):
1653
+ if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
1654
+ text_input_chunk[j, -1] = eos
1655
+ if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
1656
+ text_input_chunk[j, 1] = eos
1657
 
1658
  if clip_skip is None or clip_skip == 1:
1659
  text_embedding = pipe.text_encoder(text_input_chunk)[0]
 
1854
  mask = mask.convert("L")
1855
  w, h = mask.size
1856
  w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
1857
+ mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS)
1858
  mask = np.array(mask).astype(np.float32) / 255.0
1859
  mask = np.tile(mask, (4, 1, 1))
1860
  mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
 
1872
  # return text_encoder
1873
 
1874
 
1875
+ class BatchDataBase(NamedTuple):
1876
+ # バッチ分割が必要ないデータ
1877
+ step: int
1878
+ prompt: str
1879
+ negative_prompt: str
1880
+ seed: int
1881
+ init_image: Any
1882
+ mask_image: Any
1883
+ clip_prompt: str
1884
+ guide_image: Any
1885
+
1886
+
1887
+ class BatchDataExt(NamedTuple):
1888
+ # バッチ分割が必要なデータ
1889
+ width: int
1890
+ height: int
1891
+ steps: int
1892
+ scale: float
1893
+ negative_scale: float
1894
+ strength: float
1895
+ network_muls: Tuple[float]
1896
+
1897
+
1898
+ class BatchData(NamedTuple):
1899
+ return_latents: bool
1900
+ base: BatchDataBase
1901
+ ext: BatchDataExt
1902
+
1903
+
1904
  def main(args):
1905
  if args.fp16:
1906
  dtype = torch.float16
 
1965
  # tokenizerを読み込む
1966
  print("loading tokenizer")
1967
  if use_stable_diffusion_format:
1968
+ tokenizer = train_util.load_tokenizer(args)
 
 
 
1969
 
1970
  # schedulerを用意する
1971
  sched_init_args = {}
 
2076
  # networkを組み込む
2077
  if args.network_module:
2078
  networks = []
2079
+ network_default_muls = []
2080
  for i, network_module in enumerate(args.network_module):
2081
  print("import network module:", network_module)
2082
  imported_module = importlib.import_module(network_module)
2083
 
2084
  network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
2085
+ network_default_muls.append(network_mul)
2086
 
2087
  net_kwargs = {}
2088
  if args.network_args and i < len(args.network_args):
 
2097
  network_weight = args.network_weights[i]
2098
  print("load network weights from:", network_weight)
2099
 
2100
+ if model_util.is_safetensors(network_weight) and args.network_show_meta:
2101
  from safetensors.torch import safe_open
2102
  with safe_open(network_weight, framework="pt") as f:
2103
  metadata = f.metadata()
 
2120
  else:
2121
  networks = []
2122
 
2123
+ # ControlNetの処理
2124
+ control_nets: List[ControlNetInfo] = []
2125
+ if args.control_net_models:
2126
+ for i, model in enumerate(args.control_net_models):
2127
+ prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
2128
+ weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
2129
+ ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
2130
+
2131
+ ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
2132
+ prep = original_control_net.load_preprocess(prep_type)
2133
+ control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
2134
+
2135
  if args.opt_channels_last:
2136
  print(f"set optimizing: channels last")
2137
  text_encoder.to(memory_format=torch.channels_last)
 
2145
  if vgg16_model is not None:
2146
  vgg16_model.to(memory_format=torch.channels_last)
2147
 
2148
+ for cn in control_nets:
2149
+ cn.unet.to(memory_format=torch.channels_last)
2150
+ cn.net.to(memory_format=torch.channels_last)
2151
+
2152
  pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
2153
  clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
2154
  vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
2155
+ pipe.set_control_nets(control_nets)
2156
  print("pipeline is ready.")
2157
 
2158
  if args.diffusers_xformers:
 
2277
  mask_images = l
2278
 
2279
  # 画像サイズにオプション指定があるときはリサイズする
2280
+ if args.W is not None and args.H is not None:
2281
+ if init_images is not None:
2282
+ print(f"resize img2img source images to {args.W}*{args.H}")
2283
+ init_images = resize_images(init_images, (args.W, args.H))
2284
  if mask_images is not None:
2285
  print(f"resize img2img mask images to {args.W}*{args.H}")
2286
  mask_images = resize_images(mask_images, (args.W, args.H))
2287
 
2288
+ if networks and mask_images:
2289
+ # mask を領域情報として流用する、現在は1枚だけ対応
2290
+ # TODO 複数のnetwork classの混在時の考慮
2291
+ print("use mask as region")
2292
+ # import cv2
2293
+ # for i in range(3):
2294
+ # cv2.imshow("msk", np.array(mask_images[0])[:,:,i])
2295
+ # cv2.waitKey()
2296
+ # cv2.destroyAllWindows()
2297
+ networks[0].__class__.set_regions(networks, np.array(mask_images[0]))
2298
+ mask_images = None
2299
+
2300
  prev_image = None # for VGG16 guided
2301
  if args.guide_image_path is not None:
2302
+ print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
2303
+ guide_images = []
2304
+ for p in args.guide_image_path:
2305
+ guide_images.extend(load_images(p))
2306
+
2307
+ print(f"loaded {len(guide_images)} guide images for guidance")
2308
  if len(guide_images) == 0:
2309
  print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
2310
  guide_images = None
 
2335
  iter_seed = random.randint(0, 0x7fffffff)
2336
 
2337
  # バッチ処理の関数
2338
+ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
2339
  batch_size = len(batch)
2340
 
2341
  # highres_fixの処理
2342
  if highres_fix and not highres_1st:
2343
+ # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
2344
+ print("process 1st stage")
2345
  batch_1st = []
2346
+ for _, base, ext in batch:
2347
+ width_1st = int(ext.width * args.highres_fix_scale + .5)
2348
+ height_1st = int(ext.height * args.highres_fix_scale + .5)
2349
  width_1st = width_1st - width_1st % 32
2350
  height_1st = height_1st - height_1st % 32
2351
+
2352
+ ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale,
2353
+ ext.negative_scale, ext.strength, ext.network_muls)
2354
+ batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
2355
  images_1st = process_batch(batch_1st, True, True)
2356
 
2357
  # 2nd stageのバッチを作成して以下処理する
2358
+ print("process 2nd stage")
2359
+ if args.highres_fix_latents_upscaling:
2360
+ org_dtype = images_1st.dtype
2361
+ if images_1st.dtype == torch.bfloat16:
2362
+ images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
2363
+ images_1st = torch.nn.functional.interpolate(
2364
+ images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode='bilinear') # , antialias=True)
2365
+ images_1st = images_1st.to(org_dtype)
2366
+
2367
  batch_2nd = []
2368
+ for i, (bd, image) in enumerate(zip(batch, images_1st)):
2369
+ if not args.highres_fix_latents_upscaling:
2370
+ image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
2371
+ bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext)
2372
+ batch_2nd.append(bd_2nd)
2373
  batch = batch_2nd
2374
 
2375
+ # このバッチの情報を取り出す
2376
+ return_latents, (step_first, _, _, _, init_image, mask_image, _, guide_image), \
2377
+ (width, height, steps, scale, negative_scale, strength, network_muls) = batch[0]
2378
  noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
2379
 
2380
  prompts = []
 
2407
  all_images_are_same = True
2408
  all_masks_are_same = True
2409
  all_guide_images_are_same = True
2410
+ for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
2411
  prompts.append(prompt)
2412
  negative_prompts.append(negative_prompt)
2413
  seeds.append(seed)
 
2424
  all_masks_are_same = mask_images[-2] is mask_image
2425
 
2426
  if guide_image is not None:
2427
+ if type(guide_image) is list:
2428
+ guide_images.extend(guide_image)
2429
+ all_guide_images_are_same = False
2430
+ else:
2431
+ guide_images.append(guide_image)
2432
+ if i > 0 and all_guide_images_are_same:
2433
+ all_guide_images_are_same = guide_images[-2] is guide_image
2434
 
2435
  # make start code
2436
  torch.manual_seed(seed)
 
2453
  if guide_images is not None and all_guide_images_are_same:
2454
  guide_images = guide_images[0]
2455
 
2456
+ # ControlNet使用時はguide imageをリサイズする
2457
+ if control_nets:
2458
+ # TODO resample��メソッド
2459
+ guide_images = guide_images if type(guide_images) == list else [guide_images]
2460
+ guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images]
2461
+ if len(guide_images) == 1:
2462
+ guide_images = guide_images[0]
2463
+
2464
  # generate
2465
+ if networks:
2466
+ for n, m in zip(networks, network_muls if network_muls else network_default_muls):
2467
+ n.set_multiplier(m)
2468
+
2469
  images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
2470
+ output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises,
2471
+ vae_batch_size=args.vae_batch_size, return_latents=return_latents,
2472
+ clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
2473
+ if highres_1st and not args.highres_fix_save_1st: # return images or latents
2474
  return images
2475
 
2476
  # save image
 
2545
  strength = 0.8 if args.strength is None else args.strength
2546
  negative_prompt = ""
2547
  clip_prompt = None
2548
+ network_muls = None
2549
 
2550
  prompt_args = prompt.strip().split(' --')
2551
  prompt = prompt_args[0]
 
2609
  clip_prompt = m.group(1)
2610
  print(f"clip prompt: {clip_prompt}")
2611
  continue
2612
+
2613
+ m = re.match(r'am ([\d\.\-,]+)', parg, re.IGNORECASE)
2614
+ if m: # network multiplies
2615
+ network_muls = [float(v) for v in m.group(1).split(",")]
2616
+ while len(network_muls) < len(networks):
2617
+ network_muls.append(network_muls[-1])
2618
+ print(f"network mul: {network_muls}")
2619
+ continue
2620
+
2621
  except ValueError as ex:
2622
  print(f"Exception in parsing / 解析エラー: {parg}")
2623
  print(ex)
 
2655
  mask_image = mask_images[global_step % len(mask_images)]
2656
 
2657
  if guide_images is not None:
2658
+ if control_nets: # 複数件の場合あり
2659
+ c = len(control_nets)
2660
+ p = global_step % (len(guide_images) // c)
2661
+ guide_image = guide_images[p * c:p * c + c]
2662
+ else:
2663
+ guide_image = guide_images[global_step % len(guide_images)]
2664
  elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
2665
  if prev_image is None:
2666
  print("Generate 1st image without guide image.")
 
2668
  print("Use previous image as guide image.")
2669
  guide_image = prev_image
2670
 
2671
+ b1 = BatchData(False, BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
2672
+ BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None))
2673
+ if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
 
2674
  process_batch(batch_data, highres_fix)
2675
  batch_data.clear()
2676
 
 
2714
  parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
2715
  parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
2716
  parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
2717
+ parser.add_argument("--vae_batch_size", type=float, default=None,
2718
+ help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率")
2719
  parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
2720
  parser.add_argument('--sampler', type=str, default='ddim',
2721
  choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
 
2727
  parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
2728
  parser.add_argument("--vae", type=str, default=None,
2729
  help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
2730
+ parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
2731
+ help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
2732
  # parser.add_argument("--replace_clip_l14_336", action='store_true',
2733
  # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
2734
  parser.add_argument("--seed", type=int, default=None,
 
2743
  parser.add_argument("--opt_channels_last", action='store_true',
2744
  help='set channels last option to model / モデルにchannels lastを指定し最適化する')
2745
  parser.add_argument("--network_module", type=str, default=None, nargs='*',
2746
+ help='additional network module to use / 追加ネットワークを使う時そのモジュール名')
2747
  parser.add_argument("--network_weights", type=str, default=None, nargs='*',
2748
+ help='additional network weights to load / 追加ネットワークの重み')
2749
+ parser.add_argument("--network_mul", type=float, default=None, nargs='*',
2750
+ help='additional network multiplier / 追加ネットワークの効果の倍率')
2751
  parser.add_argument("--network_args", type=str, default=None, nargs='*',
2752
  help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
2753
+ parser.add_argument("--network_show_meta", action='store_true',
2754
+ help='show metadata of network model / ネットワークモデルのメタデータを表示する')
2755
  parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
2756
  help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
2757
  parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
 
2765
  help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
2766
  parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
2767
  help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
2768
+ parser.add_argument("--guide_image_path", type=str, default=None, nargs="*",
2769
+ help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
2770
  parser.add_argument("--highres_fix_scale", type=float, default=None,
2771
  help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
2772
  parser.add_argument("--highres_fix_steps", type=int, default=28,
2773
  help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
2774
  parser.add_argument("--highres_fix_save_1st", action='store_true',
2775
  help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
2776
+ parser.add_argument("--highres_fix_latents_upscaling", action='store_true',
2777
+ help="use latents upscaling for highres fix / highres fixでlatentで拡大する")
2778
  parser.add_argument("--negative_scale", type=float, default=None,
2779
  help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
2780
 
2781
+ parser.add_argument("--control_net_models", type=str, default=None, nargs='*',
2782
+ help='ControlNet models to use / 使用するControlNetのモデル名')
2783
+ parser.add_argument("--control_net_preps", type=str, default=None, nargs='*',
2784
+ help='ControlNet preprocess to use / 使用するControlNetのプリプロセス名')
2785
+ parser.add_argument("--control_net_weights", type=float, default=None, nargs='*', help='ControlNet weights / ControlNetの重み')
2786
+ parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*',
2787
+ help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率')
2788
+
2789
  args = parser.parse_args()
2790
  main(args)
library/model_util.py CHANGED
@@ -4,7 +4,7 @@
4
  import math
5
  import os
6
  import torch
7
- from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
8
  from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
9
  from safetensors.torch import load_file, save_file
10
 
@@ -916,7 +916,11 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
916
  info = text_model.load_state_dict(converted_text_encoder_checkpoint)
917
  else:
918
  converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
 
 
919
  text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
 
 
920
  info = text_model.load_state_dict(converted_text_encoder_checkpoint)
921
  print("loading text encoder:", info)
922
 
 
4
  import math
5
  import os
6
  import torch
7
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
8
  from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
9
  from safetensors.torch import load_file, save_file
10
 
 
916
  info = text_model.load_state_dict(converted_text_encoder_checkpoint)
917
  else:
918
  converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
919
+
920
+ logging.set_verbosity_error() # don't show annoying warning
921
  text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
922
+ logging.set_verbosity_warning()
923
+
924
  info = text_model.load_state_dict(converted_text_encoder_checkpoint)
925
  print("loading text encoder:", info)
926
 
library/train_util.py CHANGED
@@ -1,12 +1,21 @@
1
  # common functions for training
2
 
3
  import argparse
 
4
  import json
 
5
  import shutil
6
  import time
7
- from typing import Dict, List, NamedTuple, Tuple
 
 
 
 
 
 
 
 
8
  from accelerate import Accelerator
9
- from torch.autograd.function import Function
10
  import glob
11
  import math
12
  import os
@@ -17,10 +26,16 @@ from io import BytesIO
17
 
18
  from tqdm import tqdm
19
  import torch
 
20
  from torchvision import transforms
21
  from transformers import CLIPTokenizer
 
22
  import diffusers
23
- from diffusers import DDPMScheduler, StableDiffusionPipeline
 
 
 
 
24
  import albumentations as albu
25
  import numpy as np
26
  from PIL import Image
@@ -195,23 +210,95 @@ class BucketBatchIndex(NamedTuple):
195
  batch_index: int
196
 
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  class BaseDataset(torch.utils.data.Dataset):
199
- def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
200
  super().__init__()
201
- self.tokenizer: CLIPTokenizer = tokenizer
202
  self.max_token_length = max_token_length
203
- self.shuffle_caption = shuffle_caption
204
- self.shuffle_keep_tokens = shuffle_keep_tokens
205
  # width/height is used when enable_bucket==False
206
  self.width, self.height = (None, None) if resolution is None else resolution
207
- self.face_crop_aug_range = face_crop_aug_range
208
- self.flip_aug = flip_aug
209
- self.color_aug = color_aug
210
  self.debug_dataset = debug_dataset
211
- self.random_crop = random_crop
 
 
212
  self.token_padding_disabled = False
213
- self.dataset_dirs_info = {}
214
- self.reg_dataset_dirs_info = {}
215
  self.tag_frequency = {}
216
 
217
  self.enable_bucket = False
@@ -225,49 +312,28 @@ class BaseDataset(torch.utils.data.Dataset):
225
  self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
226
 
227
  self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
228
- self.dropout_rate: float = 0
229
- self.dropout_every_n_epochs: int = None
230
- self.tag_dropout_rate: float = 0
231
 
232
  # augmentation
233
- flip_p = 0.5 if flip_aug else 0.0
234
- if color_aug:
235
- # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
236
- self.aug = albu.Compose([
237
- albu.OneOf([
238
- albu.HueSaturationValue(8, 0, 0, p=.5),
239
- albu.RandomGamma((95, 105), p=.5),
240
- ], p=.33),
241
- albu.HorizontalFlip(p=flip_p)
242
- ], p=1.)
243
- elif flip_aug:
244
- self.aug = albu.Compose([
245
- albu.HorizontalFlip(p=flip_p)
246
- ], p=1.)
247
- else:
248
- self.aug = None
249
 
250
  self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
251
 
252
  self.image_data: Dict[str, ImageInfo] = {}
 
253
 
254
  self.replacements = {}
255
 
256
  def set_current_epoch(self, epoch):
257
  self.current_epoch = epoch
258
-
259
- def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
260
- # コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
261
- self.dropout_rate = dropout_rate
262
- self.dropout_every_n_epochs = dropout_every_n_epochs
263
- self.tag_dropout_rate = tag_dropout_rate
264
 
265
  def set_tag_frequency(self, dir_name, captions):
266
  frequency_for_dir = self.tag_frequency.get(dir_name, {})
267
  self.tag_frequency[dir_name] = frequency_for_dir
268
  for caption in captions:
269
  for tag in caption.split(","):
270
- if tag and not tag.isspace():
 
271
  tag = tag.lower()
272
  frequency = frequency_for_dir.get(tag, 0)
273
  frequency_for_dir[tag] = frequency + 1
@@ -278,42 +344,36 @@ class BaseDataset(torch.utils.data.Dataset):
278
  def add_replacement(self, str_from, str_to):
279
  self.replacements[str_from] = str_to
280
 
281
- def process_caption(self, caption):
282
  # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
283
- is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
284
- is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
285
 
286
  if is_drop_out:
287
  caption = ""
288
  else:
289
- if self.shuffle_caption or self.tag_dropout_rate > 0:
290
  def dropout_tags(tokens):
291
- if self.tag_dropout_rate <= 0:
292
  return tokens
293
  l = []
294
  for token in tokens:
295
- if random.random() >= self.tag_dropout_rate:
296
  l.append(token)
297
  return l
298
 
299
- tokens = [t.strip() for t in caption.strip().split(",")]
300
- if self.shuffle_keep_tokens is None:
301
- if self.shuffle_caption:
302
- random.shuffle(tokens)
303
-
304
- tokens = dropout_tags(tokens)
305
- else:
306
- if len(tokens) > self.shuffle_keep_tokens:
307
- keep_tokens = tokens[:self.shuffle_keep_tokens]
308
- tokens = tokens[self.shuffle_keep_tokens:]
309
 
310
- if self.shuffle_caption:
311
- random.shuffle(tokens)
312
 
313
- tokens = dropout_tags(tokens)
314
 
315
- tokens = keep_tokens + tokens
316
- caption = ", ".join(tokens)
317
 
318
  # textual inversion対応
319
  for str_from, str_to in self.replacements.items():
@@ -367,8 +427,9 @@ class BaseDataset(torch.utils.data.Dataset):
367
  input_ids = torch.stack(iids_list) # 3,77
368
  return input_ids
369
 
370
- def register_image(self, info: ImageInfo):
371
  self.image_data[info.image_key] = info
 
372
 
373
  def make_buckets(self):
374
  '''
@@ -467,7 +528,7 @@ class BaseDataset(torch.utils.data.Dataset):
467
  img = np.array(image, np.uint8)
468
  return img
469
 
470
- def trim_and_resize_if_required(self, image, reso, resized_size):
471
  image_height, image_width = image.shape[0:2]
472
 
473
  if image_width != resized_size[0] or image_height != resized_size[1]:
@@ -477,22 +538,27 @@ class BaseDataset(torch.utils.data.Dataset):
477
  image_height, image_width = image.shape[0:2]
478
  if image_width > reso[0]:
479
  trim_size = image_width - reso[0]
480
- p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
481
  # print("w", trim_size, p)
482
  image = image[:, p:p + reso[0]]
483
  if image_height > reso[1]:
484
  trim_size = image_height - reso[1]
485
- p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
486
  # print("h", trim_size, p)
487
  image = image[p:p + reso[1]]
488
 
489
  assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
490
  return image
491
 
 
 
 
492
  def cache_latents(self, vae):
493
  # TODO ここを高速化したい
494
  print("caching latents.")
495
  for info in tqdm(self.image_data.values()):
 
 
496
  if info.latents_npz is not None:
497
  info.latents = self.load_latents_from_npz(info, False)
498
  info.latents = torch.FloatTensor(info.latents)
@@ -502,13 +568,13 @@ class BaseDataset(torch.utils.data.Dataset):
502
  continue
503
 
504
  image = self.load_image(info.absolute_path)
505
- image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
506
 
507
  img_tensor = self.image_transforms(image)
508
  img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
509
  info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
510
 
511
- if self.flip_aug:
512
  image = image[:, ::-1].copy() # cannot convert to Tensor without copy
513
  img_tensor = self.image_transforms(image)
514
  img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
@@ -518,11 +584,11 @@ class BaseDataset(torch.utils.data.Dataset):
518
  image = Image.open(image_path)
519
  return image.size
520
 
521
- def load_image_with_face_info(self, image_path: str):
522
  img = self.load_image(image_path)
523
 
524
  face_cx = face_cy = face_w = face_h = 0
525
- if self.face_crop_aug_range is not None:
526
  tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
527
  if len(tokens) >= 5:
528
  face_cx = int(tokens[-4])
@@ -533,7 +599,7 @@ class BaseDataset(torch.utils.data.Dataset):
533
  return img, face_cx, face_cy, face_w, face_h
534
 
535
  # いい感じに切り出す
536
- def crop_target(self, image, face_cx, face_cy, face_w, face_h):
537
  height, width = image.shape[0:2]
538
  if height == self.height and width == self.width:
539
  return image
@@ -541,8 +607,8 @@ class BaseDataset(torch.utils.data.Dataset):
541
  # 画像サイズはsizeより大きいのでリサイズする
542
  face_size = max(face_w, face_h)
543
  min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
544
- min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
545
- max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
546
  if min_scale >= max_scale: # range指定がmin==max
547
  scale = min_scale
548
  else:
@@ -560,13 +626,13 @@ class BaseDataset(torch.utils.data.Dataset):
560
  for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
561
  p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
562
 
563
- if self.random_crop:
564
  # 背景も含めるために顔を中心に置く確率を高めつつずらす
565
  range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
566
  p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
567
  else:
568
  # range指定があるときのみ、すこしだけランダムに(わりと適当)
569
- if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
570
  if face_size > self.size // 10 and face_size >= 40:
571
  p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
572
 
@@ -589,9 +655,6 @@ class BaseDataset(torch.utils.data.Dataset):
589
  return self._length
590
 
591
  def __getitem__(self, index):
592
- if index == 0:
593
- self.shuffle_buckets()
594
-
595
  bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
596
  bucket_batch_size = self.buckets_indices[index].bucket_batch_size
597
  image_index = self.buckets_indices[index].batch_index * bucket_batch_size
@@ -604,28 +667,29 @@ class BaseDataset(torch.utils.data.Dataset):
604
 
605
  for image_key in bucket[image_index:image_index + bucket_batch_size]:
606
  image_info = self.image_data[image_key]
 
607
  loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
608
 
609
  # image/latentsを処理する
610
  if image_info.latents is not None:
611
- latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped
612
  image = None
613
  elif image_info.latents_npz is not None:
614
- latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5)
615
  latents = torch.FloatTensor(latents)
616
  image = None
617
  else:
618
  # 画像を読み込み、必要ならcropする
619
- img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
620
  im_h, im_w = img.shape[0:2]
621
 
622
  if self.enable_bucket:
623
- img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
624
  else:
625
  if face_cx > 0: # 顔位置情報あり
626
- img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
627
  elif im_h > self.height or im_w > self.width:
628
- assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
629
  if im_h > self.height:
630
  p = random.randint(0, im_h - self.height)
631
  img = img[p:p + self.height]
@@ -637,8 +701,9 @@ class BaseDataset(torch.utils.data.Dataset):
637
  assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
638
 
639
  # augmentation
640
- if self.aug is not None:
641
- img = self.aug(image=img)['image']
 
642
 
643
  latents = None
644
  image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
@@ -646,7 +711,7 @@ class BaseDataset(torch.utils.data.Dataset):
646
  images.append(image)
647
  latents_list.append(latents)
648
 
649
- caption = self.process_caption(image_info.caption)
650
  captions.append(caption)
651
  if not self.token_padding_disabled: # this option might be omitted in future
652
  input_ids_list.append(self.get_input_ids(caption))
@@ -677,9 +742,8 @@ class BaseDataset(torch.utils.data.Dataset):
677
 
678
 
679
  class DreamBoothDataset(BaseDataset):
680
- def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
681
- super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
682
- resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
683
 
684
  assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
685
 
@@ -702,7 +766,7 @@ class DreamBoothDataset(BaseDataset):
702
  self.bucket_reso_steps = None # この情報は使われない
703
  self.bucket_no_upscale = False
704
 
705
- def read_caption(img_path):
706
  # captionの候補ファイル名を作る
707
  base_name = os.path.splitext(img_path)[0]
708
  base_name_face_det = base_name
@@ -725,153 +789,181 @@ class DreamBoothDataset(BaseDataset):
725
  break
726
  return caption
727
 
728
- def load_dreambooth_dir(dir):
729
- if not os.path.isdir(dir):
730
- # print(f"ignore file: {dir}")
731
- return 0, [], []
732
 
733
- tokens = os.path.basename(dir).split('_')
734
- try:
735
- n_repeats = int(tokens[0])
736
- except ValueError as e:
737
- print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
738
- return 0, [], []
739
-
740
- caption_by_folder = '_'.join(tokens[1:])
741
- img_paths = glob_images(dir, "*")
742
- print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
743
 
744
  # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
745
  captions = []
746
  for img_path in img_paths:
747
- cap_for_img = read_caption(img_path)
748
- captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
 
 
 
 
749
 
750
- self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
751
 
752
- return n_repeats, img_paths, captions
753
 
754
- print("prepare train images.")
755
- train_dirs = os.listdir(train_data_dir)
756
  num_train_images = 0
757
- for dir in train_dirs:
758
- n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
759
- num_train_images += n_repeats * len(img_paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
760
 
761
  for img_path, caption in zip(img_paths, captions):
762
- info = ImageInfo(img_path, n_repeats, caption, False, img_path)
763
- self.register_image(info)
 
 
 
764
 
765
- self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
 
766
 
767
  print(f"{num_train_images} train images with repeating.")
768
  self.num_train_images = num_train_images
769
 
770
- # reg imageは数を数えて学習画像と同じ枚数にする
771
- num_reg_images = 0
772
- if reg_data_dir:
773
- print("prepare reg images.")
774
- reg_infos: List[ImageInfo] = []
775
 
776
- reg_dirs = os.listdir(reg_data_dir)
777
- for dir in reg_dirs:
778
- n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
779
- num_reg_images += n_repeats * len(img_paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
780
 
781
- for img_path, caption in zip(img_paths, captions):
782
- info = ImageInfo(img_path, n_repeats, caption, True, img_path)
783
- reg_infos.append(info)
784
 
785
- self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
786
 
787
- print(f"{num_reg_images} reg images.")
788
- if num_train_images < num_reg_images:
789
- print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
 
 
 
 
 
790
 
791
- if num_reg_images == 0:
792
- print("no regularization images / 正則化画像が見つかりませんでした")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
793
  else:
794
- # num_repeatsを計算する:どうせ大した数ではないのでループで処理する
795
- n = 0
796
- first_loop = True
797
- while n < num_train_images:
798
- for info in reg_infos:
799
- if first_loop:
800
- self.register_image(info)
801
- n += info.num_repeats
802
- else:
803
- info.num_repeats += 1
804
- n += 1
805
- if n >= num_train_images:
806
- break
807
- first_loop = False
808
 
809
- self.num_reg_images = num_reg_images
 
 
810
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
811
 
812
- class FineTuningDataset(BaseDataset):
813
- def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
814
- super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
815
- resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
816
-
817
- # メタデータを読み込む
818
- if os.path.exists(json_file_name):
819
- print(f"loading existing metadata: {json_file_name}")
820
- with open(json_file_name, "rt", encoding='utf-8') as f:
821
- metadata = json.load(f)
822
- else:
823
- raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}")
824
 
825
- self.metadata = metadata
826
- self.train_data_dir = train_data_dir
827
- self.batch_size = batch_size
828
 
829
- tags_list = []
830
- for image_key, img_md in metadata.items():
831
- # path情報を作る
832
- if os.path.exists(image_key):
833
- abs_path = image_key
834
- else:
835
- # わりといい加減だがいい方法が思いつかん
836
- abs_path = glob_images(train_data_dir, image_key)
837
- assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
838
- abs_path = abs_path[0]
839
-
840
- caption = img_md.get('caption')
841
- tags = img_md.get('tags')
842
- if caption is None:
843
- caption = tags
844
- elif tags is not None and len(tags) > 0:
845
- caption = caption + ', ' + tags
846
- tags_list.append(tags)
847
- assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
848
-
849
- image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
850
- image_info.image_size = img_md.get('train_resolution')
851
-
852
- if not self.color_aug and not self.random_crop:
853
- # if npz exists, use them
854
- image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
855
-
856
- self.register_image(image_info)
857
- self.num_train_images = len(metadata) * dataset_repeats
858
- self.num_reg_images = 0
859
 
860
- # TODO do not record tag freq when no tag
861
- self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
862
- self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
 
 
 
863
 
864
  # check existence of all npz files
865
- use_npz_latents = not (self.color_aug or self.random_crop)
866
  if use_npz_latents:
 
867
  npz_any = False
868
  npz_all = True
 
869
  for image_info in self.image_data.values():
 
 
870
  has_npz = image_info.latents_npz is not None
871
  npz_any = npz_any or has_npz
872
 
873
- if self.flip_aug:
874
  has_npz = has_npz and image_info.latents_npz_flipped is not None
 
875
  npz_all = npz_all and has_npz
876
 
877
  if npz_any and not npz_all:
@@ -883,7 +975,7 @@ class FineTuningDataset(BaseDataset):
883
  elif not npz_all:
884
  use_npz_latents = False
885
  print(f"some of npz file does not exist. ignore npz files / いくつ���のnpzファイルが見つからないためnpzファイルを無視します")
886
- if self.flip_aug:
887
  print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
888
  # else:
889
  # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
@@ -929,7 +1021,7 @@ class FineTuningDataset(BaseDataset):
929
  for image_info in self.image_data.values():
930
  image_info.latents_npz = image_info.latents_npz_flipped = None
931
 
932
- def image_key_to_npz_file(self, image_key):
933
  base_name = os.path.splitext(image_key)[0]
934
  npz_file_norm = base_name + '.npz'
935
 
@@ -941,8 +1033,8 @@ class FineTuningDataset(BaseDataset):
941
  return npz_file_norm, npz_file_flip
942
 
943
  # image_key is relative path
944
- npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz')
945
- npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz')
946
 
947
  if not os.path.exists(npz_file_norm):
948
  npz_file_norm = None
@@ -953,13 +1045,60 @@ class FineTuningDataset(BaseDataset):
953
  return npz_file_norm, npz_file_flip
954
 
955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
956
  def debug_dataset(train_dataset, show_input_ids=False):
957
  print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
958
  print("Escape for exit. / Escキーで中断、終了します")
959
 
960
  train_dataset.set_current_epoch(1)
961
  k = 0
962
- for i, example in enumerate(train_dataset):
 
 
 
963
  if example['latents'] is not None:
964
  print(f"sample has latents from npz file: {example['latents'].size()}")
965
  for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
@@ -1364,6 +1503,35 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
1364
  help='enable v-parameterization training / v-parameterization学習を有効にする')
1365
  parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
1366
  help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1367
 
1368
 
1369
  def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
@@ -1387,10 +1555,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
1387
  parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
1388
  parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
1389
  help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
1390
- parser.add_argument("--use_8bit_adam", action="store_true",
1391
- help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
1392
- parser.add_argument("--use_lion_optimizer", action="store_true",
1393
- help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
1394
  parser.add_argument("--mem_eff_attn", action="store_true",
1395
  help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
1396
  parser.add_argument("--xformers", action="store_true",
@@ -1398,7 +1562,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
1398
  parser.add_argument("--vae", type=str, default=None,
1399
  help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
1400
 
1401
- parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
1402
  parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
1403
  parser.add_argument("--max_train_epochs", type=int, default=None,
1404
  help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
@@ -1419,15 +1582,23 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
1419
  parser.add_argument("--logging_dir", type=str, default=None,
1420
  help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
1421
  parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
1422
- parser.add_argument("--lr_scheduler", type=str, default="constant",
1423
- help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
1424
- parser.add_argument("--lr_warmup_steps", type=int, default=0,
1425
- help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
1426
  parser.add_argument("--noise_offset", type=float, default=None,
1427
  help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
1428
  parser.add_argument("--lowram", action="store_true",
1429
  help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
1430
 
 
 
 
 
 
 
 
 
 
 
 
 
1431
  if support_dreambooth:
1432
  # DreamBooth training
1433
  parser.add_argument("--prior_loss_weight", type=float, default=1.0,
@@ -1449,8 +1620,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
1449
  parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
1450
  parser.add_argument("--caption_extention", type=str, default=None,
1451
  help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
1452
- parser.add_argument("--keep_tokens", type=int, default=None,
1453
- help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
1454
  parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
1455
  parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
1456
  parser.add_argument("--face_crop_aug_range", type=str, default=None,
@@ -1475,11 +1646,11 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
1475
  if support_caption_dropout:
1476
  # Textual Inversion はcaptionのdropoutをsupportしない
1477
  # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
1478
- parser.add_argument("--caption_dropout_rate", type=float, default=0,
1479
  help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
1480
- parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
1481
  help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
1482
- parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
1483
  help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
1484
 
1485
  if support_dreambooth:
@@ -1504,16 +1675,256 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
1504
  # region utils
1505
 
1506
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1507
  def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
1508
  # backward compatibility
1509
  if args.caption_extention is not None:
1510
  args.caption_extension = args.caption_extention
1511
  args.caption_extention = None
1512
 
1513
- if args.cache_latents:
1514
- assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
1515
- assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
1516
-
1517
  # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
1518
  if args.resolution is not None:
1519
  args.resolution = tuple([int(r) for r in args.resolution.split(',')])
@@ -1536,12 +1947,28 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
1536
 
1537
  def load_tokenizer(args: argparse.Namespace):
1538
  print("prepare tokenizer")
1539
- if args.v2:
1540
- tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
1541
- else:
1542
- tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
1543
- if args.max_token_length is not None:
 
 
 
 
 
 
 
 
 
 
 
1544
  print(f"update token length: {args.max_token_length}")
 
 
 
 
 
1545
  return tokenizer
1546
 
1547
 
@@ -1592,13 +2019,19 @@ def prepare_dtype(args: argparse.Namespace):
1592
 
1593
 
1594
  def load_target_model(args: argparse.Namespace, weight_dtype):
1595
- load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers
 
 
1596
  if load_stable_diffusion_format:
1597
  print("load StableDiffusion checkpoint")
1598
- text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
1599
  else:
1600
  print("load Diffusers pretrained models")
1601
- pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
 
 
 
 
1602
  text_encoder = pipe.text_encoder
1603
  vae = pipe.vae
1604
  unet = pipe.unet
@@ -1767,6 +2200,197 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
1767
  model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
1768
  accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
1769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1770
  # endregion
1771
 
1772
  # region 前処理用
 
1
  # common functions for training
2
 
3
  import argparse
4
+ import importlib
5
  import json
6
+ import re
7
  import shutil
8
  import time
9
+ from typing import (
10
+ Dict,
11
+ List,
12
+ NamedTuple,
13
+ Optional,
14
+ Sequence,
15
+ Tuple,
16
+ Union,
17
+ )
18
  from accelerate import Accelerator
 
19
  import glob
20
  import math
21
  import os
 
26
 
27
  from tqdm import tqdm
28
  import torch
29
+ from torch.optim import Optimizer
30
  from torchvision import transforms
31
  from transformers import CLIPTokenizer
32
+ import transformers
33
  import diffusers
34
+ from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
35
+ from diffusers import (StableDiffusionPipeline, DDPMScheduler,
36
+ EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler,
37
+ LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler,
38
+ KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler)
39
  import albumentations as albu
40
  import numpy as np
41
  from PIL import Image
 
210
  batch_index: int
211
 
212
 
213
+ class AugHelper:
214
+ def __init__(self):
215
+ # prepare all possible augmentators
216
+ color_aug_method = albu.OneOf([
217
+ albu.HueSaturationValue(8, 0, 0, p=.5),
218
+ albu.RandomGamma((95, 105), p=.5),
219
+ ], p=.33)
220
+ flip_aug_method = albu.HorizontalFlip(p=0.5)
221
+
222
+ # key: (use_color_aug, use_flip_aug)
223
+ self.augmentors = {
224
+ (True, True): albu.Compose([
225
+ color_aug_method,
226
+ flip_aug_method,
227
+ ], p=1.),
228
+ (True, False): albu.Compose([
229
+ color_aug_method,
230
+ ], p=1.),
231
+ (False, True): albu.Compose([
232
+ flip_aug_method,
233
+ ], p=1.),
234
+ (False, False): None
235
+ }
236
+
237
+ def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]:
238
+ return self.augmentors[(use_color_aug, use_flip_aug)]
239
+
240
+
241
+ class BaseSubset:
242
+ def __init__(self, image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, keep_tokens: int, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], random_crop: bool, caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float) -> None:
243
+ self.image_dir = image_dir
244
+ self.num_repeats = num_repeats
245
+ self.shuffle_caption = shuffle_caption
246
+ self.keep_tokens = keep_tokens
247
+ self.color_aug = color_aug
248
+ self.flip_aug = flip_aug
249
+ self.face_crop_aug_range = face_crop_aug_range
250
+ self.random_crop = random_crop
251
+ self.caption_dropout_rate = caption_dropout_rate
252
+ self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
253
+ self.caption_tag_dropout_rate = caption_tag_dropout_rate
254
+
255
+ self.img_count = 0
256
+
257
+
258
+ class DreamBoothSubset(BaseSubset):
259
+ def __init__(self, image_dir: str, is_reg: bool, class_tokens: Optional[str], caption_extension: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
260
+ assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
261
+
262
+ super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
263
+ face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
264
+
265
+ self.is_reg = is_reg
266
+ self.class_tokens = class_tokens
267
+ self.caption_extension = caption_extension
268
+
269
+ def __eq__(self, other) -> bool:
270
+ if not isinstance(other, DreamBoothSubset):
271
+ return NotImplemented
272
+ return self.image_dir == other.image_dir
273
+
274
+
275
+ class FineTuningSubset(BaseSubset):
276
+ def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
277
+ assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
278
+
279
+ super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
280
+ face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
281
+
282
+ self.metadata_file = metadata_file
283
+
284
+ def __eq__(self, other) -> bool:
285
+ if not isinstance(other, FineTuningSubset):
286
+ return NotImplemented
287
+ return self.metadata_file == other.metadata_file
288
+
289
+
290
  class BaseDataset(torch.utils.data.Dataset):
291
+ def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None:
292
  super().__init__()
293
+ self.tokenizer = tokenizer
294
  self.max_token_length = max_token_length
 
 
295
  # width/height is used when enable_bucket==False
296
  self.width, self.height = (None, None) if resolution is None else resolution
 
 
 
297
  self.debug_dataset = debug_dataset
298
+
299
+ self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
300
+
301
  self.token_padding_disabled = False
 
 
302
  self.tag_frequency = {}
303
 
304
  self.enable_bucket = False
 
312
  self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
313
 
314
  self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
 
 
 
315
 
316
  # augmentation
317
+ self.aug_helper = AugHelper()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
  self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
320
 
321
  self.image_data: Dict[str, ImageInfo] = {}
322
+ self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
323
 
324
  self.replacements = {}
325
 
326
  def set_current_epoch(self, epoch):
327
  self.current_epoch = epoch
328
+ self.shuffle_buckets()
 
 
 
 
 
329
 
330
  def set_tag_frequency(self, dir_name, captions):
331
  frequency_for_dir = self.tag_frequency.get(dir_name, {})
332
  self.tag_frequency[dir_name] = frequency_for_dir
333
  for caption in captions:
334
  for tag in caption.split(","):
335
+ tag = tag.strip()
336
+ if tag:
337
  tag = tag.lower()
338
  frequency = frequency_for_dir.get(tag, 0)
339
  frequency_for_dir[tag] = frequency + 1
 
344
  def add_replacement(self, str_from, str_to):
345
  self.replacements[str_from] = str_to
346
 
347
+ def process_caption(self, subset: BaseSubset, caption):
348
  # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
349
+ is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate
350
+ is_drop_out = is_drop_out or subset.caption_dropout_every_n_epochs > 0 and self.current_epoch % subset.caption_dropout_every_n_epochs == 0
351
 
352
  if is_drop_out:
353
  caption = ""
354
  else:
355
+ if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0:
356
  def dropout_tags(tokens):
357
+ if subset.caption_tag_dropout_rate <= 0:
358
  return tokens
359
  l = []
360
  for token in tokens:
361
+ if random.random() >= subset.caption_tag_dropout_rate:
362
  l.append(token)
363
  return l
364
 
365
+ fixed_tokens = []
366
+ flex_tokens = [t.strip() for t in caption.strip().split(",")]
367
+ if subset.keep_tokens > 0:
368
+ fixed_tokens = flex_tokens[:subset.keep_tokens]
369
+ flex_tokens = flex_tokens[subset.keep_tokens:]
 
 
 
 
 
370
 
371
+ if subset.shuffle_caption:
372
+ random.shuffle(flex_tokens)
373
 
374
+ flex_tokens = dropout_tags(flex_tokens)
375
 
376
+ caption = ", ".join(fixed_tokens + flex_tokens)
 
377
 
378
  # textual inversion対応
379
  for str_from, str_to in self.replacements.items():
 
427
  input_ids = torch.stack(iids_list) # 3,77
428
  return input_ids
429
 
430
+ def register_image(self, info: ImageInfo, subset: BaseSubset):
431
  self.image_data[info.image_key] = info
432
+ self.image_to_subset[info.image_key] = subset
433
 
434
  def make_buckets(self):
435
  '''
 
528
  img = np.array(image, np.uint8)
529
  return img
530
 
531
+ def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size):
532
  image_height, image_width = image.shape[0:2]
533
 
534
  if image_width != resized_size[0] or image_height != resized_size[1]:
 
538
  image_height, image_width = image.shape[0:2]
539
  if image_width > reso[0]:
540
  trim_size = image_width - reso[0]
541
+ p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
542
  # print("w", trim_size, p)
543
  image = image[:, p:p + reso[0]]
544
  if image_height > reso[1]:
545
  trim_size = image_height - reso[1]
546
+ p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
547
  # print("h", trim_size, p)
548
  image = image[p:p + reso[1]]
549
 
550
  assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
551
  return image
552
 
553
+ def is_latent_cacheable(self):
554
+ return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
555
+
556
  def cache_latents(self, vae):
557
  # TODO ここを高速化したい
558
  print("caching latents.")
559
  for info in tqdm(self.image_data.values()):
560
+ subset = self.image_to_subset[info.image_key]
561
+
562
  if info.latents_npz is not None:
563
  info.latents = self.load_latents_from_npz(info, False)
564
  info.latents = torch.FloatTensor(info.latents)
 
568
  continue
569
 
570
  image = self.load_image(info.absolute_path)
571
+ image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
572
 
573
  img_tensor = self.image_transforms(image)
574
  img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
575
  info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
576
 
577
+ if subset.flip_aug:
578
  image = image[:, ::-1].copy() # cannot convert to Tensor without copy
579
  img_tensor = self.image_transforms(image)
580
  img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
 
584
  image = Image.open(image_path)
585
  return image.size
586
 
587
+ def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
588
  img = self.load_image(image_path)
589
 
590
  face_cx = face_cy = face_w = face_h = 0
591
+ if subset.face_crop_aug_range is not None:
592
  tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
593
  if len(tokens) >= 5:
594
  face_cx = int(tokens[-4])
 
599
  return img, face_cx, face_cy, face_w, face_h
600
 
601
  # いい感じに切り出す
602
+ def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h):
603
  height, width = image.shape[0:2]
604
  if height == self.height and width == self.width:
605
  return image
 
607
  # 画像サイズはsizeより大きいのでリサイズする
608
  face_size = max(face_w, face_h)
609
  min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
610
+ min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
611
+ max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
612
  if min_scale >= max_scale: # range指定がmin==max
613
  scale = min_scale
614
  else:
 
626
  for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
627
  p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
628
 
629
+ if subset.random_crop:
630
  # 背景も含めるために顔を中心に置く確率を高めつつずらす
631
  range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
632
  p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
633
  else:
634
  # range指定があるときのみ、すこしだけランダムに(わりと適当)
635
+ if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]:
636
  if face_size > self.size // 10 and face_size >= 40:
637
  p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
638
 
 
655
  return self._length
656
 
657
  def __getitem__(self, index):
 
 
 
658
  bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
659
  bucket_batch_size = self.buckets_indices[index].bucket_batch_size
660
  image_index = self.buckets_indices[index].batch_index * bucket_batch_size
 
667
 
668
  for image_key in bucket[image_index:image_index + bucket_batch_size]:
669
  image_info = self.image_data[image_key]
670
+ subset = self.image_to_subset[image_key]
671
  loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
672
 
673
  # image/latentsを処理する
674
  if image_info.latents is not None:
675
+ latents = image_info.latents if not subset.flip_aug or random.random() < .5 else image_info.latents_flipped
676
  image = None
677
  elif image_info.latents_npz is not None:
678
+ latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= .5)
679
  latents = torch.FloatTensor(latents)
680
  image = None
681
  else:
682
  # 画像を読み込み、必要ならcropする
683
+ img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path)
684
  im_h, im_w = img.shape[0:2]
685
 
686
  if self.enable_bucket:
687
+ img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size)
688
  else:
689
  if face_cx > 0: # 顔位置情報あり
690
+ img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
691
  elif im_h > self.height or im_w > self.width:
692
+ assert subset.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
693
  if im_h > self.height:
694
  p = random.randint(0, im_h - self.height)
695
  img = img[p:p + self.height]
 
701
  assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
702
 
703
  # augmentation
704
+ aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
705
+ if aug is not None:
706
+ img = aug(image=img)['image']
707
 
708
  latents = None
709
  image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
 
711
  images.append(image)
712
  latents_list.append(latents)
713
 
714
+ caption = self.process_caption(subset, image_info.caption)
715
  captions.append(caption)
716
  if not self.token_padding_disabled: # this option might be omitted in future
717
  input_ids_list.append(self.get_input_ids(caption))
 
742
 
743
 
744
  class DreamBoothDataset(BaseDataset):
745
+ def __init__(self, subsets: Sequence[DreamBoothSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset) -> None:
746
+ super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
 
747
 
748
  assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
749
 
 
766
  self.bucket_reso_steps = None # この情報は使われない
767
  self.bucket_no_upscale = False
768
 
769
+ def read_caption(img_path, caption_extension):
770
  # captionの候補ファイル名を作る
771
  base_name = os.path.splitext(img_path)[0]
772
  base_name_face_det = base_name
 
789
  break
790
  return caption
791
 
792
+ def load_dreambooth_dir(subset: DreamBoothSubset):
793
+ if not os.path.isdir(subset.image_dir):
794
+ print(f"not directory: {subset.image_dir}")
795
+ return [], []
796
 
797
+ img_paths = glob_images(subset.image_dir, "*")
798
+ print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
 
 
 
 
 
 
 
 
799
 
800
  # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
801
  captions = []
802
  for img_path in img_paths:
803
+ cap_for_img = read_caption(img_path, subset.caption_extension)
804
+ if cap_for_img is None and subset.class_tokens is None:
805
+ print(f"neither caption file nor class tokens are found. use empty caption for {img_path}")
806
+ captions.append("")
807
+ else:
808
+ captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
809
 
810
+ self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
811
 
812
+ return img_paths, captions
813
 
814
+ print("prepare images.")
 
815
  num_train_images = 0
816
+ num_reg_images = 0
817
+ reg_infos: List[ImageInfo] = []
818
+ for subset in subsets:
819
+ if subset.num_repeats < 1:
820
+ print(
821
+ f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
822
+ continue
823
+
824
+ if subset in self.subsets:
825
+ print(
826
+ f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
827
+ continue
828
+
829
+ img_paths, captions = load_dreambooth_dir(subset)
830
+ if len(img_paths) < 1:
831
+ print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
832
+ continue
833
+
834
+ if subset.is_reg:
835
+ num_reg_images += subset.num_repeats * len(img_paths)
836
+ else:
837
+ num_train_images += subset.num_repeats * len(img_paths)
838
 
839
  for img_path, caption in zip(img_paths, captions):
840
+ info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
841
+ if subset.is_reg:
842
+ reg_infos.append(info)
843
+ else:
844
+ self.register_image(info, subset)
845
 
846
+ subset.img_count = len(img_paths)
847
+ self.subsets.append(subset)
848
 
849
  print(f"{num_train_images} train images with repeating.")
850
  self.num_train_images = num_train_images
851
 
852
+ print(f"{num_reg_images} reg images.")
853
+ if num_train_images < num_reg_images:
854
+ print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
 
 
855
 
856
+ if num_reg_images == 0:
857
+ print("no regularization images / 正則化画像が見つかりませんでした")
858
+ else:
859
+ # num_repeatsを計算する:どうせ大した数ではないのでループで処理する
860
+ n = 0
861
+ first_loop = True
862
+ while n < num_train_images:
863
+ for info in reg_infos:
864
+ if first_loop:
865
+ self.register_image(info, subset)
866
+ n += info.num_repeats
867
+ else:
868
+ info.num_repeats += 1
869
+ n += 1
870
+ if n >= num_train_images:
871
+ break
872
+ first_loop = False
873
 
874
+ self.num_reg_images = num_reg_images
 
 
875
 
 
876
 
877
+ class FineTuningDataset(BaseDataset):
878
+ def __init__(self, subsets: Sequence[FineTuningSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset) -> None:
879
+ super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
880
+
881
+ self.batch_size = batch_size
882
+
883
+ self.num_train_images = 0
884
+ self.num_reg_images = 0
885
 
886
+ for subset in subsets:
887
+ if subset.num_repeats < 1:
888
+ print(
889
+ f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
890
+ continue
891
+
892
+ if subset in self.subsets:
893
+ print(
894
+ f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
895
+ continue
896
+
897
+ # メタデータを読み込む
898
+ if os.path.exists(subset.metadata_file):
899
+ print(f"loading existing metadata: {subset.metadata_file}")
900
+ with open(subset.metadata_file, "rt", encoding='utf-8') as f:
901
+ metadata = json.load(f)
902
  else:
903
+ raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}")
 
 
 
 
 
 
 
 
 
 
 
 
 
904
 
905
+ if len(metadata) < 1:
906
+ print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します")
907
+ continue
908
 
909
+ tags_list = []
910
+ for image_key, img_md in metadata.items():
911
+ # path情報を作る
912
+ if os.path.exists(image_key):
913
+ abs_path = image_key
914
+ else:
915
+ npz_path = os.path.join(subset.image_dir, image_key + ".npz")
916
+ if os.path.exists(npz_path):
917
+ abs_path = npz_path
918
+ else:
919
+ # わりといい加減だがいい方法が思いつかん
920
+ abs_path = glob_images(subset.image_dir, image_key)
921
+ assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
922
+ abs_path = abs_path[0]
923
 
924
+ caption = img_md.get('caption')
925
+ tags = img_md.get('tags')
926
+ if caption is None:
927
+ caption = tags
928
+ elif tags is not None and len(tags) > 0:
929
+ caption = caption + ', ' + tags
930
+ tags_list.append(tags)
 
 
 
 
 
931
 
932
+ if caption is None:
933
+ caption = ""
 
934
 
935
+ image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
936
+ image_info.image_size = img_md.get('train_resolution')
937
+
938
+ if not subset.color_aug and not subset.random_crop:
939
+ # if npz exists, use them
940
+ image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
941
+
942
+ self.register_image(image_info, subset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
943
 
944
+ self.num_train_images += len(metadata) * subset.num_repeats
945
+
946
+ # TODO do not record tag freq when no tag
947
+ self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list)
948
+ subset.img_count = len(metadata)
949
+ self.subsets.append(subset)
950
 
951
  # check existence of all npz files
952
+ use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
953
  if use_npz_latents:
954
+ flip_aug_in_subset = False
955
  npz_any = False
956
  npz_all = True
957
+
958
  for image_info in self.image_data.values():
959
+ subset = self.image_to_subset[image_info.image_key]
960
+
961
  has_npz = image_info.latents_npz is not None
962
  npz_any = npz_any or has_npz
963
 
964
+ if subset.flip_aug:
965
  has_npz = has_npz and image_info.latents_npz_flipped is not None
966
+ flip_aug_in_subset = True
967
  npz_all = npz_all and has_npz
968
 
969
  if npz_any and not npz_all:
 
975
  elif not npz_all:
976
  use_npz_latents = False
977
  print(f"some of npz file does not exist. ignore npz files / いくつ���のnpzファイルが見つからないためnpzファイルを無視します")
978
+ if flip_aug_in_subset:
979
  print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
980
  # else:
981
  # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
 
1021
  for image_info in self.image_data.values():
1022
  image_info.latents_npz = image_info.latents_npz_flipped = None
1023
 
1024
+ def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
1025
  base_name = os.path.splitext(image_key)[0]
1026
  npz_file_norm = base_name + '.npz'
1027
 
 
1033
  return npz_file_norm, npz_file_flip
1034
 
1035
  # image_key is relative path
1036
+ npz_file_norm = os.path.join(subset.image_dir, image_key + '.npz')
1037
+ npz_file_flip = os.path.join(subset.image_dir, image_key + '_flip.npz')
1038
 
1039
  if not os.path.exists(npz_file_norm):
1040
  npz_file_norm = None
 
1045
  return npz_file_norm, npz_file_flip
1046
 
1047
 
1048
+ # behave as Dataset mock
1049
+ class DatasetGroup(torch.utils.data.ConcatDataset):
1050
+ def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
1051
+ self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]]
1052
+
1053
+ super().__init__(datasets)
1054
+
1055
+ self.image_data = {}
1056
+ self.num_train_images = 0
1057
+ self.num_reg_images = 0
1058
+
1059
+ # simply concat together
1060
+ # TODO: handling image_data key duplication among dataset
1061
+ # In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset.
1062
+ for dataset in datasets:
1063
+ self.image_data.update(dataset.image_data)
1064
+ self.num_train_images += dataset.num_train_images
1065
+ self.num_reg_images += dataset.num_reg_images
1066
+
1067
+ def add_replacement(self, str_from, str_to):
1068
+ for dataset in self.datasets:
1069
+ dataset.add_replacement(str_from, str_to)
1070
+
1071
+ # def make_buckets(self):
1072
+ # for dataset in self.datasets:
1073
+ # dataset.make_buckets()
1074
+
1075
+ def cache_latents(self, vae):
1076
+ for i, dataset in enumerate(self.datasets):
1077
+ print(f"[Dataset {i}]")
1078
+ dataset.cache_latents(vae)
1079
+
1080
+ def is_latent_cacheable(self) -> bool:
1081
+ return all([dataset.is_latent_cacheable() for dataset in self.datasets])
1082
+
1083
+ def set_current_epoch(self, epoch):
1084
+ for dataset in self.datasets:
1085
+ dataset.set_current_epoch(epoch)
1086
+
1087
+ def disable_token_padding(self):
1088
+ for dataset in self.datasets:
1089
+ dataset.disable_token_padding()
1090
+
1091
+
1092
  def debug_dataset(train_dataset, show_input_ids=False):
1093
  print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
1094
  print("Escape for exit. / Escキーで中断、終了します")
1095
 
1096
  train_dataset.set_current_epoch(1)
1097
  k = 0
1098
+ indices = list(range(len(train_dataset)))
1099
+ random.shuffle(indices)
1100
+ for i, idx in enumerate(indices):
1101
+ example = train_dataset[idx]
1102
  if example['latents'] is not None:
1103
  print(f"sample has latents from npz file: {example['latents'].size()}")
1104
  for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
 
1503
  help='enable v-parameterization training / v-parameterization学習を有効にする')
1504
  parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
1505
  help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
1506
+ parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
1507
+ help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
1508
+
1509
+
1510
+ def add_optimizer_arguments(parser: argparse.ArgumentParser):
1511
+ parser.add_argument("--optimizer_type", type=str, default="",
1512
+ help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor")
1513
+
1514
+ # backward compatibility
1515
+ parser.add_argument("--use_8bit_adam", action="store_true",
1516
+ help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
1517
+ parser.add_argument("--use_lion_optimizer", action="store_true",
1518
+ help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
1519
+
1520
+ parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
1521
+ parser.add_argument("--max_grad_norm", default=1.0, type=float,
1522
+ help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない")
1523
+
1524
+ parser.add_argument("--optimizer_args", type=str, default=None, nargs='*',
1525
+ help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")")
1526
+
1527
+ parser.add_argument("--lr_scheduler", type=str, default="constant",
1528
+ help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor")
1529
+ parser.add_argument("--lr_warmup_steps", type=int, default=0,
1530
+ help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
1531
+ parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
1532
+ help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
1533
+ parser.add_argument("--lr_scheduler_power", type=float, default=1,
1534
+ help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
1535
 
1536
 
1537
  def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
 
1555
  parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
1556
  parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
1557
  help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
 
 
 
 
1558
  parser.add_argument("--mem_eff_attn", action="store_true",
1559
  help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
1560
  parser.add_argument("--xformers", action="store_true",
 
1562
  parser.add_argument("--vae", type=str, default=None,
1563
  help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
1564
 
 
1565
  parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
1566
  parser.add_argument("--max_train_epochs", type=int, default=None,
1567
  help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
 
1582
  parser.add_argument("--logging_dir", type=str, default=None,
1583
  help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
1584
  parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
 
 
 
 
1585
  parser.add_argument("--noise_offset", type=float, default=None,
1586
  help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
1587
  parser.add_argument("--lowram", action="store_true",
1588
  help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
1589
 
1590
+ parser.add_argument("--sample_every_n_steps", type=int, default=None,
1591
+ help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する")
1592
+ parser.add_argument("--sample_every_n_epochs", type=int, default=None,
1593
+ help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)")
1594
+ parser.add_argument("--sample_prompts", type=str, default=None,
1595
+ help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル")
1596
+ parser.add_argument('--sample_sampler', type=str, default='ddim',
1597
+ choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
1598
+ 'dpmsolver++', 'dpmsingle',
1599
+ 'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'],
1600
+ help=f'sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類')
1601
+
1602
  if support_dreambooth:
1603
  # DreamBooth training
1604
  parser.add_argument("--prior_loss_weight", type=float, default=1.0,
 
1620
  parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
1621
  parser.add_argument("--caption_extention", type=str, default=None,
1622
  help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
1623
+ parser.add_argument("--keep_tokens", type=int, default=0,
1624
+ help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)")
1625
  parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
1626
  parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
1627
  parser.add_argument("--face_crop_aug_range", type=str, default=None,
 
1646
  if support_caption_dropout:
1647
  # Textual Inversion はcaptionのdropoutをsupportしない
1648
  # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
1649
+ parser.add_argument("--caption_dropout_rate", type=float, default=0.0,
1650
  help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
1651
+ parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=0,
1652
  help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
1653
+ parser.add_argument("--caption_tag_dropout_rate", type=float, default=0.0,
1654
  help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
1655
 
1656
  if support_dreambooth:
 
1675
  # region utils
1676
 
1677
 
1678
+ def get_optimizer(args, trainable_params):
1679
+ # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
1680
+
1681
+ optimizer_type = args.optimizer_type
1682
+ if args.use_8bit_adam:
1683
+ assert not args.use_lion_optimizer, "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています"
1684
+ assert optimizer_type is None or optimizer_type == "", "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています"
1685
+ optimizer_type = "AdamW8bit"
1686
+
1687
+ elif args.use_lion_optimizer:
1688
+ assert optimizer_type is None or optimizer_type == "", "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています"
1689
+ optimizer_type = "Lion"
1690
+
1691
+ if optimizer_type is None or optimizer_type == "":
1692
+ optimizer_type = "AdamW"
1693
+ optimizer_type = optimizer_type.lower()
1694
+
1695
+ # 引数を分解する:boolとfloat、tupleのみ対応
1696
+ optimizer_kwargs = {}
1697
+ if args.optimizer_args is not None and len(args.optimizer_args) > 0:
1698
+ for arg in args.optimizer_args:
1699
+ key, value = arg.split('=')
1700
+
1701
+ value = value.split(",")
1702
+ for i in range(len(value)):
1703
+ if value[i].lower() == "true" or value[i].lower() == "false":
1704
+ value[i] = (value[i].lower() == "true")
1705
+ else:
1706
+ value[i] = float(value[i])
1707
+ if len(value) == 1:
1708
+ value = value[0]
1709
+ else:
1710
+ value = tuple(value)
1711
+
1712
+ optimizer_kwargs[key] = value
1713
+ # print("optkwargs:", optimizer_kwargs)
1714
+
1715
+ lr = args.learning_rate
1716
+
1717
+ if optimizer_type == "AdamW8bit".lower():
1718
+ try:
1719
+ import bitsandbytes as bnb
1720
+ except ImportError:
1721
+ raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
1722
+ print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
1723
+ optimizer_class = bnb.optim.AdamW8bit
1724
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1725
+
1726
+ elif optimizer_type == "SGDNesterov8bit".lower():
1727
+ try:
1728
+ import bitsandbytes as bnb
1729
+ except ImportError:
1730
+ raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
1731
+ print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
1732
+ if "momentum" not in optimizer_kwargs:
1733
+ print(f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
1734
+ optimizer_kwargs["momentum"] = 0.9
1735
+
1736
+ optimizer_class = bnb.optim.SGD8bit
1737
+ optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
1738
+
1739
+ elif optimizer_type == "Lion".lower():
1740
+ try:
1741
+ import lion_pytorch
1742
+ except ImportError:
1743
+ raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
1744
+ print(f"use Lion optimizer | {optimizer_kwargs}")
1745
+ optimizer_class = lion_pytorch.Lion
1746
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1747
+
1748
+ elif optimizer_type == "SGDNesterov".lower():
1749
+ print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
1750
+ if "momentum" not in optimizer_kwargs:
1751
+ print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
1752
+ optimizer_kwargs["momentum"] = 0.9
1753
+
1754
+ optimizer_class = torch.optim.SGD
1755
+ optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
1756
+
1757
+ elif optimizer_type == "DAdaptation".lower():
1758
+ try:
1759
+ import dadaptation
1760
+ except ImportError:
1761
+ raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
1762
+ print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
1763
+
1764
+ actual_lr = lr
1765
+ lr_count = 1
1766
+ if type(trainable_params) == list and type(trainable_params[0]) == dict:
1767
+ lrs = set()
1768
+ actual_lr = trainable_params[0].get("lr", actual_lr)
1769
+ for group in trainable_params:
1770
+ lrs.add(group.get("lr", actual_lr))
1771
+ lr_count = len(lrs)
1772
+
1773
+ if actual_lr <= 0.1:
1774
+ print(
1775
+ f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}')
1776
+ print('recommend option: lr=1.0 / 推奨は1.0です')
1777
+ if lr_count > 1:
1778
+ print(
1779
+ f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}")
1780
+
1781
+ optimizer_class = dadaptation.DAdaptAdam
1782
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1783
+
1784
+ elif optimizer_type == "Adafactor".lower():
1785
+ # 引数を確認して適宜補正する
1786
+ if "relative_step" not in optimizer_kwargs:
1787
+ optimizer_kwargs["relative_step"] = True # default
1788
+ if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
1789
+ print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします")
1790
+ optimizer_kwargs["relative_step"] = True
1791
+ print(f"use Adafactor optimizer | {optimizer_kwargs}")
1792
+
1793
+ if optimizer_kwargs["relative_step"]:
1794
+ print(f"relative_step is true / relative_stepがtrueです")
1795
+ if lr != 0.0:
1796
+ print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
1797
+ args.learning_rate = None
1798
+
1799
+ # trainable_paramsがgroupだった時の処理:lrを削除する
1800
+ if type(trainable_params) == list and type(trainable_params[0]) == dict:
1801
+ has_group_lr = False
1802
+ for group in trainable_params:
1803
+ p = group.pop("lr", None)
1804
+ has_group_lr = has_group_lr or (p is not None)
1805
+
1806
+ if has_group_lr:
1807
+ # 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない
1808
+ print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます")
1809
+ args.unet_lr = None
1810
+ args.text_encoder_lr = None
1811
+
1812
+ if args.lr_scheduler != "adafactor":
1813
+ print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
1814
+ args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
1815
+
1816
+ lr = None
1817
+ else:
1818
+ if args.max_grad_norm != 0.0:
1819
+ print(f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません")
1820
+ if args.lr_scheduler != "constant_with_warmup":
1821
+ print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
1822
+ if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
1823
+ print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
1824
+
1825
+ optimizer_class = transformers.optimization.Adafactor
1826
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1827
+
1828
+ elif optimizer_type == "AdamW".lower():
1829
+ print(f"use AdamW optimizer | {optimizer_kwargs}")
1830
+ optimizer_class = torch.optim.AdamW
1831
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1832
+
1833
+ else:
1834
+ # 任意のoptimizerを使う
1835
+ optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
1836
+ print(f"use {optimizer_type} | {optimizer_kwargs}")
1837
+ if "." not in optimizer_type:
1838
+ optimizer_module = torch.optim
1839
+ else:
1840
+ values = optimizer_type.split(".")
1841
+ optimizer_module = importlib.import_module(".".join(values[:-1]))
1842
+ optimizer_type = values[-1]
1843
+
1844
+ optimizer_class = getattr(optimizer_module, optimizer_type)
1845
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1846
+
1847
+ optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
1848
+ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
1849
+
1850
+ return optimizer_name, optimizer_args, optimizer
1851
+
1852
+
1853
+ # Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
1854
+ # code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
1855
+ # Which is a newer release of diffusers than currently packaged with sd-scripts
1856
+ # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
1857
+
1858
+
1859
+ def get_scheduler_fix(
1860
+ name: Union[str, SchedulerType],
1861
+ optimizer: Optimizer,
1862
+ num_warmup_steps: Optional[int] = None,
1863
+ num_training_steps: Optional[int] = None,
1864
+ num_cycles: int = 1,
1865
+ power: float = 1.0,
1866
+ ):
1867
+ """
1868
+ Unified API to get any scheduler from its name.
1869
+ Args:
1870
+ name (`str` or `SchedulerType`):
1871
+ The name of the scheduler to use.
1872
+ optimizer (`torch.optim.Optimizer`):
1873
+ The optimizer that will be used during training.
1874
+ num_warmup_steps (`int`, *optional*):
1875
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
1876
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
1877
+ num_training_steps (`int``, *optional*):
1878
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
1879
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
1880
+ num_cycles (`int`, *optional*):
1881
+ The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
1882
+ power (`float`, *optional*, defaults to 1.0):
1883
+ Power factor. See `POLYNOMIAL` scheduler
1884
+ last_epoch (`int`, *optional*, defaults to -1):
1885
+ The index of the last epoch when resuming training.
1886
+ """
1887
+ if name.startswith("adafactor"):
1888
+ assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
1889
+ initial_lr = float(name.split(':')[1])
1890
+ # print("adafactor scheduler init lr", initial_lr)
1891
+ return transformers.optimization.AdafactorSchedule(optimizer, initial_lr)
1892
+
1893
+ name = SchedulerType(name)
1894
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
1895
+ if name == SchedulerType.CONSTANT:
1896
+ return schedule_func(optimizer)
1897
+
1898
+ # All other schedulers require `num_warmup_steps`
1899
+ if num_warmup_steps is None:
1900
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
1901
+
1902
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
1903
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
1904
+
1905
+ # All other schedulers require `num_training_steps`
1906
+ if num_training_steps is None:
1907
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
1908
+
1909
+ if name == SchedulerType.COSINE_WITH_RESTARTS:
1910
+ return schedule_func(
1911
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
1912
+ )
1913
+
1914
+ if name == SchedulerType.POLYNOMIAL:
1915
+ return schedule_func(
1916
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
1917
+ )
1918
+
1919
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
1920
+
1921
+
1922
  def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
1923
  # backward compatibility
1924
  if args.caption_extention is not None:
1925
  args.caption_extension = args.caption_extention
1926
  args.caption_extention = None
1927
 
 
 
 
 
1928
  # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
1929
  if args.resolution is not None:
1930
  args.resolution = tuple([int(r) for r in args.resolution.split(',')])
 
1947
 
1948
  def load_tokenizer(args: argparse.Namespace):
1949
  print("prepare tokenizer")
1950
+ original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH
1951
+
1952
+ tokenizer: CLIPTokenizer = None
1953
+ if args.tokenizer_cache_dir:
1954
+ local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace('/', '_'))
1955
+ if os.path.exists(local_tokenizer_path):
1956
+ print(f"load tokenizer from cache: {local_tokenizer_path}")
1957
+ tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2
1958
+
1959
+ if tokenizer is None:
1960
+ if args.v2:
1961
+ tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer")
1962
+ else:
1963
+ tokenizer = CLIPTokenizer.from_pretrained(original_path)
1964
+
1965
+ if hasattr(args, "max_token_length") and args.max_token_length is not None:
1966
  print(f"update token length: {args.max_token_length}")
1967
+
1968
+ if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
1969
+ print(f"save Tokenizer to cache: {local_tokenizer_path}")
1970
+ tokenizer.save_pretrained(local_tokenizer_path)
1971
+
1972
  return tokenizer
1973
 
1974
 
 
2019
 
2020
 
2021
  def load_target_model(args: argparse.Namespace, weight_dtype):
2022
+ name_or_path = args.pretrained_model_name_or_path
2023
+ name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
2024
+ load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
2025
  if load_stable_diffusion_format:
2026
  print("load StableDiffusion checkpoint")
2027
+ text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path)
2028
  else:
2029
  print("load Diffusers pretrained models")
2030
+ try:
2031
+ pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
2032
+ except EnvironmentError as ex:
2033
+ print(
2034
+ f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}")
2035
  text_encoder = pipe.text_encoder
2036
  vae = pipe.vae
2037
  unet = pipe.unet
 
2200
  model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
2201
  accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
2202
 
2203
+
2204
+ # scheduler:
2205
+ SCHEDULER_LINEAR_START = 0.00085
2206
+ SCHEDULER_LINEAR_END = 0.0120
2207
+ SCHEDULER_TIMESTEPS = 1000
2208
+ SCHEDLER_SCHEDULE = 'scaled_linear'
2209
+
2210
+
2211
+ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None):
2212
+ """
2213
+ 生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない
2214
+ clip skipは対応した
2215
+ """
2216
+ if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
2217
+ return
2218
+ if args.sample_every_n_epochs is not None:
2219
+ # sample_every_n_steps は無視する
2220
+ if epoch is None or epoch % args.sample_every_n_epochs != 0:
2221
+ return
2222
+ else:
2223
+ if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
2224
+ return
2225
+
2226
+ print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
2227
+ if not os.path.isfile(args.sample_prompts):
2228
+ print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
2229
+ return
2230
+
2231
+ org_vae_device = vae.device # CPUにいるはず
2232
+ vae.to(device)
2233
+
2234
+ # clip skip 対応のための wrapper を作る
2235
+ if args.clip_skip is None:
2236
+ text_encoder_or_wrapper = text_encoder
2237
+ else:
2238
+ class Wrapper():
2239
+ def __init__(self, tenc) -> None:
2240
+ self.tenc = tenc
2241
+ self.config = {}
2242
+ super().__init__()
2243
+
2244
+ def __call__(self, input_ids, attention_mask):
2245
+ enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True)
2246
+ encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
2247
+ encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states)
2248
+ pooled_output = enc_out['pooler_output']
2249
+ return encoder_hidden_states, pooled_output # 1st output is only used
2250
+
2251
+ text_encoder_or_wrapper = Wrapper(text_encoder)
2252
+
2253
+ # read prompts
2254
+ with open(args.sample_prompts, 'rt', encoding='utf-8') as f:
2255
+ prompts = f.readlines()
2256
+
2257
+ # schedulerを用意する
2258
+ sched_init_args = {}
2259
+ if args.sample_sampler == "ddim":
2260
+ scheduler_cls = DDIMScheduler
2261
+ elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
2262
+ scheduler_cls = DDPMScheduler
2263
+ elif args.sample_sampler == "pndm":
2264
+ scheduler_cls = PNDMScheduler
2265
+ elif args.sample_sampler == 'lms' or args.sample_sampler == 'k_lms':
2266
+ scheduler_cls = LMSDiscreteScheduler
2267
+ elif args.sample_sampler == 'euler' or args.sample_sampler == 'k_euler':
2268
+ scheduler_cls = EulerDiscreteScheduler
2269
+ elif args.sample_sampler == 'euler_a' or args.sample_sampler == 'k_euler_a':
2270
+ scheduler_cls = EulerAncestralDiscreteScheduler
2271
+ elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
2272
+ scheduler_cls = DPMSolverMultistepScheduler
2273
+ sched_init_args['algorithm_type'] = args.sample_sampler
2274
+ elif args.sample_sampler == "dpmsingle":
2275
+ scheduler_cls = DPMSolverSinglestepScheduler
2276
+ elif args.sample_sampler == "heun":
2277
+ scheduler_cls = HeunDiscreteScheduler
2278
+ elif args.sample_sampler == 'dpm_2' or args.sample_sampler == 'k_dpm_2':
2279
+ scheduler_cls = KDPM2DiscreteScheduler
2280
+ elif args.sample_sampler == 'dpm_2_a' or args.sample_sampler == 'k_dpm_2_a':
2281
+ scheduler_cls = KDPM2AncestralDiscreteScheduler
2282
+ else:
2283
+ scheduler_cls = DDIMScheduler
2284
+
2285
+ if args.v_parameterization:
2286
+ sched_init_args['prediction_type'] = 'v_prediction'
2287
+
2288
+ scheduler = scheduler_cls(num_train_timesteps=SCHEDULER_TIMESTEPS,
2289
+ beta_start=SCHEDULER_LINEAR_START, beta_end=SCHEDULER_LINEAR_END,
2290
+ beta_schedule=SCHEDLER_SCHEDULE, **sched_init_args)
2291
+
2292
+ # clip_sample=Trueにする
2293
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
2294
+ # print("set clip_sample to True")
2295
+ scheduler.config.clip_sample = True
2296
+
2297
+ pipeline = StableDiffusionPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer,
2298
+ scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False)
2299
+ pipeline.to(device)
2300
+
2301
+ save_dir = args.output_dir + "/sample"
2302
+ os.makedirs(save_dir, exist_ok=True)
2303
+
2304
+ rng_state = torch.get_rng_state()
2305
+ cuda_rng_state = torch.cuda.get_rng_state()
2306
+
2307
+ with torch.no_grad():
2308
+ with accelerator.autocast():
2309
+ for i, prompt in enumerate(prompts):
2310
+ if not accelerator.is_main_process:
2311
+ continue
2312
+ prompt = prompt.strip()
2313
+ if len(prompt) == 0 or prompt[0] == '#':
2314
+ continue
2315
+
2316
+ # subset of gen_img_diffusers
2317
+ prompt_args = prompt.split(' --')
2318
+ prompt = prompt_args[0]
2319
+ negative_prompt = None
2320
+ sample_steps = 30
2321
+ width = height = 512
2322
+ scale = 7.5
2323
+ seed = None
2324
+ for parg in prompt_args:
2325
+ try:
2326
+ m = re.match(r'w (\d+)', parg, re.IGNORECASE)
2327
+ if m:
2328
+ width = int(m.group(1))
2329
+ continue
2330
+
2331
+ m = re.match(r'h (\d+)', parg, re.IGNORECASE)
2332
+ if m:
2333
+ height = int(m.group(1))
2334
+ continue
2335
+
2336
+ m = re.match(r'd (\d+)', parg, re.IGNORECASE)
2337
+ if m:
2338
+ seed = int(m.group(1))
2339
+ continue
2340
+
2341
+ m = re.match(r's (\d+)', parg, re.IGNORECASE)
2342
+ if m: # steps
2343
+ sample_steps = max(1, min(1000, int(m.group(1))))
2344
+ continue
2345
+
2346
+ m = re.match(r'l ([\d\.]+)', parg, re.IGNORECASE)
2347
+ if m: # scale
2348
+ scale = float(m.group(1))
2349
+ continue
2350
+
2351
+ m = re.match(r'n (.+)', parg, re.IGNORECASE)
2352
+ if m: # negative prompt
2353
+ negative_prompt = m.group(1)
2354
+ continue
2355
+
2356
+ except ValueError as ex:
2357
+ print(f"Exception in parsing / 解析エラー: {parg}")
2358
+ print(ex)
2359
+
2360
+ if seed is not None:
2361
+ torch.manual_seed(seed)
2362
+ torch.cuda.manual_seed(seed)
2363
+
2364
+ if prompt_replacement is not None:
2365
+ prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
2366
+ if negative_prompt is not None:
2367
+ negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
2368
+
2369
+ height = max(64, height - height % 8) # round to divisible by 8
2370
+ width = max(64, width - width % 8) # round to divisible by 8
2371
+ print(f"prompt: {prompt}")
2372
+ print(f"negative_prompt: {negative_prompt}")
2373
+ print(f"height: {height}")
2374
+ print(f"width: {width}")
2375
+ print(f"sample_steps: {sample_steps}")
2376
+ print(f"scale: {scale}")
2377
+ image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
2378
+
2379
+ ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
2380
+ num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
2381
+ seed_suffix = "" if seed is None else f"_{seed}"
2382
+ img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
2383
+
2384
+ image.save(os.path.join(save_dir, img_filename))
2385
+
2386
+ # clear pipeline and cache to reduce vram usage
2387
+ del pipeline
2388
+ torch.cuda.empty_cache()
2389
+
2390
+ torch.set_rng_state(rng_state)
2391
+ torch.cuda.set_rng_state(cuda_rng_state)
2392
+ vae.to(org_vae_device)
2393
+
2394
  # endregion
2395
 
2396
  # region 前処理用
networks/check_lora_weights.py CHANGED
@@ -21,7 +21,7 @@ def main(file):
21
 
22
  for key, value in values:
23
  value = value.to(torch.float32)
24
- print(f"{key},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
25
 
26
 
27
  if __name__ == '__main__':
 
21
 
22
  for key, value in values:
23
  value = value.to(torch.float32)
24
+ print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
25
 
26
 
27
  if __name__ == '__main__':
networks/extract_lora_from_models.py CHANGED
@@ -45,8 +45,13 @@ def svd(args):
45
  text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
46
 
47
  # create LoRA network to extract weights: Use dim (rank) as alpha
48
- lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o)
49
- lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t)
 
 
 
 
 
50
  assert len(lora_network_o.text_encoder_loras) == len(
51
  lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
52
 
@@ -85,13 +90,28 @@ def svd(args):
85
 
86
  # make LoRA with svd
87
  print("calculating by svd")
88
- rank = args.dim
89
  lora_weights = {}
90
  with torch.no_grad():
91
  for lora_name, mat in tqdm(list(diffs.items())):
 
92
  conv2d = (len(mat.size()) == 4)
 
 
 
 
 
 
 
 
 
 
 
 
93
  if conv2d:
94
- mat = mat.squeeze()
 
 
 
95
 
96
  U, S, Vh = torch.linalg.svd(mat)
97
 
@@ -108,30 +128,27 @@ def svd(args):
108
  U = U.clamp(low_val, hi_val)
109
  Vh = Vh.clamp(low_val, hi_val)
110
 
111
- lora_weights[lora_name] = (U, Vh)
112
-
113
- # make state dict for LoRA
114
- lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict
115
- lora_sd = lora_network_o.state_dict()
116
- print(f"LoRA has {len(lora_sd)} weights.")
117
-
118
- for key in list(lora_sd.keys()):
119
- if "alpha" in key:
120
- continue
121
 
122
- lora_name = key.split('.')[0]
123
- i = 0 if "lora_up" in key else 1
124
 
125
- weights = lora_weights[lora_name][i]
126
- # print(key, i, weights.size(), lora_sd[key].size())
127
- if len(lora_sd[key].size()) == 4:
128
- weights = weights.unsqueeze(2).unsqueeze(3)
129
 
130
- assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
131
- lora_sd[key] = weights
 
 
 
 
132
 
133
  # load state dict to LoRA and save it
134
- info = lora_network_o.load_state_dict(lora_sd)
 
 
 
135
  print(f"Loading extracted LoRA weights: {info}")
136
 
137
  dir_name = os.path.dirname(args.save_to)
@@ -139,9 +156,9 @@ def svd(args):
139
  os.makedirs(dir_name, exist_ok=True)
140
 
141
  # minimum metadata
142
- metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
143
 
144
- lora_network_o.save_weights(args.save_to, save_dtype, metadata)
145
  print(f"LoRA weights are saved to: {args.save_to}")
146
 
147
 
@@ -158,6 +175,8 @@ if __name__ == '__main__':
158
  parser.add_argument("--save_to", type=str, default=None,
159
  help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
160
  parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
 
 
161
  parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイ��、cuda でGPUを使う")
162
 
163
  args = parser.parse_args()
 
45
  text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
46
 
47
  # create LoRA network to extract weights: Use dim (rank) as alpha
48
+ if args.conv_dim is None:
49
+ kwargs = {}
50
+ else:
51
+ kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim}
52
+
53
+ lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o, **kwargs)
54
+ lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t, **kwargs)
55
  assert len(lora_network_o.text_encoder_loras) == len(
56
  lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
57
 
 
90
 
91
  # make LoRA with svd
92
  print("calculating by svd")
 
93
  lora_weights = {}
94
  with torch.no_grad():
95
  for lora_name, mat in tqdm(list(diffs.items())):
96
+ # if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3
97
  conv2d = (len(mat.size()) == 4)
98
+ kernel_size = None if not conv2d else mat.size()[2:4]
99
+ conv2d_3x3 = conv2d and kernel_size != (1, 1)
100
+
101
+ rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim
102
+ out_dim, in_dim = mat.size()[0:2]
103
+
104
+ if args.device:
105
+ mat = mat.to(args.device)
106
+
107
+ # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
108
+ rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
109
+
110
  if conv2d:
111
+ if conv2d_3x3:
112
+ mat = mat.flatten(start_dim=1)
113
+ else:
114
+ mat = mat.squeeze()
115
 
116
  U, S, Vh = torch.linalg.svd(mat)
117
 
 
128
  U = U.clamp(low_val, hi_val)
129
  Vh = Vh.clamp(low_val, hi_val)
130
 
131
+ if conv2d:
132
+ U = U.reshape(out_dim, rank, 1, 1)
133
+ Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
 
 
 
 
 
 
 
134
 
135
+ U = U.to("cpu").contiguous()
136
+ Vh = Vh.to("cpu").contiguous()
137
 
138
+ lora_weights[lora_name] = (U, Vh)
 
 
 
139
 
140
+ # make state dict for LoRA
141
+ lora_sd = {}
142
+ for lora_name, (up_weight, down_weight) in lora_weights.items():
143
+ lora_sd[lora_name + '.lora_up.weight'] = up_weight
144
+ lora_sd[lora_name + '.lora_down.weight'] = down_weight
145
+ lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0])
146
 
147
  # load state dict to LoRA and save it
148
+ lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
149
+ lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict
150
+
151
+ info = lora_network_save.load_state_dict(lora_sd)
152
  print(f"Loading extracted LoRA weights: {info}")
153
 
154
  dir_name = os.path.dirname(args.save_to)
 
156
  os.makedirs(dir_name, exist_ok=True)
157
 
158
  # minimum metadata
159
+ metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
160
 
161
+ lora_network_save.save_weights(args.save_to, save_dtype, metadata)
162
  print(f"LoRA weights are saved to: {args.save_to}")
163
 
164
 
 
175
  parser.add_argument("--save_to", type=str, default=None,
176
  help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
177
  parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
178
+ parser.add_argument("--conv_dim", type=int, default=None,
179
+ help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)")
180
  parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイ��、cuda でGPUを使う")
181
 
182
  args = parser.parse_args()
networks/lora.py CHANGED
@@ -6,6 +6,7 @@
6
  import math
7
  import os
8
  from typing import List
 
9
  import torch
10
 
11
  from library import train_util
@@ -20,22 +21,34 @@ class LoRAModule(torch.nn.Module):
20
  """ if alpha == 0 or None, alpha is rank (no scaling). """
21
  super().__init__()
22
  self.lora_name = lora_name
23
- self.lora_dim = lora_dim
24
 
25
  if org_module.__class__.__name__ == 'Conv2d':
26
  in_dim = org_module.in_channels
27
  out_dim = org_module.out_channels
28
- self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
29
- self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
30
  else:
31
  in_dim = org_module.in_features
32
  out_dim = org_module.out_features
33
- self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
34
- self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  if type(alpha) == torch.Tensor:
37
  alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
38
- alpha = lora_dim if alpha is None or alpha == 0 else alpha
39
  self.scale = alpha / self.lora_dim
40
  self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
41
 
@@ -45,69 +58,192 @@ class LoRAModule(torch.nn.Module):
45
 
46
  self.multiplier = multiplier
47
  self.org_module = org_module # remove in applying
 
 
48
 
49
  def apply_to(self):
50
  self.org_forward = self.org_module.forward
51
  self.org_module.forward = self.forward
52
  del self.org_module
53
 
 
 
 
 
54
  def forward(self, x):
55
- return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
 
58
  def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
59
  if network_dim is None:
60
  network_dim = 4 # default
61
- network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
62
- return network
63
 
 
 
 
 
 
 
 
 
 
64
 
65
- def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
66
- if os.path.splitext(file)[1] == '.safetensors':
67
- from safetensors.torch import load_file, safe_open
68
- weights_sd = load_file(file)
69
- else:
70
- weights_sd = torch.load(file, map_location='cpu')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- # get dim (rank)
73
- network_alpha = None
74
- network_dim = None
75
- for key, value in weights_sd.items():
76
- if network_alpha is None and 'alpha' in key:
77
- network_alpha = value
78
- if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
79
- network_dim = value.size()[0]
80
 
81
- if network_alpha is None:
82
- network_alpha = network_dim
83
 
84
- network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  network.weights_sd = weights_sd
86
  return network
87
 
88
 
89
  class LoRANetwork(torch.nn.Module):
 
90
  UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
 
91
  TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
92
  LORA_PREFIX_UNET = 'lora_unet'
93
  LORA_PREFIX_TEXT_ENCODER = 'lora_te'
94
 
95
- def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
96
  super().__init__()
97
  self.multiplier = multiplier
 
98
  self.lora_dim = lora_dim
99
  self.alpha = alpha
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  # create module instances
102
  def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
103
  loras = []
104
  for name, module in root_module.named_modules():
105
  if module.__class__.__name__ in target_replace_modules:
 
106
  for child_name, child_module in module.named_modules():
107
- if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
 
 
 
108
  lora_name = prefix + '.' + name + '.' + child_name
109
  lora_name = lora_name.replace('.', '_')
110
- lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  loras.append(lora)
112
  return loras
113
 
@@ -115,7 +251,12 @@ class LoRANetwork(torch.nn.Module):
115
  text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
116
  print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
117
 
118
- self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
 
 
 
 
 
119
  print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
120
 
121
  self.weights_sd = None
@@ -126,6 +267,11 @@ class LoRANetwork(torch.nn.Module):
126
  assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
127
  names.add(lora.lora_name)
128
 
 
 
 
 
 
129
  def load_weights(self, file):
130
  if os.path.splitext(file)[1] == '.safetensors':
131
  from safetensors.torch import load_file, safe_open
@@ -235,3 +381,18 @@ class LoRANetwork(torch.nn.Module):
235
  save_file(state_dict, file, metadata)
236
  else:
237
  torch.save(state_dict, file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import math
7
  import os
8
  from typing import List
9
+ import numpy as np
10
  import torch
11
 
12
  from library import train_util
 
21
  """ if alpha == 0 or None, alpha is rank (no scaling). """
22
  super().__init__()
23
  self.lora_name = lora_name
 
24
 
25
  if org_module.__class__.__name__ == 'Conv2d':
26
  in_dim = org_module.in_channels
27
  out_dim = org_module.out_channels
 
 
28
  else:
29
  in_dim = org_module.in_features
30
  out_dim = org_module.out_features
31
+
32
+ # if limit_rank:
33
+ # self.lora_dim = min(lora_dim, in_dim, out_dim)
34
+ # if self.lora_dim != lora_dim:
35
+ # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
36
+ # else:
37
+ self.lora_dim = lora_dim
38
+
39
+ if org_module.__class__.__name__ == 'Conv2d':
40
+ kernel_size = org_module.kernel_size
41
+ stride = org_module.stride
42
+ padding = org_module.padding
43
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
44
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
45
+ else:
46
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
47
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
48
 
49
  if type(alpha) == torch.Tensor:
50
  alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
51
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
52
  self.scale = alpha / self.lora_dim
53
  self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
54
 
 
58
 
59
  self.multiplier = multiplier
60
  self.org_module = org_module # remove in applying
61
+ self.region = None
62
+ self.region_mask = None
63
 
64
  def apply_to(self):
65
  self.org_forward = self.org_module.forward
66
  self.org_module.forward = self.forward
67
  del self.org_module
68
 
69
+ def set_region(self, region):
70
+ self.region = region
71
+ self.region_mask = None
72
+
73
  def forward(self, x):
74
+ if self.region is None:
75
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
76
+
77
+ # regional LoRA FIXME same as additional-network extension
78
+ if x.size()[1] % 77 == 0:
79
+ # print(f"LoRA for context: {self.lora_name}")
80
+ self.region = None
81
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
82
+
83
+ # calculate region mask first time
84
+ if self.region_mask is None:
85
+ if len(x.size()) == 4:
86
+ h, w = x.size()[2:4]
87
+ else:
88
+ seq_len = x.size()[1]
89
+ ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
90
+ h = int(self.region.size()[0] / ratio + .5)
91
+ w = seq_len // h
92
+
93
+ r = self.region.to(x.device)
94
+ if r.dtype == torch.bfloat16:
95
+ r = r.to(torch.float)
96
+ r = r.unsqueeze(0).unsqueeze(1)
97
+ # print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
98
+ r = torch.nn.functional.interpolate(r, (h, w), mode='bilinear')
99
+ r = r.to(x.dtype)
100
+
101
+ if len(x.size()) == 3:
102
+ r = torch.reshape(r, (1, x.size()[1], -1))
103
+
104
+ self.region_mask = r
105
+
106
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask
107
 
108
 
109
  def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
110
  if network_dim is None:
111
  network_dim = 4 # default
 
 
112
 
113
+ # extract dim/alpha for conv2d, and block dim
114
+ conv_dim = kwargs.get('conv_dim', None)
115
+ conv_alpha = kwargs.get('conv_alpha', None)
116
+ if conv_dim is not None:
117
+ conv_dim = int(conv_dim)
118
+ if conv_alpha is None:
119
+ conv_alpha = 1.0
120
+ else:
121
+ conv_alpha = float(conv_alpha)
122
 
123
+ """
124
+ block_dims = kwargs.get("block_dims")
125
+ block_alphas = None
126
+
127
+ if block_dims is not None:
128
+ block_dims = [int(d) for d in block_dims.split(',')]
129
+ assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
130
+ block_alphas = kwargs.get("block_alphas")
131
+ if block_alphas is None:
132
+ block_alphas = [1] * len(block_dims)
133
+ else:
134
+ block_alphas = [int(a) for a in block_alphas(',')]
135
+ assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
136
+
137
+ conv_block_dims = kwargs.get("conv_block_dims")
138
+ conv_block_alphas = None
139
+
140
+ if conv_block_dims is not None:
141
+ conv_block_dims = [int(d) for d in conv_block_dims.split(',')]
142
+ assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
143
+ conv_block_alphas = kwargs.get("conv_block_alphas")
144
+ if conv_block_alphas is None:
145
+ conv_block_alphas = [1] * len(conv_block_dims)
146
+ else:
147
+ conv_block_alphas = [int(a) for a in conv_block_alphas(',')]
148
+ assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
149
+ """
150
 
151
+ network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim,
152
+ alpha=network_alpha, conv_lora_dim=conv_dim, conv_alpha=conv_alpha)
153
+ return network
 
 
 
 
 
154
 
 
 
155
 
156
+ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
157
+ if weights_sd is None:
158
+ if os.path.splitext(file)[1] == '.safetensors':
159
+ from safetensors.torch import load_file, safe_open
160
+ weights_sd = load_file(file)
161
+ else:
162
+ weights_sd = torch.load(file, map_location='cpu')
163
+
164
+ # get dim/alpha mapping
165
+ modules_dim = {}
166
+ modules_alpha = {}
167
+ for key, value in weights_sd.items():
168
+ if '.' not in key:
169
+ continue
170
+
171
+ lora_name = key.split('.')[0]
172
+ if 'alpha' in key:
173
+ modules_alpha[lora_name] = value
174
+ elif 'lora_down' in key:
175
+ dim = value.size()[0]
176
+ modules_dim[lora_name] = dim
177
+ # print(lora_name, value.size(), dim)
178
+
179
+ # support old LoRA without alpha
180
+ for key in modules_dim.keys():
181
+ if key not in modules_alpha:
182
+ modules_alpha = modules_dim[key]
183
+
184
+ network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
185
  network.weights_sd = weights_sd
186
  return network
187
 
188
 
189
  class LoRANetwork(torch.nn.Module):
190
+ # is it possible to apply conv_in and conv_out?
191
  UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
192
+ UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
193
  TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
194
  LORA_PREFIX_UNET = 'lora_unet'
195
  LORA_PREFIX_TEXT_ENCODER = 'lora_te'
196
 
197
+ def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1, conv_lora_dim=None, conv_alpha=None, modules_dim=None, modules_alpha=None) -> None:
198
  super().__init__()
199
  self.multiplier = multiplier
200
+
201
  self.lora_dim = lora_dim
202
  self.alpha = alpha
203
+ self.conv_lora_dim = conv_lora_dim
204
+ self.conv_alpha = conv_alpha
205
+
206
+ if modules_dim is not None:
207
+ print(f"create LoRA network from weights")
208
+ else:
209
+ print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
210
+
211
+ self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None
212
+ if self.apply_to_conv2d_3x3:
213
+ if self.conv_alpha is None:
214
+ self.conv_alpha = self.alpha
215
+ print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
216
 
217
  # create module instances
218
  def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
219
  loras = []
220
  for name, module in root_module.named_modules():
221
  if module.__class__.__name__ in target_replace_modules:
222
+ # TODO get block index here
223
  for child_name, child_module in module.named_modules():
224
+ is_linear = child_module.__class__.__name__ == "Linear"
225
+ is_conv2d = child_module.__class__.__name__ == "Conv2d"
226
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
227
+ if is_linear or is_conv2d:
228
  lora_name = prefix + '.' + name + '.' + child_name
229
  lora_name = lora_name.replace('.', '_')
230
+
231
+ if modules_dim is not None:
232
+ if lora_name not in modules_dim:
233
+ continue # no LoRA module in this weights file
234
+ dim = modules_dim[lora_name]
235
+ alpha = modules_alpha[lora_name]
236
+ else:
237
+ if is_linear or is_conv2d_1x1:
238
+ dim = self.lora_dim
239
+ alpha = self.alpha
240
+ elif self.apply_to_conv2d_3x3:
241
+ dim = self.conv_lora_dim
242
+ alpha = self.conv_alpha
243
+ else:
244
+ continue
245
+
246
+ lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
247
  loras.append(lora)
248
  return loras
249
 
 
251
  text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
252
  print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
253
 
254
+ # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
255
+ target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
256
+ if modules_dim is not None or self.conv_lora_dim is not None:
257
+ target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
258
+
259
+ self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules)
260
  print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
261
 
262
  self.weights_sd = None
 
267
  assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
268
  names.add(lora.lora_name)
269
 
270
+ def set_multiplier(self, multiplier):
271
+ self.multiplier = multiplier
272
+ for lora in self.text_encoder_loras + self.unet_loras:
273
+ lora.multiplier = self.multiplier
274
+
275
  def load_weights(self, file):
276
  if os.path.splitext(file)[1] == '.safetensors':
277
  from safetensors.torch import load_file, safe_open
 
381
  save_file(state_dict, file, metadata)
382
  else:
383
  torch.save(state_dict, file)
384
+
385
+ @ staticmethod
386
+ def set_regions(networks, image):
387
+ image = image.astype(np.float32) / 255.0
388
+ for i, network in enumerate(networks[:3]):
389
+ # NOTE: consider averaging overwrapping area
390
+ region = image[:, :, i]
391
+ if region.max() == 0:
392
+ continue
393
+ region = torch.tensor(region)
394
+ network.set_region(region)
395
+
396
+ def set_region(self, region):
397
+ for lora in self.unet_loras:
398
+ lora.set_region(region)
networks/merge_lora.py CHANGED
@@ -48,7 +48,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
48
  for name, module in root_module.named_modules():
49
  if module.__class__.__name__ in target_replace_modules:
50
  for child_name, child_module in module.named_modules():
51
- if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
52
  lora_name = prefix + '.' + name + '.' + child_name
53
  lora_name = lora_name.replace('.', '_')
54
  name_to_module[lora_name] = child_module
@@ -80,13 +80,19 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
80
 
81
  # W <- W + U * D
82
  weight = module.weight
 
83
  if len(weight.size()) == 2:
84
  # linear
85
  weight = weight + ratio * (up_weight @ down_weight) * scale
86
- else:
87
- # conv2d
88
  weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
89
  ).unsqueeze(2).unsqueeze(3) * scale
 
 
 
 
 
90
 
91
  module.weight = torch.nn.Parameter(weight)
92
 
@@ -123,7 +129,7 @@ def merge_lora_models(models, ratios, merge_dtype):
123
  alphas[lora_module_name] = alpha
124
  if lora_module_name not in base_alphas:
125
  base_alphas[lora_module_name] = alpha
126
-
127
  print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
128
 
129
  # merge
@@ -145,7 +151,7 @@ def merge_lora_models(models, ratios, merge_dtype):
145
  merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
146
  else:
147
  merged_sd[key] = lora_sd[key] * scale
148
-
149
  # set alpha to sd
150
  for lora_module_name, alpha in base_alphas.items():
151
  key = lora_module_name + ".alpha"
 
48
  for name, module in root_module.named_modules():
49
  if module.__class__.__name__ in target_replace_modules:
50
  for child_name, child_module in module.named_modules():
51
+ if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
52
  lora_name = prefix + '.' + name + '.' + child_name
53
  lora_name = lora_name.replace('.', '_')
54
  name_to_module[lora_name] = child_module
 
80
 
81
  # W <- W + U * D
82
  weight = module.weight
83
+ # print(module_name, down_weight.size(), up_weight.size())
84
  if len(weight.size()) == 2:
85
  # linear
86
  weight = weight + ratio * (up_weight @ down_weight) * scale
87
+ elif down_weight.size()[2:4] == (1, 1):
88
+ # conv2d 1x1
89
  weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
90
  ).unsqueeze(2).unsqueeze(3) * scale
91
+ else:
92
+ # conv2d 3x3
93
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
94
+ # print(conved.size(), weight.size(), module.stride, module.padding)
95
+ weight = weight + ratio * conved * scale
96
 
97
  module.weight = torch.nn.Parameter(weight)
98
 
 
129
  alphas[lora_module_name] = alpha
130
  if lora_module_name not in base_alphas:
131
  base_alphas[lora_module_name] = alpha
132
+
133
  print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
134
 
135
  # merge
 
151
  merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
152
  else:
153
  merged_sd[key] = lora_sd[key] * scale
154
+
155
  # set alpha to sd
156
  for lora_module_name, alpha in base_alphas.items():
157
  key = lora_module_name + ".alpha"
networks/resize_lora.py CHANGED
@@ -1,14 +1,15 @@
1
  # Convert LoRA to different rank approximation (should only be used to go to lower rank)
2
  # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
3
- # Thanks to cloneofsimo and kohya
4
 
5
  import argparse
6
- import os
7
  import torch
8
  from safetensors.torch import load_file, save_file, safe_open
9
  from tqdm import tqdm
10
  from library import train_util, model_util
 
11
 
 
12
 
13
  def load_state_dict(file_name, dtype):
14
  if model_util.is_safetensors(file_name):
@@ -38,12 +39,149 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
38
  torch.save(model, file_name)
39
 
40
 
41
- def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  network_alpha = None
43
  network_dim = None
44
  verbose_str = "\n"
45
-
46
- CLAMP_QUANTILE = 0.99
47
 
48
  # Extract loaded lora dim and alpha
49
  for key, value in lora_sd.items():
@@ -57,9 +195,9 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
57
  network_alpha = network_dim
58
 
59
  scale = network_alpha/network_dim
60
- new_alpha = float(scale*new_rank) # calculate new alpha from scale
61
 
62
- print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}")
 
63
 
64
  lora_down_weight = None
65
  lora_up_weight = None
@@ -68,7 +206,6 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
68
  block_down_name = None
69
  block_up_name = None
70
 
71
- print("resizing lora...")
72
  with torch.no_grad():
73
  for key, value in tqdm(lora_sd.items()):
74
  if 'lora_down' in key:
@@ -85,57 +222,43 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
85
  conv2d = (len(lora_down_weight.size()) == 4)
86
 
87
  if conv2d:
88
- lora_down_weight = lora_down_weight.squeeze()
89
- lora_up_weight = lora_up_weight.squeeze()
90
-
91
- if device:
92
- org_device = lora_up_weight.device
93
- lora_up_weight = lora_up_weight.to(args.device)
94
- lora_down_weight = lora_down_weight.to(args.device)
95
-
96
- full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
97
-
98
- U, S, Vh = torch.linalg.svd(full_weight_matrix)
99
 
100
  if verbose:
101
- s_sum = torch.sum(torch.abs(S))
102
- s_rank = torch.sum(torch.abs(S[:new_rank]))
103
- verbose_str+=f"{block_down_name:76} | "
104
- verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n"
 
105
 
106
- U = U[:, :new_rank]
107
- S = S[:new_rank]
108
- U = U @ torch.diag(S)
109
 
110
- Vh = Vh[:new_rank, :]
 
 
 
111
 
112
- dist = torch.cat([U.flatten(), Vh.flatten()])
113
- hi_val = torch.quantile(dist, CLAMP_QUANTILE)
114
- low_val = -hi_val
115
-
116
- U = U.clamp(low_val, hi_val)
117
- Vh = Vh.clamp(low_val, hi_val)
118
-
119
- if conv2d:
120
- U = U.unsqueeze(2).unsqueeze(3)
121
- Vh = Vh.unsqueeze(2).unsqueeze(3)
122
-
123
- if device:
124
- U = U.to(org_device)
125
- Vh = Vh.to(org_device)
126
-
127
- o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
128
- o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
129
- o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
130
 
131
  block_down_name = None
132
  block_up_name = None
133
  lora_down_weight = None
134
  lora_up_weight = None
135
  weights_loaded = False
 
136
 
137
  if verbose:
138
  print(verbose_str)
 
 
139
  print("resizing complete")
140
  return o_lora_sd, network_dim, new_alpha
141
 
@@ -151,6 +274,9 @@ def resize(args):
151
  return torch.bfloat16
152
  return None
153
 
 
 
 
154
  merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
155
  save_dtype = str_to_dtype(args.save_precision)
156
  if save_dtype is None:
@@ -159,17 +285,23 @@ def resize(args):
159
  print("loading Model...")
160
  lora_sd, metadata = load_state_dict(args.model, merge_dtype)
161
 
162
- print("resizing rank...")
163
- state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose)
164
 
165
  # update metadata
166
  if metadata is None:
167
  metadata = {}
168
 
169
  comment = metadata.get("ss_training_comment", "")
170
- metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
171
- metadata["ss_network_dim"] = str(args.new_rank)
172
- metadata["ss_network_alpha"] = str(new_alpha)
 
 
 
 
 
 
173
 
174
  model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
175
  metadata["sshs_model_hash"] = model_hash
@@ -193,6 +325,11 @@ if __name__ == '__main__':
193
  parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
194
  parser.add_argument("--verbose", action="store_true",
195
  help="Display verbose resizing information / rank変更時の詳細情報を出力する")
 
 
 
 
 
196
 
197
  args = parser.parse_args()
198
  resize(args)
 
1
  # Convert LoRA to different rank approximation (should only be used to go to lower rank)
2
  # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
3
+ # Thanks to cloneofsimo
4
 
5
  import argparse
 
6
  import torch
7
  from safetensors.torch import load_file, save_file, safe_open
8
  from tqdm import tqdm
9
  from library import train_util, model_util
10
+ import numpy as np
11
 
12
+ MIN_SV = 1e-6
13
 
14
  def load_state_dict(file_name, dtype):
15
  if model_util.is_safetensors(file_name):
 
39
  torch.save(model, file_name)
40
 
41
 
42
+ def index_sv_cumulative(S, target):
43
+ original_sum = float(torch.sum(S))
44
+ cumulative_sums = torch.cumsum(S, dim=0)/original_sum
45
+ index = int(torch.searchsorted(cumulative_sums, target)) + 1
46
+ if index >= len(S):
47
+ index = len(S) - 1
48
+
49
+ return index
50
+
51
+
52
+ def index_sv_fro(S, target):
53
+ S_squared = S.pow(2)
54
+ s_fro_sq = float(torch.sum(S_squared))
55
+ sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
56
+ index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
57
+ if index >= len(S):
58
+ index = len(S) - 1
59
+
60
+ return index
61
+
62
+
63
+ # Modified from Kohaku-blueleaf's extract/merge functions
64
+ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
65
+ out_size, in_size, kernel_size, _ = weight.size()
66
+ U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
67
+
68
+ param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
69
+ lora_rank = param_dict["new_rank"]
70
+
71
+ U = U[:, :lora_rank]
72
+ S = S[:lora_rank]
73
+ U = U @ torch.diag(S)
74
+ Vh = Vh[:lora_rank, :]
75
+
76
+ param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
77
+ param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
78
+ del U, S, Vh, weight
79
+ return param_dict
80
+
81
+
82
+ def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
83
+ out_size, in_size = weight.size()
84
+
85
+ U, S, Vh = torch.linalg.svd(weight.to(device))
86
+
87
+ param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
88
+ lora_rank = param_dict["new_rank"]
89
+
90
+ U = U[:, :lora_rank]
91
+ S = S[:lora_rank]
92
+ U = U @ torch.diag(S)
93
+ Vh = Vh[:lora_rank, :]
94
+
95
+ param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
96
+ param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
97
+ del U, S, Vh, weight
98
+ return param_dict
99
+
100
+
101
+ def merge_conv(lora_down, lora_up, device):
102
+ in_rank, in_size, kernel_size, k_ = lora_down.shape
103
+ out_size, out_rank, _, _ = lora_up.shape
104
+ assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
105
+
106
+ lora_down = lora_down.to(device)
107
+ lora_up = lora_up.to(device)
108
+
109
+ merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
110
+ weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
111
+ del lora_up, lora_down
112
+ return weight
113
+
114
+
115
+ def merge_linear(lora_down, lora_up, device):
116
+ in_rank, in_size = lora_down.shape
117
+ out_size, out_rank = lora_up.shape
118
+ assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
119
+
120
+ lora_down = lora_down.to(device)
121
+ lora_up = lora_up.to(device)
122
+
123
+ weight = lora_up @ lora_down
124
+ del lora_up, lora_down
125
+ return weight
126
+
127
+
128
+ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
129
+ param_dict = {}
130
+
131
+ if dynamic_method=="sv_ratio":
132
+ # Calculate new dim and alpha based off ratio
133
+ max_sv = S[0]
134
+ min_sv = max_sv/dynamic_param
135
+ new_rank = max(torch.sum(S > min_sv).item(),1)
136
+ new_alpha = float(scale*new_rank)
137
+
138
+ elif dynamic_method=="sv_cumulative":
139
+ # Calculate new dim and alpha based off cumulative sum
140
+ new_rank = index_sv_cumulative(S, dynamic_param)
141
+ new_rank = max(new_rank, 1)
142
+ new_alpha = float(scale*new_rank)
143
+
144
+ elif dynamic_method=="sv_fro":
145
+ # Calculate new dim and alpha based off sqrt sum of squares
146
+ new_rank = index_sv_fro(S, dynamic_param)
147
+ new_rank = min(max(new_rank, 1), len(S)-1)
148
+ new_alpha = float(scale*new_rank)
149
+ else:
150
+ new_rank = rank
151
+ new_alpha = float(scale*new_rank)
152
+
153
+
154
+ if S[0] <= MIN_SV: # Zero matrix, set dim to 1
155
+ new_rank = 1
156
+ new_alpha = float(scale*new_rank)
157
+ elif new_rank > rank: # cap max rank at rank
158
+ new_rank = rank
159
+ new_alpha = float(scale*new_rank)
160
+
161
+
162
+ # Calculate resize info
163
+ s_sum = torch.sum(torch.abs(S))
164
+ s_rank = torch.sum(torch.abs(S[:new_rank]))
165
+
166
+ S_squared = S.pow(2)
167
+ s_fro = torch.sqrt(torch.sum(S_squared))
168
+ s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
169
+ fro_percent = float(s_red_fro/s_fro)
170
+
171
+ param_dict["new_rank"] = new_rank
172
+ param_dict["new_alpha"] = new_alpha
173
+ param_dict["sum_retained"] = (s_rank)/s_sum
174
+ param_dict["fro_retained"] = fro_percent
175
+ param_dict["max_ratio"] = S[0]/S[new_rank]
176
+
177
+ return param_dict
178
+
179
+
180
+ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
181
  network_alpha = None
182
  network_dim = None
183
  verbose_str = "\n"
184
+ fro_list = []
 
185
 
186
  # Extract loaded lora dim and alpha
187
  for key, value in lora_sd.items():
 
195
  network_alpha = network_dim
196
 
197
  scale = network_alpha/network_dim
 
198
 
199
+ if dynamic_method:
200
+ print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
201
 
202
  lora_down_weight = None
203
  lora_up_weight = None
 
206
  block_down_name = None
207
  block_up_name = None
208
 
 
209
  with torch.no_grad():
210
  for key, value in tqdm(lora_sd.items()):
211
  if 'lora_down' in key:
 
222
  conv2d = (len(lora_down_weight.size()) == 4)
223
 
224
  if conv2d:
225
+ full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
226
+ param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
227
+ else:
228
+ full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
229
+ param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
 
 
 
 
 
 
230
 
231
  if verbose:
232
+ max_ratio = param_dict['max_ratio']
233
+ sum_retained = param_dict['sum_retained']
234
+ fro_retained = param_dict['fro_retained']
235
+ if not np.isnan(fro_retained):
236
+ fro_list.append(float(fro_retained))
237
 
238
+ verbose_str+=f"{block_down_name:75} | "
239
+ verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
 
240
 
241
+ if verbose and dynamic_method:
242
+ verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
243
+ else:
244
+ verbose_str+=f"\n"
245
 
246
+ new_alpha = param_dict['new_alpha']
247
+ o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
248
+ o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
249
+ o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
  block_down_name = None
252
  block_up_name = None
253
  lora_down_weight = None
254
  lora_up_weight = None
255
  weights_loaded = False
256
+ del param_dict
257
 
258
  if verbose:
259
  print(verbose_str)
260
+
261
+ print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
262
  print("resizing complete")
263
  return o_lora_sd, network_dim, new_alpha
264
 
 
274
  return torch.bfloat16
275
  return None
276
 
277
+ if args.dynamic_method and not args.dynamic_param:
278
+ raise Exception("If using dynamic_method, then dynamic_param is required")
279
+
280
  merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
281
  save_dtype = str_to_dtype(args.save_precision)
282
  if save_dtype is None:
 
285
  print("loading Model...")
286
  lora_sd, metadata = load_state_dict(args.model, merge_dtype)
287
 
288
+ print("Resizing Lora...")
289
+ state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
290
 
291
  # update metadata
292
  if metadata is None:
293
  metadata = {}
294
 
295
  comment = metadata.get("ss_training_comment", "")
296
+
297
+ if not args.dynamic_method:
298
+ metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
299
+ metadata["ss_network_dim"] = str(args.new_rank)
300
+ metadata["ss_network_alpha"] = str(new_alpha)
301
+ else:
302
+ metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
303
+ metadata["ss_network_dim"] = 'Dynamic'
304
+ metadata["ss_network_alpha"] = 'Dynamic'
305
 
306
  model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
307
  metadata["sshs_model_hash"] = model_hash
 
325
  parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
326
  parser.add_argument("--verbose", action="store_true",
327
  help="Display verbose resizing information / rank変更時の詳細情報を出力する")
328
+ parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
329
+ help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
330
+ parser.add_argument("--dynamic_param", type=float, default=None,
331
+ help="Specify target for dynamic reduction")
332
+
333
 
334
  args = parser.parse_args()
335
  resize(args)
networks/svd_merge_lora.py CHANGED
@@ -23,19 +23,20 @@ def load_state_dict(file_name, dtype):
23
  return sd
24
 
25
 
26
- def save_to_file(file_name, model, state_dict, dtype):
27
  if dtype is not None:
28
  for key in list(state_dict.keys()):
29
  if type(state_dict[key]) == torch.Tensor:
30
  state_dict[key] = state_dict[key].to(dtype)
31
 
32
  if os.path.splitext(file_name)[1] == '.safetensors':
33
- save_file(model, file_name)
34
  else:
35
- torch.save(model, file_name)
36
 
37
 
38
- def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
 
39
  merged_sd = {}
40
  for model, ratio in zip(models, ratios):
41
  print(f"loading: {model}")
@@ -58,11 +59,12 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
58
  in_dim = down_weight.size()[1]
59
  out_dim = up_weight.size()[0]
60
  conv2d = len(down_weight.size()) == 4
61
- print(lora_module_name, network_dim, alpha, in_dim, out_dim)
 
62
 
63
  # make original weight if not exist
64
  if lora_module_name not in merged_sd:
65
- weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
66
  if device:
67
  weight = weight.to(device)
68
  else:
@@ -75,11 +77,18 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
75
 
76
  # W <- W + U * D
77
  scale = (alpha / network_dim)
 
 
 
 
78
  if not conv2d: # linear
79
  weight = weight + ratio * (up_weight @ down_weight) * scale
80
- else:
81
  weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
82
  ).unsqueeze(2).unsqueeze(3) * scale
 
 
 
83
 
84
  merged_sd[lora_module_name] = weight
85
 
@@ -89,16 +98,26 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
89
  with torch.no_grad():
90
  for lora_module_name, mat in tqdm(list(merged_sd.items())):
91
  conv2d = (len(mat.size()) == 4)
 
 
 
 
92
  if conv2d:
93
- mat = mat.squeeze()
 
 
 
 
 
 
94
 
95
  U, S, Vh = torch.linalg.svd(mat)
96
 
97
- U = U[:, :new_rank]
98
- S = S[:new_rank]
99
  U = U @ torch.diag(S)
100
 
101
- Vh = Vh[:new_rank, :]
102
 
103
  dist = torch.cat([U.flatten(), Vh.flatten()])
104
  hi_val = torch.quantile(dist, CLAMP_QUANTILE)
@@ -107,16 +126,16 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
107
  U = U.clamp(low_val, hi_val)
108
  Vh = Vh.clamp(low_val, hi_val)
109
 
 
 
 
 
110
  up_weight = U
111
  down_weight = Vh
112
 
113
- if conv2d:
114
- up_weight = up_weight.unsqueeze(2).unsqueeze(3)
115
- down_weight = down_weight.unsqueeze(2).unsqueeze(3)
116
-
117
  merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
118
  merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
119
- merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank)
120
 
121
  return merged_lora_sd
122
 
@@ -138,10 +157,11 @@ def merge(args):
138
  if save_dtype is None:
139
  save_dtype = merge_dtype
140
 
141
- state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, args.device, merge_dtype)
 
142
 
143
  print(f"saving model to: {args.save_to}")
144
- save_to_file(args.save_to, state_dict, state_dict, save_dtype)
145
 
146
 
147
  if __name__ == '__main__':
@@ -158,6 +178,8 @@ if __name__ == '__main__':
158
  help="ratios for each model / それぞれのLoRAモデルの比率")
159
  parser.add_argument("--new_rank", type=int, default=4,
160
  help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
 
 
161
  parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
162
 
163
  args = parser.parse_args()
 
23
  return sd
24
 
25
 
26
+ def save_to_file(file_name, state_dict, dtype):
27
  if dtype is not None:
28
  for key in list(state_dict.keys()):
29
  if type(state_dict[key]) == torch.Tensor:
30
  state_dict[key] = state_dict[key].to(dtype)
31
 
32
  if os.path.splitext(file_name)[1] == '.safetensors':
33
+ save_file(state_dict, file_name)
34
  else:
35
+ torch.save(state_dict, file_name)
36
 
37
 
38
+ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
39
+ print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
40
  merged_sd = {}
41
  for model, ratio in zip(models, ratios):
42
  print(f"loading: {model}")
 
59
  in_dim = down_weight.size()[1]
60
  out_dim = up_weight.size()[0]
61
  conv2d = len(down_weight.size()) == 4
62
+ kernel_size = None if not conv2d else down_weight.size()[2:4]
63
+ # print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
64
 
65
  # make original weight if not exist
66
  if lora_module_name not in merged_sd:
67
+ weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
68
  if device:
69
  weight = weight.to(device)
70
  else:
 
77
 
78
  # W <- W + U * D
79
  scale = (alpha / network_dim)
80
+
81
+ if device: # and isinstance(scale, torch.Tensor):
82
+ scale = scale.to(device)
83
+
84
  if not conv2d: # linear
85
  weight = weight + ratio * (up_weight @ down_weight) * scale
86
+ elif kernel_size == (1, 1):
87
  weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
88
  ).unsqueeze(2).unsqueeze(3) * scale
89
+ else:
90
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
91
+ weight = weight + ratio * conved * scale
92
 
93
  merged_sd[lora_module_name] = weight
94
 
 
98
  with torch.no_grad():
99
  for lora_module_name, mat in tqdm(list(merged_sd.items())):
100
  conv2d = (len(mat.size()) == 4)
101
+ kernel_size = None if not conv2d else mat.size()[2:4]
102
+ conv2d_3x3 = conv2d and kernel_size != (1, 1)
103
+ out_dim, in_dim = mat.size()[0:2]
104
+
105
  if conv2d:
106
+ if conv2d_3x3:
107
+ mat = mat.flatten(start_dim=1)
108
+ else:
109
+ mat = mat.squeeze()
110
+
111
+ module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
112
+ module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
113
 
114
  U, S, Vh = torch.linalg.svd(mat)
115
 
116
+ U = U[:, :module_new_rank]
117
+ S = S[:module_new_rank]
118
  U = U @ torch.diag(S)
119
 
120
+ Vh = Vh[:module_new_rank, :]
121
 
122
  dist = torch.cat([U.flatten(), Vh.flatten()])
123
  hi_val = torch.quantile(dist, CLAMP_QUANTILE)
 
126
  U = U.clamp(low_val, hi_val)
127
  Vh = Vh.clamp(low_val, hi_val)
128
 
129
+ if conv2d:
130
+ U = U.reshape(out_dim, module_new_rank, 1, 1)
131
+ Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1])
132
+
133
  up_weight = U
134
  down_weight = Vh
135
 
 
 
 
 
136
  merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
137
  merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
138
+ merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(module_new_rank)
139
 
140
  return merged_lora_sd
141
 
 
157
  if save_dtype is None:
158
  save_dtype = merge_dtype
159
 
160
+ new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
161
+ state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype)
162
 
163
  print(f"saving model to: {args.save_to}")
164
+ save_to_file(args.save_to, state_dict, save_dtype)
165
 
166
 
167
  if __name__ == '__main__':
 
178
  help="ratios for each model / それぞれのLoRAモデルの比率")
179
  parser.add_argument("--new_rank", type=int, default=4,
180
  help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
181
+ parser.add_argument("--new_conv_rank", type=int, default=None,
182
+ help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ")
183
  parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
184
 
185
  args = parser.parse_args()
requirements.txt CHANGED
@@ -12,6 +12,8 @@ safetensors==0.2.6
12
  gradio==3.16.2
13
  altair==4.2.2
14
  easygui==0.98.3
 
 
15
  # for BLIP captioning
16
  requests==2.28.2
17
  timm==0.6.12
 
12
  gradio==3.16.2
13
  altair==4.2.2
14
  easygui==0.98.3
15
+ toml==0.10.2
16
+ voluptuous==0.13.1
17
  # for BLIP captioning
18
  requests==2.28.2
19
  timm==0.6.12
tools/canny.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+
4
+
5
+ def canny(args):
6
+ img = cv2.imread(args.input)
7
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
8
+
9
+ canny_img = cv2.Canny(img, args.thres1, args.thres2)
10
+ # canny_img = 255 - canny_img
11
+
12
+ cv2.imwrite(args.output, canny_img)
13
+ print("done!")
14
+
15
+
16
+ if __name__ == '__main__':
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("--input", type=str, default=None, help="input path")
19
+ parser.add_argument("--output", type=str, default=None, help="output path")
20
+ parser.add_argument("--thres1", type=int, default=32, help="thres1")
21
+ parser.add_argument("--thres2", type=int, default=224, help="thres2")
22
+
23
+ args = parser.parse_args()
24
+ canny(args)
tools/original_control_net.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, NamedTuple, Any
2
+ import numpy as np
3
+ import cv2
4
+ import torch
5
+ from safetensors.torch import load_file
6
+
7
+ from diffusers import UNet2DConditionModel
8
+ from diffusers.models.unet_2d_condition import UNet2DConditionOutput
9
+
10
+ import library.model_util as model_util
11
+
12
+
13
+ class ControlNetInfo(NamedTuple):
14
+ unet: Any
15
+ net: Any
16
+ prep: Any
17
+ weight: float
18
+ ratio: float
19
+
20
+
21
+ class ControlNet(torch.nn.Module):
22
+ def __init__(self) -> None:
23
+ super().__init__()
24
+
25
+ # make control model
26
+ self.control_model = torch.nn.Module()
27
+
28
+ dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280]
29
+ zero_convs = torch.nn.ModuleList()
30
+ for i, dim in enumerate(dims):
31
+ sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)])
32
+ zero_convs.append(sub_list)
33
+ self.control_model.add_module("zero_convs", zero_convs)
34
+
35
+ middle_block_out = torch.nn.Conv2d(1280, 1280, 1)
36
+ self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out]))
37
+
38
+ dims = [16, 16, 32, 32, 96, 96, 256, 320]
39
+ strides = [1, 1, 2, 1, 2, 1, 2, 1]
40
+ prev_dim = 3
41
+ input_hint_block = torch.nn.Sequential()
42
+ for i, (dim, stride) in enumerate(zip(dims, strides)):
43
+ input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1))
44
+ if i < len(dims) - 1:
45
+ input_hint_block.append(torch.nn.SiLU())
46
+ prev_dim = dim
47
+ self.control_model.add_module("input_hint_block", input_hint_block)
48
+
49
+
50
+ def load_control_net(v2, unet, model):
51
+ device = unet.device
52
+
53
+ # control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
54
+ # state dictを読み込む
55
+ print(f"ControlNet: loading control SD model : {model}")
56
+
57
+ if model_util.is_safetensors(model):
58
+ ctrl_sd_sd = load_file(model)
59
+ else:
60
+ ctrl_sd_sd = torch.load(model, map_location='cpu')
61
+ ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd)
62
+
63
+ # 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
64
+ is_difference = "difference" in ctrl_sd_sd
65
+ print("ControlNet: loading difference")
66
+
67
+ # ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
68
+ # またTransfer Controlの元weightとなる
69
+ ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict())
70
+
71
+ # 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける
72
+ for key in list(ctrl_unet_sd_sd.keys()):
73
+ ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone()
74
+
75
+ zero_conv_sd = {}
76
+ for key in list(ctrl_sd_sd.keys()):
77
+ if key.startswith("control_"):
78
+ unet_key = "model.diffusion_" + key[len("control_"):]
79
+ if unet_key not in ctrl_unet_sd_sd: # zero conv
80
+ zero_conv_sd[key] = ctrl_sd_sd[key]
81
+ continue
82
+ if is_difference: # Transfer Control
83
+ ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype)
84
+ else:
85
+ ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype)
86
+
87
+ unet_config = model_util.create_unet_diffusers_config(v2)
88
+ ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict
89
+
90
+ # ControlNetのU-Netを作成する
91
+ ctrl_unet = UNet2DConditionModel(**unet_config)
92
+ info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
93
+ print("ControlNet: loading Control U-Net:", info)
94
+
95
+ # U-Net以外のControlNetを作成する
96
+ # TODO support middle only
97
+ ctrl_net = ControlNet()
98
+ info = ctrl_net.load_state_dict(zero_conv_sd)
99
+ print("ControlNet: loading ControlNet:", info)
100
+
101
+ ctrl_unet.to(unet.device, dtype=unet.dtype)
102
+ ctrl_net.to(unet.device, dtype=unet.dtype)
103
+ return ctrl_unet, ctrl_net
104
+
105
+
106
+ def load_preprocess(prep_type: str):
107
+ if prep_type is None or prep_type.lower() == "none":
108
+ return None
109
+
110
+ if prep_type.startswith("canny"):
111
+ args = prep_type.split("_")
112
+ th1 = int(args[1]) if len(args) >= 2 else 63
113
+ th2 = int(args[2]) if len(args) >= 3 else 191
114
+
115
+ def canny(img):
116
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
117
+ return cv2.Canny(img, th1, th2)
118
+ return canny
119
+
120
+ print("Unsupported prep type:", prep_type)
121
+ return None
122
+
123
+
124
+ def preprocess_ctrl_net_hint_image(image):
125
+ image = np.array(image).astype(np.float32) / 255.0
126
+ image = image[:, :, ::-1].copy() # rgb to bgr
127
+ image = image[None].transpose(0, 3, 1, 2) # nchw
128
+ image = torch.from_numpy(image)
129
+ return image # 0 to 1
130
+
131
+
132
+ def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_size, hints):
133
+ guided_hints = []
134
+ for i, cnet_info in enumerate(control_nets):
135
+ # hintは 1枚目の画像のcnet1, 1枚目の画像のcnet2, 1枚目の画像のcnet3, 2枚目の画像のcnet1, 2枚目の画像のcnet2 ... と並んでいること
136
+ b_hints = []
137
+ if len(hints) == 1: # すべて同じ画像をhintとして使う
138
+ hint = hints[0]
139
+ if cnet_info.prep is not None:
140
+ hint = cnet_info.prep(hint)
141
+ hint = preprocess_ctrl_net_hint_image(hint)
142
+ b_hints = [hint for _ in range(b_size)]
143
+ else:
144
+ for bi in range(b_size):
145
+ hint = hints[(bi * len(control_nets) + i) % len(hints)]
146
+ if cnet_info.prep is not None:
147
+ hint = cnet_info.prep(hint)
148
+ hint = preprocess_ctrl_net_hint_image(hint)
149
+ b_hints.append(hint)
150
+ b_hints = torch.cat(b_hints, dim=0)
151
+ b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype)
152
+
153
+ guided_hint = cnet_info.net.control_model.input_hint_block(b_hints)
154
+ guided_hints.append(guided_hint)
155
+ return guided_hints
156
+
157
+
158
+ def call_unet_and_control_net(step, num_latent_input, original_unet, control_nets: List[ControlNetInfo], guided_hints, current_ratio, sample, timestep, encoder_hidden_states):
159
+ # ControlNet
160
+ # 複数のControlNetの場合は、出力をマージするのではなく交互に適用する
161
+ cnet_cnt = len(control_nets)
162
+ cnet_idx = step % cnet_cnt
163
+ cnet_info = control_nets[cnet_idx]
164
+
165
+ # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
166
+ if cnet_info.ratio < current_ratio:
167
+ return original_unet(sample, timestep, encoder_hidden_states)
168
+
169
+ guided_hint = guided_hints[cnet_idx]
170
+ guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1))
171
+ outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states)
172
+ outs = [o * cnet_info.weight for o in outs]
173
+
174
+ # U-Net
175
+ return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states)
176
+
177
+
178
+ """
179
+ # これはmergeのバージョン
180
+ # ControlNet
181
+ cnet_outs_list = []
182
+ for i, cnet_info in enumerate(control_nets):
183
+ # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
184
+ if cnet_info.ratio < current_ratio:
185
+ continue
186
+ guided_hint = guided_hints[i]
187
+ outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states)
188
+ for i in range(len(outs)):
189
+ outs[i] *= cnet_info.weight
190
+
191
+ cnet_outs_list.append(outs)
192
+
193
+ count = len(cnet_outs_list)
194
+ if count == 0:
195
+ return original_unet(sample, timestep, encoder_hidden_states)
196
+
197
+ # sum of controlnets
198
+ for i in range(1, count):
199
+ cnet_outs_list[0] += cnet_outs_list[i]
200
+
201
+ # U-Net
202
+ return unet_forward(False, cnet_info.net, original_unet, None, cnet_outs_list[0], sample, timestep, encoder_hidden_states)
203
+ """
204
+
205
+
206
+ def unet_forward(is_control_net, control_net: ControlNet, unet: UNet2DConditionModel, guided_hint, ctrl_outs, sample, timestep, encoder_hidden_states):
207
+ # copy from UNet2DConditionModel
208
+ default_overall_up_factor = 2**unet.num_upsamplers
209
+
210
+ forward_upsample_size = False
211
+ upsample_size = None
212
+
213
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
214
+ print("Forward upsample size to force interpolation output size.")
215
+ forward_upsample_size = True
216
+
217
+ # 0. center input if necessary
218
+ if unet.config.center_input_sample:
219
+ sample = 2 * sample - 1.0
220
+
221
+ # 1. time
222
+ timesteps = timestep
223
+ if not torch.is_tensor(timesteps):
224
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
225
+ # This would be a good case for the `match` statement (Python 3.10+)
226
+ is_mps = sample.device.type == "mps"
227
+ if isinstance(timestep, float):
228
+ dtype = torch.float32 if is_mps else torch.float64
229
+ else:
230
+ dtype = torch.int32 if is_mps else torch.int64
231
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
232
+ elif len(timesteps.shape) == 0:
233
+ timesteps = timesteps[None].to(sample.device)
234
+
235
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
236
+ timesteps = timesteps.expand(sample.shape[0])
237
+
238
+ t_emb = unet.time_proj(timesteps)
239
+
240
+ # timesteps does not contain any weights and will always return f32 tensors
241
+ # but time_embedding might actually be running in fp16. so we need to cast here.
242
+ # there might be better ways to encapsulate this.
243
+ t_emb = t_emb.to(dtype=unet.dtype)
244
+ emb = unet.time_embedding(t_emb)
245
+
246
+ outs = [] # output of ControlNet
247
+ zc_idx = 0
248
+
249
+ # 2. pre-process
250
+ sample = unet.conv_in(sample)
251
+ if is_control_net:
252
+ sample += guided_hint
253
+ outs.append(control_net.control_model.zero_convs[zc_idx][0](sample)) # , emb, encoder_hidden_states))
254
+ zc_idx += 1
255
+
256
+ # 3. down
257
+ down_block_res_samples = (sample,)
258
+ for downsample_block in unet.down_blocks:
259
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
260
+ sample, res_samples = downsample_block(
261
+ hidden_states=sample,
262
+ temb=emb,
263
+ encoder_hidden_states=encoder_hidden_states,
264
+ )
265
+ else:
266
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
267
+ if is_control_net:
268
+ for rs in res_samples:
269
+ outs.append(control_net.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states))
270
+ zc_idx += 1
271
+
272
+ down_block_res_samples += res_samples
273
+
274
+ # 4. mid
275
+ sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
276
+ if is_control_net:
277
+ outs.append(control_net.control_model.middle_block_out[0](sample))
278
+ return outs
279
+
280
+ if not is_control_net:
281
+ sample += ctrl_outs.pop()
282
+
283
+ # 5. up
284
+ for i, upsample_block in enumerate(unet.up_blocks):
285
+ is_final_block = i == len(unet.up_blocks) - 1
286
+
287
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
288
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
289
+
290
+ if not is_control_net and len(ctrl_outs) > 0:
291
+ res_samples = list(res_samples)
292
+ apply_ctrl_outs = ctrl_outs[-len(res_samples):]
293
+ ctrl_outs = ctrl_outs[:-len(res_samples)]
294
+ for j in range(len(res_samples)):
295
+ res_samples[j] = res_samples[j] + apply_ctrl_outs[j]
296
+ res_samples = tuple(res_samples)
297
+
298
+ # if we have not reached the final block and need to forward the
299
+ # upsample size, we do it here
300
+ if not is_final_block and forward_upsample_size:
301
+ upsample_size = down_block_res_samples[-1].shape[2:]
302
+
303
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
304
+ sample = upsample_block(
305
+ hidden_states=sample,
306
+ temb=emb,
307
+ res_hidden_states_tuple=res_samples,
308
+ encoder_hidden_states=encoder_hidden_states,
309
+ upsample_size=upsample_size,
310
+ )
311
+ else:
312
+ sample = upsample_block(
313
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
314
+ )
315
+ # 6. post-process
316
+ sample = unet.conv_norm_out(sample)
317
+ sample = unet.conv_act(sample)
318
+ sample = unet.conv_out(sample)
319
+
320
+ return UNet2DConditionOutput(sample=sample)
train_README-ja.md ADDED
@@ -0,0 +1,936 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __ドキュメント更新中のため記述に誤りがあるかもしれません。__
2
+
3
+ # 学習について、共通編
4
+
5
+ 当リポジトリではモデルのfine tuning、DreamBooth、およびLoRAとTextual Inversionの学習をサポートします。この文書ではそれらに共通する、学習データの準備方法やオプション等について説明します。
6
+
7
+ # 概要
8
+
9
+ あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。
10
+
11
+
12
+ 以下について説明します。
13
+
14
+ 1. 学習データの準備について(設定ファイルを用いる新形式)
15
+ 1. 学習で使われる用語のごく簡単な解説
16
+ 1. 以前の指定形式(設定ファイルを用いずコマンドラインから指定)
17
+ 1. 学習途中のサンプル画像生成
18
+ 1. 各スクリプトで共通の、よく使われるオプション
19
+ 1. fine tuning 方式のメタデータ準備:キャプションニングなど
20
+
21
+ 1.だけ実行すればとりあえず学習は可能です(学習については各スクリプトのドキュメントを参照)。2.以降は必要に応じて参照してください。
22
+
23
+
24
+ # 学習データの準備について
25
+
26
+ 任意のフォルダ(複数でも可)に学習データの画像ファイルを用意しておきます。`.png`, `.jpg`, `.jpeg`, `.webp`, `.bmp` をサポートします。リサイズなどの前処理は基本的に必要ありません。
27
+
28
+ ただし学習解像度(後述)よりも極端に小さい画像は使わないか、あらかじめ超解像AIなどで拡大しておくことをお勧めします。また極端に大きな画像(3000x3000ピクセル程度?)よりも大きな画像はエラーになる場合があるようですので事前に縮小してください。
29
+
30
+ 学習時には、モデルに学ばせる画像データを整理し、スクリプトに対して指定する必要があります。学習データの数、学習対象、キャプション(画像の説明)が用意できるか否かなどにより、いくつかの方法で学習データを指定できます。以下の方式があります(それぞれの名前は一般的なものではなく、当リポジトリ独自の定義です)。正則化画像については後述します。
31
+
32
+ 1. DreamBooth、class+identifier方式(正則化画像使用可)
33
+
34
+ 特定の単語 (identifier) に学習対象を紐づけるように学習します。キャプションを用意する必要はありません。たとえば特定のキャラを学ばせる場合に使うとキャプションを用意する必要がない分、手軽ですが、髪型や服装、背景など学習データの全要素が identifier に紐づけられて学習されるため、生成時のプロンプトで服が変えられない、といった事態も起こりえます。
35
+
36
+ 1. DreamBooth、キャプション方式(正則化画像使用可)
37
+
38
+ 画像ごとにキャプションが記録されたテキストファイルを用意して学習します。たとえば特定のキャラを学ばせると、画像の詳細をキャプションに記述することで(白い服を着たキャラA、赤い服を着たキャラA、など)キャラとそれ以外の要素が分離され、より厳密にモデルがキャラだけを学ぶことが期待できます。
39
+
40
+ 1. fine tuning方式(正則化画像使用不可)
41
+
42
+ あらかじめキャプションをメタデータファイルにまとめます。タグとキャプションを分けて管理したり、学習を高速化するためlatentsを事前キャッシュしたりなどの機能をサポートします(いずれも別文書で説明しています)。(fine tuning方式という名前ですが fine tuning 以外でも使えます。)
43
+
44
+ 学習したいものと使用できる指定方法の組み合わせは以下の通りです。
45
+
46
+ | 学習対象または方法 | スクリプト | DB / class+identifier | DB / キャプション | fine tuning |
47
+ | ----- | ----- | ----- | ----- | ----- |
48
+ | モデルをfine tuning | `fine_tune.py`| x | x | o |
49
+ | モデルをDreamBooth | `train_db.py`| o | o | x |
50
+ | LoRA | `train_network.py`| o | o | o |
51
+ | Textual Invesion | `train_textual_inversion.py`| o | o | o |
52
+
53
+ ## どれを選ぶか
54
+
55
+ LoRA、Textual Inversionについては、手軽にキャプションファイルを用意せずに学習したい場合はDreamBooth class+identifier、用意できるならDreamBooth キャプション方式がよいでしょう。学習データの枚数が多く、かつ正則化画像を使用しない場合はfine tuning方式も検討してください。
56
+
57
+ DreamBoothについても同様ですが、fine tuning方式は使えません。fine tuningの場合はfine tuning方式のみです。
58
+
59
+ # 各方式の指定方法について
60
+
61
+ ここではそれぞれの指定方法で典型的なパターンについてだけ説明します。より詳細な指定方法については [データセット��定](./config_README-ja.md) をご覧ください。
62
+
63
+ # DreamBooth、class+identifier方式(正則化画像使用可)
64
+
65
+ この方式では、各画像は `class identifier` というキャプションで学習されたのと同じことになります(`shs dog` など)。
66
+
67
+ ## step 1. identifierとclassを決める
68
+
69
+ 学ばせたい対象を結びつける単語identifierと、対象の属するclassを決めます。
70
+
71
+ (instanceなどいろいろな呼び方がありますが、とりあえず元の論文に合わせます。)
72
+
73
+ 以下ごく簡単に説明します(詳しくは調べてください)。
74
+
75
+ classは学習対象の一般的な種別です。たとえば特定の犬種を学ばせる場合には、classはdogになります。アニメキャラならモデルによりboyやgirl、1boyや1girlになるでしょう。
76
+
77
+ identifierは学習対象を識別して学習するためのものです。任意の単語で構いませんが、元論文によると「tokinizerで1トークンになる3文字以下でレアな単語」が良いとのことです。
78
+
79
+ identifierとclassを使い、たとえば「shs dog」などでモデルを学習することで、学習させたい対象をclassから識別して学習できます。
80
+
81
+ 画像生成時には「shs dog」とすれば学ばせた犬種の画像が生成されます。
82
+
83
+ (identifierとして私が最近使っているものを参考までに挙げると、``shs sts scs cpc coc cic msm usu ici lvl cic dii muk ori hru rik koo yos wny`` などです。本当は Danbooru Tag に含まれないやつがより望ましいです。)
84
+
85
+ ## step 2. 正則化画像を使うか否かを決め、使う場合には正則化画像を生成する
86
+
87
+ 正則化画像とは、前述のclass全体が、学習対象に引っ張られることを防ぐための画像です(language drift)。正則化画像を使わないと、たとえば `shs 1girl` で特定のキャラクタを学ばせると、単なる `1girl` というプロンプトで生成してもそのキャラに似てきます。これは `1girl` が学習時のキャプションに含まれているためです。
88
+
89
+ 学習対象の画像と正則化画像を同時に学ばせることで、class は class のままで留まり、identifier をプロンプトにつけた時だけ学習対象が生成されるようになります。
90
+
91
+ LoRAやDreamBoothで特定のキャラだけ出てくればよい場合は、正則化画像を用いなくても良いといえます。
92
+
93
+ Textual Inversionでは用いなくてよいでしょう(学ばせる token string がキャプションに含まれない場合はなにも学習されないため)。
94
+
95
+ 正則化画像としては、学習対象のモデルで、class 名だけで生成した画像を用いるのが一般的です(たとえば `1girl`)。ただし生成画像の品質が悪い場合には、プロンプトを工夫したり、ネットから別途ダウンロードした画像を用いることもできます。
96
+
97
+ (正則化画像も学習されるため、その品質はモデルに影響します。)
98
+
99
+ 一般的には数百枚程度、用意するのが望ましいようです(枚数が少ないと class 画像が一般化されずそれらの特徴を学んでしまいます)。
100
+
101
+ 生成画像を使う場合、通常、生成画像のサイズは学習解像度(より正確にはbucketの解像度、後述)にあわせてください。
102
+
103
+ ## step 2. 設定ファイルの記述
104
+
105
+ テキストファイルを作成し、拡張子を `.toml` にします。たとえば以下のように記述します。
106
+
107
+ (`#` で始まっている部分はコメントですので、このままコピペしてそのままでもよいですし、削除しても問題ありません。)
108
+
109
+ ```toml
110
+ [general]
111
+ enable_bucket = true # Aspect Ratio Bucketingを使うか否か
112
+
113
+ [[datasets]]
114
+ resolution = 512 # 学習解像度
115
+ batch_size = 4 # バッチサイズ
116
+
117
+ [[datasets.subsets]]
118
+ image_dir = 'C:\hoge' # 学習用画像を入れたフォルダを指定
119
+ class_tokens = 'hoge girl' # identifier class を指定
120
+ num_repeats = 10 # 学習用画像の繰り返し回数
121
+
122
+ # 以下は正則化画像を用いる場合のみ記述する。用いない場合は削除する
123
+ [[datasets.subsets]]
124
+ is_reg = true
125
+ image_dir = 'C:\reg' # 正則化画像を入れたフォルダを指定
126
+ class_tokens = 'girl' # class を指定
127
+ num_repeats = 1 # 正則化画像の繰り返し回数、基本的には1でよい
128
+ ```
129
+
130
+ 基本的には以下の場所のみ書き換えれば学習できます。
131
+
132
+ 1. 学習解像度
133
+
134
+ 数値1つを指定すると正方形(`512`なら512x512)、鍵カッコカンマ区切りで2つ指定すると横×縦(`[512,768]`なら512x768)になります。SD1.x系では��ともとの学習解像度は512です。`[512,768]` 等の大きめの解像度を指定すると縦長、横長画像生成時の破綻を小さくできるかもしれません。SD2.x 768系では `768` です。
135
+
136
+ 1. バッチサイズ
137
+
138
+ 同時に何件のデータを学習するかを指定します。GPUのVRAMサイズ、学習解像度によって変わってきます。詳しくは後述します。またfine tuning/DreamBooth/LoRA等でも変わってきますので各スクリプトの説明もご覧ください。
139
+
140
+ 1. フォルダ指定
141
+
142
+ 学習用画像、正則化画像(使用する場合のみ)のフォルダを指定します。画像データが含まれているフォルダそのものを指定します。
143
+
144
+ 1. identifier と class の指定
145
+
146
+ 前述のサンプルの通りです。
147
+
148
+ 1. 繰り返し回数
149
+
150
+ 後述します。
151
+
152
+ ### 繰り返し回数について
153
+
154
+ 繰り返し回数は、正則化画像の枚数と学習用画像の枚数を調整するために用いられます。正則化画像の枚数は学習用画像よりも多いため、学習用画像を繰り返して枚数を合わせ、1対1の比率で学習できるようにします。
155
+
156
+ 繰り返し回数は「 __学習用画像の繰り返し回数×学習用画像の枚数≧正則化画像の繰り返し回数×正則化画像の枚数__ 」となるように指定してください。
157
+
158
+ (1 epoch(データが一周すると1 epoch)のデータ数が「学習用画像の繰り返し回数×学習用画像の枚数」となります。正則化画像の枚数がそれより多いと、余った部分の正則化画像は使用されません。)
159
+
160
+ ## step 3. 学習
161
+
162
+ それぞれのドキュメントを参考に学習を行ってください。
163
+
164
+ # DreamBooth、キャプション方式(正則化画像使用可)
165
+
166
+ この方式では各画像はキャプションで学習されます。
167
+
168
+ ## step 1. キャプションファイルを準備する
169
+
170
+ 学習用画像のフォルダに、画像と同じファイル名で、拡張子 `.caption`(設定で変えられます)のファイルを置いてください。それぞれのファイルは1行のみとしてください。エンコーディングは `UTF-8` です。
171
+
172
+ ## step 2. 正則化画像を使うか否かを決め、使う場合には正則化画像を生成する
173
+
174
+ class+identifier形式と同様です。なお正則化画像にもキャプションを付けることができますが、通常は不要でしょう。
175
+
176
+ ## step 2. 設定ファイルの記述
177
+
178
+ テキストファイルを作成し、拡張子を `.toml` にします。たとえば以下のように記述します。
179
+
180
+ ```toml
181
+ [general]
182
+ enable_bucket = true # Aspect Ratio Bucketingを使うか否か
183
+
184
+ [[datasets]]
185
+ resolution = 512 # 学習解像度
186
+ batch_size = 4 # バッチサイズ
187
+
188
+ [[datasets.subsets]]
189
+ image_dir = 'C:\hoge' # 学習用画像を入れたフォルダを指定
190
+ caption_extension = '.caption' # キャプションファイルの拡張子 .txt を使う場合には書き換える
191
+ num_repeats = 10 # 学習用画像の繰り返し回数
192
+
193
+ # 以下は正則化画像を用いる場合のみ記述する。用いない場合は削除する
194
+ [[datasets.subsets]]
195
+ is_reg = true
196
+ image_dir = 'C:\reg' # 正則化画像を入れたフォルダを指定
197
+ class_tokens = 'girl' # class を指定
198
+ num_repeats = 1 # 正則化画像の繰り返し回数、基本的には1でよい
199
+ ```
200
+
201
+ 基本的には以下を場所のみ書き換えれば学習できます。特に記述がない部分は class+identifier 方式と同じです。
202
+
203
+ 1. 学習解像度
204
+ 1. バッチサイズ
205
+ 1. フォルダ指定
206
+ 1. キャプションファイルの拡張子
207
+
208
+ 任意の拡張子を指定できます。
209
+ 1. 繰り返し回数
210
+
211
+ ## step 3. 学習
212
+
213
+ それぞれのドキュメントを参考に学習を行ってください。
214
+
215
+ # fine tuning 方式
216
+
217
+ ## step 1. メタデータを準備する
218
+
219
+ キャプションやタグをまとめた管理用ファイルをメタデータと呼びます。json形式で拡張子は `.json`
220
+ です。作成方法は長くなりますのでこの文書の末尾に書きました。
221
+
222
+ ## step 2. 設定ファイルの記述
223
+
224
+ テキストファイルを作成し、拡張子を `.toml` にします。たとえば以下のように記述します。
225
+
226
+ ```toml
227
+ [general]
228
+ shuffle_caption = true
229
+ keep_tokens = 1
230
+
231
+ [[datasets]]
232
+ resolution = 512 # 学習解像度
233
+ batch_size = 4 # バッチサイズ
234
+
235
+ [[datasets.subsets]]
236
+ image_dir = 'C:\piyo' # 学習用画像を入れたフォルダを指定
237
+ metadata_file = 'C:\piyo\piyo_md.json' # メタデータファイル名
238
+ ```
239
+
240
+ 基本的には以下を場所のみ書き換えれば学習できます。特に記述がない部分は DreamBooth, class+identifier 方式と同じです。
241
+
242
+ 1. 学習解像度
243
+ 1. バッチサイズ
244
+ 1. フォルダ指定
245
+ 1. メタデータファイル名
246
+
247
+ 後述の方法で作成したメタデータファイルを指定します。
248
+
249
+
250
+ ## step 3. 学習
251
+
252
+ それぞれのドキュメントを参考に学習を行ってください。
253
+
254
+ # 学習で使われる用語のごく簡単な解説
255
+
256
+ 細かいことは省略していますし私も完全には理解していないため、詳しくは各自お調べください。
257
+
258
+ ## fine tuning(ファインチューニング)
259
+
260
+ モデルを学習して微調整することを指します。使われ方によって意味が異なってきますが、狭義のfine tuningはStable Diffusionの場合、モデルを画像とキャプションで学習することです。DreamBoothは狭義のfine tuningのひとつの特殊なやり方と言えます。広義のfine tuningは、LoRAやTextual Inversion、Hypernetworksなどを含み、モデルを学習することすべてを含みます。
261
+
262
+ ## ステップ
263
+
264
+ ざっくりいうと学習データで1回計算すると1ステップです。「学習データのキャプションを今のモデルに流してみて、出てくる画像を学習データの画像と比較し、学習データに近づくようにモデルをわずかに変更する」のが1ステップです。
265
+
266
+ ## バッチサイズ
267
+
268
+ バッチサイズは1ステップで何件のデータをまとめて計算するかを指定する値です。まとめて計算するため速度は相対的に向上します。また一般的には精度も高くなるといわれています。
269
+
270
+ `バッチサイズ×ステップ数` が学習に使われるデータの件数になります。そのため、バッチサイズを増やした分だけステップ数を減らすとよいでしょう。
271
+
272
+ (ただし、たとえば「バッチサイズ1で1600ステップ」と「バッチサイズ4で400ステップ」は同じ結果にはなりません。同じ学習率の場合、一般的には後者のほうが学習不足になります。学習率を多少大きくするか(たとえば `2e-6` など)、ステップ数をたとえば500ステップにするなどして工夫してください。)
273
+
274
+ バッチサイズを大きくするとその分だけGPUメモリを消費します。メモリが足りなくなるとエラーになりますし、エラーにならないギリギリでは学習速度が低下します。タスクマネージャーや `nvidia-smi` コマンドで使用メモリ量を確認しながら調整するとよいでしょう。
275
+
276
+ なお、バッチは「一塊のデータ」位の意味です。
277
+
278
+ ## 学習率
279
+
280
+ ざっくりいうと1ステップごとにどのくらい変化させるかを表します。大きな値を指定するとそれだけ速く学習が進みますが、変化しすぎてモデルが壊れたり、最適な状態にまで至れない場合があります。小さい値を指定すると学習速度は遅くなり、また最適な状態にやはり至れない場合があります。
281
+
282
+ fine tuning、DreamBoooth、LoRAそれぞれで大きく異なり、また学習データや学習させたいモデル、バッチサイズやステップ数によっても変わってきます。一般的な値から初めて学習状態を見ながら増減してください。
283
+
284
+ デフォルトでは学習全体を通して学習率は固定です。スケジューラの指定で学習率をどう変化させるか決められますので、それらによっても結果は変わってきます。
285
+
286
+ ## エポック(epoch)
287
+
288
+ 学習データが一通り学習されると(データが一周すると)1 epochです。繰り返し回数を指定した場合は、その繰り返し後のデータが一周すると1 epochです。
289
+
290
+ 1 epochのステップ数は、基本的には `データ件数÷バッチサイズ` ですが、Aspect Ratio Bucketing を使うと微妙に増えます(異なるbucketのデータは同じバッチにできないため、ステップ数が増えます)。
291
+
292
+ ## Aspect Ratio Bucketing
293
+
294
+ Stable Diffusion のv1は512\*512で学習されていますが、それに加えて256\*1024や384\*640といった解像度でも学習します。これによりトリミングされる部分が減り、より正しくキャプションと画像の関係が学習されることが期待されます。
295
+
296
+ また任意の解像度で学習するため、事前に画像データの縦横比を統一しておく必要がなくなります。
297
+
298
+ 設定で有効、向こうが切り替えられますが、ここまでの設定ファイルの記述例では有効になっています(`true` が設定されています)。
299
+
300
+ 学習解像度はパラメータとして与えられた解像度の面積(=メモリ使用量)を超えない範囲で、64ピクセル単位(デフォルト、変��可)で縦横に調整、作成されます。
301
+
302
+ 機械学習では入力サイズをすべて統一するのが一般的ですが、特に制約があるわけではなく、実際は同一のバッチ内で統一されていれば大丈夫です。NovelAIの言うbucketingは、あらかじめ教師データを、アスペクト比に応じた学習解像度ごとに分類しておくことを指しているようです。そしてバッチを各bucket内の画像で作成することで、バッチの画像サイズを統一します。
303
+
304
+ # 以前の指定形式(設定ファイルを用いずコマンドラインから指定)
305
+
306
+ `.toml` ファイルを指定せずコマンドラインオプションで指定する方法です。DreamBooth class+identifier方式、DreamBooth キャプション方式、fine tuning方式があります。
307
+
308
+ ## DreamBooth、class+identifier方式
309
+
310
+ フォルダ名で繰り返し回数を指定します。また `train_data_dir` オプションと `reg_data_dir` オプションを用います。
311
+
312
+ ### step 1. 学習用画像の準備
313
+
314
+ 学習用画像を格納するフォルダを作成します。 __さらにその中に__ 、以下の名前でディレクトリを作成します。
315
+
316
+ ```
317
+ <繰り返し回数>_<identifier> <class>
318
+ ```
319
+
320
+ 間の``_``を忘れないでください。
321
+
322
+ たとえば「sls frog」というプロンプトで、データを20回繰り返す場合、「20_sls frog」となります。以下のようになります。
323
+
324
+ ![image](https://user-images.githubusercontent.com/52813779/210770636-1c851377-5936-4c15-90b7-8ac8ad6c2074.png)
325
+
326
+ ### 複数class、複数対象(identifier)の学習
327
+
328
+ 方法は単純で、学習用画像のフォルダ内に ``繰り返し回数_<identifier> <class>`` のフォルダを複数、正則化画像フォルダにも同様に ``繰り返し回数_<class>`` のフォルダを複数、用意してください。
329
+
330
+ たとえば「sls frog」と「cpc rabbit」を同時に学習する場合、以下のようになります。
331
+
332
+ ![image](https://user-images.githubusercontent.com/52813779/210777933-a22229db-b219-4cd8-83ca-e87320fc4192.png)
333
+
334
+ classがひとつで対象が複数の場合、正則化画像フォルダはひとつで構いません。たとえば1girlにキャラAとキャラBがいる場合は次のようにします。
335
+
336
+ - train_girls
337
+ - 10_sls 1girl
338
+ - 10_cpc 1girl
339
+ - reg_girls
340
+ - 1_1girl
341
+
342
+ ### step 2. 正則化画像の準備
343
+
344
+ 正則化画像を使う場合の手順です。
345
+
346
+ 正則化画像を格納するフォルダを作成します。 __さらにその中に__ ``<繰り返し回数>_<class>`` という名前でディレクトリを作成します。
347
+
348
+ たとえば「frog」というプロンプトで、データを繰り返さない(1回だけ)場合、以下のようになります。
349
+
350
+ ![image](https://user-images.githubusercontent.com/52813779/210770897-329758e5-3675-49f1-b345-c135f1725832.png)
351
+
352
+
353
+ ### step 3. 学習の実行
354
+
355
+ 各学習スクリプトを実行します。 `--train_data_dir` オプションで前述の学習用データのフォルダを(__画像を含むフォルダではなく、その親フォルダ__)、`--reg_data_dir` オプションで正則化画像のフォルダ(__画像を含むフォルダではなく、その親フォルダ__)を指定してください。
356
+
357
+ ## DreamBooth、キャプション方式
358
+
359
+ 学習用画像、正則化画像のフォルダに、画像と同じファイル名で、拡張子.caption(オプションで変えられます)のファイルを置くと、そのファイルからキャプションを読み込みプロンプトとして学習します。
360
+
361
+ ※それらの画像の学習に、フォルダ名(identifier class)は使用されなくなります。
362
+
363
+ キャプションファイルの拡張子はデフォルトで.captionです。学習スクリプトの `--caption_extension` オプションで変更できます。`--shuffle_caption` オプションで学習時のキャプションについて、カンマ区切りの各部分をシャッフルしながら学習します。
364
+
365
+ ## fine tuning 方式
366
+
367
+ メタデータを作るところまでは設定ファイルを使う場合と同様です。`in_json` オプションでメタデータファイルを指定します。
368
+
369
+ # 学習途中でのサンプル出力
370
+
371
+ 学習中のモデルで試しに画像生成することで学習の進み方を確認できます。学習スクリプトに以下のオプションを指定します。
372
+
373
+ - `--sample_every_n_steps` / `--sample_every_n_epochs`
374
+
375
+ サンプル出力するステップ数またはエポック数を指定します。この数ごとにサンプル出力します。両方指定するとエポック数が優先されます。
376
+
377
+ - `--sample_prompts`
378
+
379
+ サンプル出力用プロンプトのファイルを指定します。
380
+
381
+ - `--sample_sampler`
382
+
383
+ サンプル出力に使うサンプラーを指定します。
384
+ `'ddim', 'pndm', 'heun', 'dpmsolver', 'dpmsolver++', 'dpmsingle', 'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'`が選べます。
385
+
386
+ サンプル出力を行うにはあらかじめプロンプトを記述したテキストファイルを用意しておく必要があります。1行につき1プロンプトで記述します。
387
+
388
+ たとえば以下のようになります。
389
+
390
+ ```txt
391
+ # prompt 1
392
+ masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
393
+
394
+ # prompt 2
395
+ masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
396
+ ```
397
+
398
+ 先頭が `#` の行はコメントになります。`--n` のように 「`--` + 英小文字」で生成画像へのオプションを指定できます。以下が使えます。
399
+
400
+ - `--n` 次のオプションまでをネガティブプロンプトとします。
401
+ - `--w` 生成画像の横幅を指定します。
402
+ - `--h` 生成画像の高さを指定します。
403
+ - `--d` 生成画像のseedを指定します。
404
+ - `--l` 生成画像のCFG scaleを指定します。
405
+ - `--s` 生成時のステップ数を指定します。
406
+
407
+
408
+ # 各スクリプトで共通の、よく使われるオプション
409
+
410
+ スクリプトの更新後、ドキュメントの更新が追い付いていない場合があります。その場合は `--help` オプションで使用できるオプションを確認してください。
411
+
412
+ ## 学習に使うモデル指定
413
+
414
+ - `--v2` / `--v_parameterization`
415
+
416
+ 学習対象モデルとしてHugging Faceのstable-diffusion-2-base、またはそこからのfine tuningモデルを使う場合(推論時に `v2-inference.yaml` を使うように指示されているモデルの場合)は `--v2` オプションを、stable-diffusion-2や768-v-ema.ckpt、およびそれらのfine tuningモデルを使う場合(推論時に `v2-inference-v.yaml` を使うモデルの場合)は `--v2` と `--v_parameterization` の両方のオプションを指定してください。
417
+
418
+ Stable Diffusion 2.0では大きく以下の点が変わっています。
419
+
420
+ 1. 使用するTokenizer
421
+ 2. 使用するText Encoderおよび使用する出力層(2.0は最後から二番目の層を使う)
422
+ 3. Text Encoderの出力次元数(768->1024)
423
+ 4. U-Netの構造(CrossAttentionのhead数など)
424
+ 5. v-parameterization(サンプリング方法が変更されているらしい)
425
+
426
+ このうちbaseでは1~4が、baseのつかない方(768-v)では1~5が採用されています。1~4を有効にするのがv2オプション、5を有効にするのがv_parameterizationオプションです。
427
+
428
+ - `--pretrained_model_name_or_path`
429
+
430
+ 追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。
431
+
432
+ ## 学習に関する設定
433
+
434
+ - `--output_dir`
435
+
436
+ 学習後のモデルを保存するフォルダを指定します。
437
+
438
+ - `--output_name`
439
+
440
+ モデルのファイル名を拡張子を除いて指定します。
441
+
442
+ - `--dataset_config`
443
+
444
+ データセットの設定を記述した `.toml` ファイルを指定します。
445
+
446
+ - `--max_train_steps` / `--max_train_epochs`
447
+
448
+ 学習するステップ数やエポック数を指定します。両方指定するとエポック数のほうが優先されます。
449
+
450
+ - `--mixed_precision`
451
+
452
+ 省メモリ化のため mixed precision (混合精度)で学習します。`--mixed_precision="fp16"` のように指定します。mixed precision なし(デフォルト)と比べて精度が低くなる可能性がありますが、学習に必要なGPUメモリ量が大きく減ります。
453
+
454
+ (RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。
455
+
456
+ - `--gradient_checkpointing`
457
+
458
+ 学習時の重みの計算をまとめて行うのではなく少しずつ行うことで、学習に必要なGPUメモリ量を減らします。オンオフは精度には影響しませんが、オンにするとバッチサイズを大きくできるため、そちらでの影響はあります。
459
+
460
+ また一般的にはオンにすると速度は低下しますが、バッチサイズを大きくできるので、トータルでの学習時間はむしろ速くなるかもしれません。
461
+
462
+ - `--xformers` / `--mem_eff_attn`
463
+
464
+ xformersオプションを指定するとxformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(xformersよりも速度は遅くなります)。
465
+
466
+ - `--save_precision`
467
+
468
+ 保存時のデータ精度を指定します。save_precisionオプションにfloat、fp16、bf16のいずれかを指定すると、その形式でモデルを保存します(DreamBooth、fine tuningでDiffusers形式でモデルを保存する場合は無効です)。モデルのサイズを削減したい場合などにお使いください。
469
+
470
+ - `--save_every_n_epochs` / `--save_state` / `--resume`
471
+ save_every_n_epochsオプションに数値を指定すると、そのエポックごとに学習途中のモデルを保存します。
472
+
473
+ save_stateオプションを同時に指定すると、optimizer等の状態も含めた学習状態を合わせて保存します(保存したモデルからも学習再開できますが、それに比べると精度の向上、学習時間の短縮が期待できます)。保存先はフォルダになります。
474
+
475
+ 学習状態は保存先フォルダに `<output_name>-??????-state`(??????はエポック数)という名前のフォルダで出力されます。長時間にわたる学習時にご利用ください。
476
+
477
+ 保存された学習状態から学習を再開するにはresumeオプションを使います。学習状態のフォルダ(`output_dir` ではなくその中のstateのフォルダ)を指定してください。
478
+
479
+ なおAcceleratorの仕様により、エポック数、global stepは保存されておらず、resumeしたときにも1からになりますがご容赦ください。
480
+
481
+ - `--save_model_as` (DreamBooth, fine tuning のみ)
482
+
483
+ モデルの保存形式を`ckpt, safetensors, diffusers, diffusers_safetensors` から選べます。
484
+
485
+ `--save_model_as=safetensors` のように指定します。Stable Diffusion形式(ckptまたはsafetensors)を読み込み、Diffusers形式で保存する場合、不足する情報はHugging Faceからv1.5またはv2.1の情報を落としてきて補完します。
486
+
487
+ - `--clip_skip`
488
+
489
+ `2` を指定すると、Text Encoder (CLIP) の後ろから二番目の層の出力を用います。1またはオプション省略時は最後の層を用います。
490
+
491
+ ※SD2.0はデフォルトで後ろから二番目の層を使うため、SD2.0の学習では指定しないでください。
492
+
493
+ 学習対象のモデルがもともと二番目の層を使うように学習されている場合は、2を指定するとよいでしょう。
494
+
495
+ そうではなく最後の層を使用していた場合はモデル全体がそれを前提に学習されています。そのため改めて二番目の層を使用して学習すると、望ましい学習結果を得るにはある程度の枚数の教師データ、長めの学習が必要になるかもしれません。
496
+
497
+ - `--max_token_length`
498
+
499
+ デフォルトは75です。`150` または `225` を指定することでトークン長を拡張して学習できます。長いキャプションで学習する場合に指定してください。
500
+
501
+ ただし学習時のトークン拡張の仕様は Automatic1111 氏のWeb UIとは微妙に異なるため(分割の仕様など)、必要なければ75で学習することをお勧めします。
502
+
503
+ clip_skipと同様に、モデルの学習状態と異なる長さで学習するには、ある程度の教師データ枚数、長めの学習時間が必要になると思われます。
504
+
505
+ - `--persistent_data_loader_workers`
506
+
507
+ Windows環境で指定するとエポック間の待ち時間が大幅に短縮されます。
508
+
509
+ - `--max_data_loader_n_workers`
510
+
511
+ データ読み込みのプロセス数を指定します。プロセス数が多いとデータ読み込みが速くなりGPUを効率的に利用できますが、メインメモリを消費します。デフォルトは「`8` または `CPU同時実行スレッド数-1` の小さいほう」なので、メインメモリに余裕がない場合や、GPU使用率が90%程度以上なら、それらの数値を見ながら `2` または `1` 程度まで下げてください。
512
+
513
+ - `--logging_dir` / `--log_prefix`
514
+
515
+ 学習ログの保存に関するオプションです。logging_dirオプションにログ保存先フォルダを指定してください。TensorBoard形式のログが保存されます。
516
+
517
+ たとえば--logging_dir=logsと指定すると、作業フォルダにlogsフォルダが作成され、その中の日時フォルダにログが保存されます。
518
+ また--log_prefixオプションを指定すると、日時の前に指定した文字列が追加されます。「--logging_dir=logs --log_prefix=db_style1_」などとして識別用にお使いください。
519
+
520
+ TensorBoardでログを確認するには、別のコマンドプロンプトを開き、作業フォルダで以下のように入力します。
521
+
522
+ ```
523
+ tensorboard --logdir=logs
524
+ ```
525
+
526
+ (tensorboardは環境整備時にあわせてインストールされると思いますが、もし入っていないなら `pip install tensorboard` で入れてください。)
527
+
528
+ その後ブラウザを開き、http://localhost:6006/ へアクセスすると表示されます。
529
+
530
+ - `--noise_offset`
531
+
532
+ こちらの記事の実装になります: https://www.crosslabs.org//blog/diffusion-with-offset-noise
533
+
534
+ 全体的に暗い、明るい画像の生成結果が良くなる可能性があるようです。LoRA学習でも有効なようです。`0.1` 程度の値を指定するとよいようです。
535
+
536
+ - `--debug_dataset`
537
+
538
+ このオプションを付けることで学習を行う前に事前にどのような画像データ、キャプションで学習されるかを確認できます。Escキーを押すと終了してコマンドラインに戻ります。
539
+
540
+ ※Linux環境(Colabを含む)では画像は表示されません。
541
+
542
+ - `--vae`
543
+
544
+ vaeオプションにStable Diffusionのcheckpoint、VAEのcheckpointファイル、DiffusesのモデルまたはVAE(ともにローカルまたはHugging FaceのモデルIDが指定できます)のいずれかを指定すると、そのVAEを使って学習します(latentsのキャッシュ時または学習中のlatents取得時)。
545
+
546
+ DreamBoothおよびfine tuningでは、保存されるモデルはこのVAEを組み込んだものになります。
547
+
548
+
549
+ ## オプティマイザ関係
550
+
551
+ - `--optimizer_type`
552
+ --オプティマイザの種類を指定します。以下が指定できます。
553
+ - AdamW : [torch.optim.AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html)
554
+ - 過去のバージョンのオプション未指定時と同じ
555
+ - AdamW8bit : 引数は同上
556
+ - 過去のバージョンの--use_8bit_adam指定時と同じ
557
+ - Lion : https://github.com/lucidrains/lion-pytorch
558
+ - 過去のバージョンの--use_lion_optimizer指定時と同じ
559
+ - SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True
560
+ - SGDNesterov8bit : 引数は同上
561
+ - DAdaptation : https://github.com/facebookresearch/dadaptation
562
+ - AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
563
+ - 任意のオプティマイザ
564
+
565
+ - `--learning_rate`
566
+
567
+ 学習率を指定します。適切な学習率は学習スクリプトにより異なりますので、それぞれの説明を参照してください。
568
+
569
+ - `--lr_scheduler` / `--lr_warmup_steps` / `--lr_scheduler_num_cycles` / `--lr_scheduler_power`
570
+
571
+ 学習率のスケジューラ関連の指定です。
572
+
573
+ lr_schedulerオプションで学習率のスケジューラをlinear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmupから選べます。デフォルトはconstantです。
574
+
575
+ lr_warmup_stepsでスケジューラのウォームアップ(だんだん学習率を変えていく)ステップ数を指定できます。
576
+
577
+ lr_scheduler_num_cycles は cosine with restartsスケジューラでのリスタート回数、lr_scheduler_power は polynomialスケジューラでのpolynomial power です。
578
+
579
+ 詳細については各自お調べください。
580
+
581
+ ### オプティマイザの指定について
582
+
583
+ オプティマイザのオプション引数は--optimizer_argsオプションで指定してください。key=valueの形式で、複数の値が指定できます。また、valueはカンマ区切りで複数の値が指定できます。たとえばAdamWオプティマイザに引数を指定する場合は、``--optimizer_args weight_decay=0.01 betas=.9,.999``のようになります。
584
+
585
+ オプション引数を指定する場合は、それぞれのオプティマイザの仕様をご確認ください。
586
+
587
+ 一部のオプティマイザでは必須の引数があり、省略すると自動的に追加されます(SGDNesterovのmomentumなど)。コンソールの出力を確認してください。
588
+
589
+ D-Adaptationオプティマイザは学習率を自動調整します。学習率のオプションに指定した値は学習率そのものではなくD-Adaptationが決定した学習率の適用率になりますので、通常は1.0を指定してください。Text EncoderにU-Netの半分の学習率を指定したい場合は、``--text_encoder_lr=0.5 --unet_lr=1.0``と指定します。
590
+
591
+ AdaFactorオプティマイザはrelative_step=Trueを指定すると学習率を自動調整できます(省略時はデフォルトで追加されます)。自動調整する場合は学習率のスケジューラにはadafactor_schedulerが強制的に使用されます。またscale_parameterとwarmup_initを指定するとよいようです。
592
+
593
+ 自動調整する場合のオプション指定はたとえば ``--optimizer_args "relative_step=True" "scale_parameter=True" "warmup_init=True"`` のようになります。
594
+
595
+ 学習率を自動調整しない場合はオプション引数 ``relative_step=False`` を追加してください。その場合、学習率のスケジューラにはconstant_with_warmupが、また勾配のclip normをしないことが推奨されているようです。そのため引数は ``--optimizer_type=adafactor --optimizer_args "relative_step=False" --lr_scheduler="constant_with_warmup" --max_grad_norm=0.0`` のようになります。
596
+
597
+ ### 任意のオプティマイザを使う
598
+
599
+ ``torch.optim`` のオプティマイザを使う場合にはクラス名のみを(``--optimizer_type=RMSprop``など)、他のモジュールのオプティマイザを使う時は「モジュール名.クラス名」を指定してください(``--optimizer_type=bitsandbytes.optim.lamb.LAMB``など)。
600
+
601
+ (内部でimportlibしているだけで動作は未確認です。必要ならパッケージをインストールしてください。)
602
+
603
+
604
+ <!--
605
+ ## 任意サイズの画像での学習 --resolution
606
+ 正方形以外で学習できます。resolutionに「448,640」のように「幅,高さ」で指定してください。幅と高さは64で割り切れる必要があります。学習用画像、正則化画像のサイズを合わせてください。
607
+
608
+ 個人的には縦長の画像を生成することが多いため「448,640」などで学習することもあります。
609
+
610
+ ## Aspect Ratio Bucketing --enable_bucket / --min_bucket_reso / --max_bucket_reso
611
+ enable_bucketオプションを指定すると有効になります。Stable Diffusionは512x512で学習されていますが、それに加えて256x768や384x640といった解像度でも学習します。
612
+
613
+ このオプションを指定した場合は、学習用画像、正則化画像を特定の解像度に統一する必要はありません。いくつかの解像度(アスペクト比)から最適なものを選び、その解像度で学習します。
614
+ 解像度は64ピクセル単位のため、元画像とアスペクト比が完全に一致しない場合がありますが、その場合は、はみ出した部分がわずかにトリミングされます。
615
+
616
+ 解像度の最小サイズをmin_bucket_resoオプションで、最大サイズをmax_bucket_resoで指定できます。デフォルトはそれぞれ256、1024です。
617
+ たとえば最小サイズに384を指定すると、256x1024や320x768などの解像度は使わなくなります。
618
+ 解像度を768x768のように大きくした場合、最大サイズに1280などを指定しても良いかもしれません。
619
+
620
+ なおAspect Ratio Bucketingを有効にするときには、正則化画像についても、学習用画像と似た傾向の様々な解像度を用意した方がいいかもしれません。
621
+
622
+ (ひとつのバッチ内の画像が学習用画像、正則化画像に偏らなくなるため。そこまで大きな影響はないと思いますが……。)
623
+
624
+ ## augmentation --color_aug / --flip_aug
625
+ augmentationは学習時に動的にデータを変化させることで、モデルの性能を上げる手法です。color_augで色合いを微妙に変えつつ、flip_augで左右反転をしつつ、学習します。
626
+
627
+ 動的にデータを変化させるため、cache_latentsオプションと同時に指定できません。
628
+
629
+
630
+ ## 勾配をfp16とした学習(実験的機能) --full_fp16
631
+ full_fp16オプションを指定すると勾配を通常のfloat32からfloat16(fp16)に変更して学習します(mixed precisionではなく完全なfp16学習になるようです)。
632
+ これによりSD1.xの512x512サイズでは8GB未満、SD2.xの512x512サイズで12GB未満のVRAM使用量で学習できるようです。
633
+
634
+ あらかじめaccelerate configでfp16を指定し、オプションで ``mixed_precision="fp16"`` としてください(bf16では動作しません)。
635
+
636
+ メモリ使用量を最小化するためには、xformers、use_8bit_adam、cache_latents、gradient_checkpointingの各オプションを指定し、train_batch_sizeを1としてください。
637
+
638
+ (余裕があるようならtrain_batch_sizeを段階的に増やすと若干精度が上がるはずです。)
639
+
640
+ PyTorchのソースにパッチを当てて無理やり実現しています(PyTorch 1.12.1と1.13.0で確認)。精度はかなり落ちますし、途中で学習失敗する確率も高くなります。
641
+ 学習率やステップ数の設定もシビアなようです。それらを認識したうえで自己責任でお使いください。
642
+
643
+ -->
644
+
645
+ # メタデータファイルの作成
646
+
647
+ ## 教師データの用意
648
+
649
+ 前述のように学習させたい画像データを用意し、任意のフォルダに入れてください。
650
+
651
+ たとえば以下のように画像を格納します。
652
+
653
+ ![教師データフォルダのスクショ](https://user-images.githubusercontent.com/52813779/208907739-8e89d5fa-6ca8-4b60-8927-f484d2a9ae04.png)
654
+
655
+ ## 自動キャプショニング
656
+
657
+ キャプションを使わずタグだけで学習する場合はスキップしてください。
658
+
659
+ また手動でキャプションを用意する場合、キャプションは教師データ画像と同じディレクトリに、同じファイル名、拡張子.caption等で用意してください。各ファイルは1行のみのテキストファイルとします。
660
+
661
+ ### BLIPによるキャプショニング
662
+
663
+ 最新版ではBLIPのダウンロード、重みのダウンロード、仮想環境の追加は不要になりました。そのままで動作します。
664
+
665
+ finetuneフォルダ内のmake_captions.pyを実行します。
666
+
667
+ ```
668
+ python finetune\make_captions.py --batch_size <バッチサイズ> <教師データフォルダ>
669
+ ```
670
+
671
+ バッチサイズ8、教師データを親フォルダのtrain_dataに置いた場合、以下のようになります。
672
+
673
+ ```
674
+ python finetune\make_captions.py --batch_size 8 ..\train_data
675
+ ```
676
+
677
+ キャプションファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.captionで作成されます。
678
+
679
+ batch_sizeはGPUのVRAM容量に応じて増減してください。大きいほうが速くなります(VRAM 12GBでももう少し増やせると思います)。
680
+ max_lengthオプションでキャプションの最大長を指定できます。デフォルトは75です。モデルをトークン長225で学習する場合には長くしても良いかもしれません。
681
+ caption_extensionオプションでキャプションの拡張子を変更できます。デフォルトは.captionです(.txtにすると後述のDeepDanbooruと競合します)。
682
+
683
+ 複数の教師データフォルダがある場合には、それぞれのフォルダに対して実行してください。
684
+
685
+ なお、推論にランダム性があるため、実行するたびに結果が変わります。固定する場合には--seedオプションで `--seed 42` のように乱数seedを指定してください。
686
+
687
+ その他のオプションは `--help` でヘルプをご参照ください(パラメータの意味についてはドキュメントがまとまっていないようで、ソースを見るしかないようです)。
688
+
689
+ デフォルトでは拡張子.captionでキャプションファイルが生成されます。
690
+
691
+ ![captionが生成されたフォルダ](https://user-images.githubusercontent.com/52813779/208908845-48a9d36c-f6ee-4dae-af71-9ab462d1459e.png)
692
+
693
+ たとえば以下のようなキャプションが付きます。
694
+
695
+ ![キャプションと画像](https://user-images.githubusercontent.com/52813779/208908947-af936957-5d73-4339-b6c8-945a52857373.png)
696
+
697
+ ## DeepDanbooruによるタグ付け
698
+
699
+ danbooruタグのタグ付け自体を行わない場合は「キャプションとタグ情報の前処理」に進んでください。
700
+
701
+ タグ付けはDeepDanbooruまたはWD14Taggerで行います。WD14Taggerのほうが精度が良いようです。WD14Taggerでタグ付けする場合は、次の章へ進んでください。
702
+
703
+ ### 環境整備
704
+
705
+ DeepDanbooru https://github.com/KichangKim/DeepDanbooru を作業フォルダにcloneしてくるか、zipをダウンロードして展開します。私はzipで展開しました。
706
+ またDeepDanbooruのReleasesのページ https://github.com/KichangKim/DeepDanbooru/releases の「DeepDanbooru Pretrained Model v3-20211112-sgd-e28」のAssetsから、deepdanbooru-v3-20211112-sgd-e28.zipをダウンロードしてきてDeepDanbooruのフォルダに展開します。
707
+
708
+ 以下からダウンロードします。Assetsをクリックして開き、そこからダウンロードします。
709
+
710
+ ![DeepDanbooruダウンロードページ](https://user-images.githubusercontent.com/52813779/208909417-10e597df-7085-41ee-bd06-3e856a1339df.png)
711
+
712
+ 以下のようなこういうディレクトリ構造にしてください
713
+
714
+ ![DeepDanbooruのディレクトリ構造](https://user-images.githubusercontent.com/52813779/208909486-38935d8b-8dc6-43f1-84d3-fef99bc471aa.png)
715
+
716
+ Diffusersの環境に必要なライブラリをインストールします。DeepDanbooruのフォルダに移動してインストールします(実質的にはtensorflow-ioが追加されるだけだと思います)。
717
+
718
+ ```
719
+ pip install -r requirements.txt
720
+ ```
721
+
722
+ 続いてDeepDanbooru自体をインストールします。
723
+
724
+ ```
725
+ pip install .
726
+ ```
727
+
728
+ 以上でタグ付けの環境整備は完了です。
729
+
730
+ ### タグ付けの実施
731
+ DeepDanbooruのフォルダに移動し、deepdanbooruを実行してタグ付けを行います。
732
+
733
+ ```
734
+ deepdanbooru evaluate <教師データフォルダ> --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt
735
+ ```
736
+
737
+ 教師データを親フォルダのtrain_dataに置いた場合、以下のようになります。
738
+
739
+ ```
740
+ deepdanbooru evaluate ../train_data --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt
741
+ ```
742
+
743
+ タグファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.txtで作成されます。1件ずつ処理されるためわりと遅いです。
744
+
745
+ 複数の教師データフォルダがある場合には、それぞれのフォルダに対して実行してください。
746
+
747
+ 以下のように生成されます。
748
+
749
+ ![DeepDanbooruの生成ファイル](https://user-images.githubusercontent.com/52813779/208909855-d21b9c98-f2d3-4283-8238-5b0e5aad6691.png)
750
+
751
+ こんな感じにタグが付きます(すごい情報量……)。
752
+
753
+ ![DeepDanbooruタグと画像](https://user-images.githubusercontent.com/52813779/208909908-a7920174-266e-48d5-aaef-940aba709519.png)
754
+
755
+ ## WD14Taggerによるタグ付け
756
+
757
+ DeepDanbooruの代わりにWD14Taggerを用いる手順です。
758
+
759
+ Automatic1111氏のWebUIで使用しているtaggerを利用します。こちらのgithubページ(https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger )の情報を参考にさせていただきました。
760
+
761
+ 最初の環境整備で必要なモジュールはインストール済みです。また重みはHugging Faceから自動的にダウンロードしてきます。
762
+
763
+ ### タグ付けの実施
764
+
765
+ スクリプトを実行してタグ付けを行います。
766
+ ```
767
+ python tag_images_by_wd14_tagger.py --batch_size <バッチサイズ> <教師データフォルダ>
768
+ ```
769
+
770
+ 教師データを親フォルダのtrain_dataに置いた場合、以下のようになります。
771
+ ```
772
+ python tag_images_by_wd14_tagger.py --batch_size 4 ..\train_data
773
+ ```
774
+
775
+ 初回起動時にはモデルファイルがwd14_tagger_modelフォルダに自動的にダウンロードされます(フォルダはオプションで変えられます)。以下のようになります。
776
+
777
+ ![ダウンロードされたファイル](https://user-images.githubusercontent.com/52813779/208910447-f7eb0582-90d6-49d3-a666-2b508c7d1842.png)
778
+
779
+ タグファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.txtで作成されます。
780
+
781
+ ![生成されたタグファイル](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png)
782
+
783
+ ![タグと画像](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png)
784
+
785
+ threshオプションで、判定されたタグのconfidence(確信度)がいくつ以上でタグをつけるかが指定できます。デフォルトはWD14Taggerのサンプルと同じ0.35です。値を下げるとより多くのタグが付与されますが、精度は下がります。
786
+
787
+ batch_sizeはGPUのVRAM容量に応じて増減してください。大きいほうが速くなります(VRAM 12GBでももう少し増やせると思います)。caption_extensionオプションでタグファイルの拡張子を変更できます。デフォルトは.txtです。
788
+
789
+ model_dirオプションでモデルの保存先フォルダを指定できます。
790
+
791
+ またforce_downloadオプションを指定すると保存先フォルダがあってもモデルを再ダウンロードします。
792
+
793
+ 複数の教師データフォルダがある場合には、それぞれのフォルダに対して実行してください。
794
+
795
+ ## キャプションとタグ情報の前処理
796
+
797
+ スクリプトから処理しやすいようにキャプションとタグをメタデータとしてひとつのファイルにまとめます。
798
+
799
+ ### キャプションの前処理
800
+
801
+ キャプションをメタデータに入れるには、作業フォルダ内で以下を実行してください(キャプションを学習に使わない場合は実行不要です)(実際は1行で記述します、以下同様)。`--full_path` オプションを指定してメタデータに画像ファイルの場所をフルパスで格納します。このオプションを省略すると相対パスで記録されますが、フォルダ指定が `.toml` ファイル内で別途必要になります。
802
+
803
+ ```
804
+ python merge_captions_to_metadata.py --full_apth <教師データフォルダ>
805
+   --in_json <読み込むメタデータファイル名> <メタデータファイル名>
806
+ ```
807
+
808
+ メタデータファイル名は任意の名前です。
809
+ 教師データがtrain_data、読み込むメタデータファイルなし、メタデータファイルがmeta_cap.jsonの場合、以下のようになります。
810
+
811
+ ```
812
+ python merge_captions_to_metadata.py --full_path train_data meta_cap.json
813
+ ```
814
+
815
+ caption_extensionオプションでキャプションの拡張子を指定できます。
816
+
817
+ 複数の教師データフォルダがある場合には、full_path引数を指定しつつ、それぞれのフォルダに対して実行してください。
818
+
819
+ ```
820
+ python merge_captions_to_metadata.py --full_path
821
+ train_data1 meta_cap1.json
822
+ python merge_captions_to_metadata.py --full_path --in_json meta_cap1.json
823
+ train_data2 meta_cap2.json
824
+ ```
825
+
826
+ in_jsonを省略すると書き込み先メタデータファイルがあるとそこから読み込み、そこに上書きします。
827
+
828
+ __※in_json��プションと書き込み先を都度書き換えて、別のメタデータファイルへ書き出すようにすると安全です。__
829
+
830
+ ### タグの前処理
831
+
832
+ 同様にタグもメタデータにまとめます(タグを学習に使わない場合は実行不要です)。
833
+ ```
834
+ python merge_dd_tags_to_metadata.py --full_path <教師データフォルダ>
835
+ --in_json <読み込むメタデータファイル名> <書き込むメタデータファイル名>
836
+ ```
837
+
838
+ 先と同じディレクトリ構成で、meta_cap.jsonを読み、meta_cap_dd.jsonに書きだす場合、以下となります。
839
+ ```
840
+ python merge_dd_tags_to_metadata.py --full_path train_data --in_json meta_cap.json meta_cap_dd.json
841
+ ```
842
+
843
+ 複数の教師データフォルダがある場合には、full_path引数を指定しつつ、それぞれのフォルダに対して実行してください。
844
+
845
+ ```
846
+ python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap2.json
847
+ train_data1 meta_cap_dd1.json
848
+ python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap_dd1.json
849
+ train_data2 meta_cap_dd2.json
850
+ ```
851
+
852
+ in_jsonを省略すると書き込み先メタデータファイルがあるとそこから読み込み、そこに上書きします。
853
+
854
+ __※in_jsonオプションと書き込み先を都度書き換えて、別のメタデータファイルへ書き出すようにすると安全です。__
855
+
856
+ ### キャプションとタグのクリーニング
857
+
858
+ ここまででメタデータファイルにキャプションとDeepDanbooruのタグがまとめられています。ただ自動キャプショニングにしたキャプションは表記ゆれなどがあり微妙(※)ですし、タグにはアンダースコアが含まれていたりratingが付いていたりしますので(DeepDanbooruの場合)、エディタの置換機能などを用いてキャプションとタグのクリーニングをしたほうがいいでしょう。
859
+
860
+ ※たとえばアニメ絵の少女を学習する場合、キャプションにはgirl/girls/woman/womenなどのばらつきがあります。また「anime girl」なども単に「girl」としたほうが適切かもしれません。
861
+
862
+ クリーニング用のスクリプトが用意してありますので、スクリプトの内容を状況に応じて編集してお使いください。
863
+
864
+ (教師データフォルダの指定は不要になりました。メタデータ内の全データをクリーニングします。)
865
+
866
+ ```
867
+ python clean_captions_and_tags.py <読み込むメタデータファイル名> <書き込むメタデータファイル名>
868
+ ```
869
+
870
+ --in_jsonは付きませんのでご注意ください。たとえば次のようになります。
871
+
872
+ ```
873
+ python clean_captions_and_tags.py meta_cap_dd.json meta_clean.json
874
+ ```
875
+
876
+ 以上でキャプションとタグの前処理は完了です。
877
+
878
+ ## latentsの事前取得
879
+
880
+ ※ このステップは必須ではありません。省略しても学習時にlatentsを取得しながら学習できます。
881
+ また学習時に `random_crop` や `color_aug` などを行う場合にはlatentsの事前取得はできません(画像を毎回変えながら学習するため)。事前取得をしない場合、ここまでのメタデータで学習できます。
882
+
883
+ あらかじめ画像の潜在表現を取得しディスクに保存しておきます。それにより、学習を高速に進めることができます。あわせてbucketing(教師データをアスペクト比に応じて分類する)を行います。
884
+
885
+ 作業フォルダで以下のように入力してください。
886
+ ```
887
+ python prepare_buckets_latents.py --full_path <教師データフォルダ>
888
+ <読み込むメタデータファイル名> <書き込むメタデータファイル名>
889
+ <fine tuningするモデル名またはcheckpoint>
890
+ --batch_size <バッチサイズ>
891
+ --max_resolution <解像度 幅,高さ>
892
+ --mixed_precision <精度>
893
+ ```
894
+
895
+ モデルがmodel.ckpt、バッチサイズ4、学習解像度は512\*512、精度no(float32)で、meta_clean.jsonからメタデータを読み込み、meta_lat.jsonに書き込む場合、以下のようになります。
896
+
897
+ ```
898
+ python prepare_buckets_latents.py --full_path
899
+ train_data meta_clean.json meta_lat.json model.ckpt
900
+ --batch_size 4 --max_resolution 512,512 --mixed_precision no
901
+ ```
902
+
903
+ 教師データフォルダにnumpyのnpz形式でlatentsが保存されます。
904
+
905
+ 解像度の最小サイズを--min_bucket_resoオプションで、最大サイズを--max_bucket_resoで指定できます。デフォルトはそれぞれ256、1024です。たとえば最小サイズに384を指定すると、256\*1024や320\*768などの解像度は使わなくなります。
906
+ 解像度を768\*768のように大きくした場合、最大サイズに1280などを指定すると良いでしょう。
907
+
908
+ --flip_augオプションを指定すると左右反転のaugmentation(データ拡張)を行います。疑似的にデータ量を二���に増やすことができますが、データが左右対称でない場合に指定すると(例えばキャラクタの外見、髪型など)学習がうまく行かなくなります。
909
+
910
+
911
+ (反転した画像についてもlatentsを取得し、\*\_flip.npzファイルを保存する単純な実装です。fline_tune.pyには特にオプション指定は必要ありません。\_flip付きのファイルがある場合、flip付き・なしのファイルを、ランダムに読み込みます。)
912
+
913
+ バッチサイズはVRAM 12GBでももう少し増やせるかもしれません。
914
+ 解像度は64で割り切れる数字で、"幅,高さ"で指定します。解像度はfine tuning時のメモリサイズに直結します。VRAM 12GBでは512,512が限界と思われます(※)。16GBなら512,704や512,768まで上げられるかもしれません。なお256,256等にしてもVRAM 8GBでは厳しいようです(パラメータやoptimizerなどは解像度に関係せず一定のメモリが必要なため)。
915
+
916
+ ※batch size 1の学習で12GB VRAM、640,640で動いたとの報告もありました。
917
+
918
+ 以下のようにbucketingの結果が表示されます。
919
+
920
+ ![bucketingの結果](https://user-images.githubusercontent.com/52813779/208911419-71c00fbb-2ce6-49d5-89b5-b78d7715e441.png)
921
+
922
+ 複数の教師データフォルダがある場合には、full_path引数を指定しつつ、それぞれのフォルダに対して実行してください。
923
+ ```
924
+ python prepare_buckets_latents.py --full_path
925
+ train_data1 meta_clean.json meta_lat1.json model.ckpt
926
+ --batch_size 4 --max_resolution 512,512 --mixed_precision no
927
+
928
+ python prepare_buckets_latents.py --full_path
929
+ train_data2 meta_lat1.json meta_lat2.json model.ckpt
930
+ --batch_size 4 --max_resolution 512,512 --mixed_precision no
931
+
932
+ ```
933
+ 読み込み元と書き込み先を同じにすることも可能ですが別々の方が安全です。
934
+
935
+ __※引数を都度書き換えて、別のメタデータファイルに書き込むと安全です。__
936
+
train_db.py CHANGED
@@ -15,7 +15,11 @@ import diffusers
15
  from diffusers import DDPMScheduler
16
 
17
  import library.train_util as train_util
18
- from library.train_util import DreamBoothDataset
 
 
 
 
19
 
20
 
21
  def collate_fn(examples):
@@ -33,24 +37,33 @@ def train(args):
33
 
34
  tokenizer = train_util.load_tokenizer(args)
35
 
36
- train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
37
- tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
38
- args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
39
- args.bucket_reso_steps, args.bucket_no_upscale,
40
- args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
41
-
42
- if args.no_token_padding:
43
- train_dataset.disable_token_padding()
 
 
 
 
 
44
 
45
- # 学習データのdropout率を設定する
46
- train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
47
 
48
- train_dataset.make_buckets()
 
49
 
50
  if args.debug_dataset:
51
- train_util.debug_dataset(train_dataset)
52
  return
53
 
 
 
 
54
  # acceleratorを準備する
55
  print("prepare accelerator")
56
 
@@ -91,7 +104,7 @@ def train(args):
91
  vae.requires_grad_(False)
92
  vae.eval()
93
  with torch.no_grad():
94
- train_dataset.cache_latents(vae)
95
  vae.to("cpu")
96
  if torch.cuda.is_available():
97
  torch.cuda.empty_cache()
@@ -115,38 +128,18 @@ def train(args):
115
 
116
  # 学習に必要なクラスを準備する
117
  print("prepare optimizer, data loader etc.")
118
-
119
- # 8-bit Adamを使う
120
- if args.use_8bit_adam:
121
- try:
122
- import bitsandbytes as bnb
123
- except ImportError:
124
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
125
- print("use 8-bit Adam optimizer")
126
- optimizer_class = bnb.optim.AdamW8bit
127
- elif args.use_lion_optimizer:
128
- try:
129
- import lion_pytorch
130
- except ImportError:
131
- raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
132
- print("use Lion optimizer")
133
- optimizer_class = lion_pytorch.Lion
134
- else:
135
- optimizer_class = torch.optim.AdamW
136
-
137
  if train_text_encoder:
138
  trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
139
  else:
140
  trainable_params = unet.parameters()
141
 
142
- # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
143
- optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
144
 
145
  # dataloaderを準備する
146
  # DataLoaderのプロセス数:0はメインプロセスになる
147
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
148
  train_dataloader = torch.utils.data.DataLoader(
149
- train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
150
 
151
  # 学習ステップ数を計算する
152
  if args.max_train_epochs is not None:
@@ -156,9 +149,10 @@ def train(args):
156
  if args.stop_text_encoder_training is None:
157
  args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
158
 
159
- # lr schedulerを用意する
160
- lr_scheduler = diffusers.optimization.get_scheduler(
161
- args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
 
162
 
163
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
164
  if args.full_fp16:
@@ -195,8 +189,8 @@ def train(args):
195
  # 学習する
196
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
197
  print("running training / 学習開始")
198
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
199
- print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
200
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
201
  print(f" num epochs / epoch数: {num_train_epochs}")
202
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
@@ -217,7 +211,7 @@ def train(args):
217
  loss_total = 0.0
218
  for epoch in range(num_train_epochs):
219
  print(f"epoch {epoch+1}/{num_train_epochs}")
220
- train_dataset.set_current_epoch(epoch + 1)
221
 
222
  # 指定したステップ数までText Encoderを学習する:epoch最初の状態
223
  unet.train()
@@ -281,12 +275,12 @@ def train(args):
281
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
282
 
283
  accelerator.backward(loss)
284
- if accelerator.sync_gradients:
285
  if train_text_encoder:
286
  params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
287
  else:
288
  params_to_clip = unet.parameters()
289
- accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
290
 
291
  optimizer.step()
292
  lr_scheduler.step()
@@ -297,9 +291,13 @@ def train(args):
297
  progress_bar.update(1)
298
  global_step += 1
299
 
 
 
300
  current_loss = loss.detach().item()
301
  if args.logging_dir is not None:
302
- logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
 
 
303
  accelerator.log(logs, step=global_step)
304
 
305
  if epoch == 0:
@@ -326,6 +324,8 @@ def train(args):
326
  train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
327
  save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
328
 
 
 
329
  is_main_process = accelerator.is_main_process
330
  if is_main_process:
331
  unet = unwrap_model(unet)
@@ -352,6 +352,8 @@ if __name__ == '__main__':
352
  train_util.add_dataset_arguments(parser, True, False, True)
353
  train_util.add_training_arguments(parser, True)
354
  train_util.add_sd_saving_arguments(parser)
 
 
355
 
356
  parser.add_argument("--no_token_padding", action="store_true",
357
  help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
 
15
  from diffusers import DDPMScheduler
16
 
17
  import library.train_util as train_util
18
+ import library.config_util as config_util
19
+ from library.config_util import (
20
+ ConfigSanitizer,
21
+ BlueprintGenerator,
22
+ )
23
 
24
 
25
  def collate_fn(examples):
 
37
 
38
  tokenizer = train_util.load_tokenizer(args)
39
 
40
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
41
+ if args.dataset_config is not None:
42
+ print(f"Load dataset config from {args.dataset_config}")
43
+ user_config = config_util.load_user_config(args.dataset_config)
44
+ ignored = ["train_data_dir", "reg_data_dir"]
45
+ if any(getattr(args, attr) is not None for attr in ignored):
46
+ print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
47
+ else:
48
+ user_config = {
49
+ "datasets": [{
50
+ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
51
+ }]
52
+ }
53
 
54
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
55
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
56
 
57
+ if args.no_token_padding:
58
+ train_dataset_group.disable_token_padding()
59
 
60
  if args.debug_dataset:
61
+ train_util.debug_dataset(train_dataset_group)
62
  return
63
 
64
+ if cache_latents:
65
+ assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
66
+
67
  # acceleratorを準備する
68
  print("prepare accelerator")
69
 
 
104
  vae.requires_grad_(False)
105
  vae.eval()
106
  with torch.no_grad():
107
+ train_dataset_group.cache_latents(vae)
108
  vae.to("cpu")
109
  if torch.cuda.is_available():
110
  torch.cuda.empty_cache()
 
128
 
129
  # 学習に必要なクラスを準備する
130
  print("prepare optimizer, data loader etc.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  if train_text_encoder:
132
  trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
133
  else:
134
  trainable_params = unet.parameters()
135
 
136
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params)
 
137
 
138
  # dataloaderを準備する
139
  # DataLoaderのプロセス数:0はメインプロセスになる
140
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
141
  train_dataloader = torch.utils.data.DataLoader(
142
+ train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
143
 
144
  # 学習ステップ数を計算する
145
  if args.max_train_epochs is not None:
 
149
  if args.stop_text_encoder_training is None:
150
  args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
151
 
152
+ # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
153
+ lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
154
+ num_training_steps=args.max_train_steps,
155
+ num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
156
 
157
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
158
  if args.full_fp16:
 
189
  # 学習する
190
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
191
  print("running training / 学習開始")
192
+ print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
193
+ print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
194
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
195
  print(f" num epochs / epoch数: {num_train_epochs}")
196
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
 
211
  loss_total = 0.0
212
  for epoch in range(num_train_epochs):
213
  print(f"epoch {epoch+1}/{num_train_epochs}")
214
+ train_dataset_group.set_current_epoch(epoch + 1)
215
 
216
  # 指定したステップ数までText Encoderを学習する:epoch最初の状態
217
  unet.train()
 
275
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
276
 
277
  accelerator.backward(loss)
278
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
279
  if train_text_encoder:
280
  params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
281
  else:
282
  params_to_clip = unet.parameters()
283
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
284
 
285
  optimizer.step()
286
  lr_scheduler.step()
 
291
  progress_bar.update(1)
292
  global_step += 1
293
 
294
+ train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
295
+
296
  current_loss = loss.detach().item()
297
  if args.logging_dir is not None:
298
+ logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
299
+ if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
300
+ logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
301
  accelerator.log(logs, step=global_step)
302
 
303
  if epoch == 0:
 
324
  train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
325
  save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
326
 
327
+ train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
328
+
329
  is_main_process = accelerator.is_main_process
330
  if is_main_process:
331
  unet = unwrap_model(unet)
 
352
  train_util.add_dataset_arguments(parser, True, False, True)
353
  train_util.add_training_arguments(parser, True)
354
  train_util.add_sd_saving_arguments(parser)
355
+ train_util.add_optimizer_arguments(parser)
356
+ config_util.add_config_arguments(parser)
357
 
358
  parser.add_argument("--no_token_padding", action="store_true",
359
  help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
train_db_README-ja.md ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DreamBoothのガイドです。
2
+
3
+ [学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。
4
+
5
+ # 概要
6
+
7
+ DreamBoothとは、画像生成モデルに特定の主題を追加学習し、それを特定の識別子で生成する技術です。[論文はこちら](https://arxiv.org/abs/2208.12242)。
8
+
9
+ 具体的には、Stable Diffusionのモデルにキャラや画風などを学ばせ、それを `shs` のような特定の単語で呼び出せる(生成画像に出現させる)ことができます。
10
+
11
+ スクリプトは[DiffusersのDreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth)を元にしていますが、以下のような機能追加を行っています(いくつかの機能は元のスクリプト側もその後対応しています)。
12
+
13
+ スクリプトの主な機能は以下の通りです。
14
+
15
+ - 8bit Adam optimizerおよびlatentのキャッシュによる省メモリ化([Shivam Shrirao氏版](https://github.com/ShivamShrirao/diffusers/tree/main/examples/dreambooth)と同様)。
16
+ - xformersによる省メモリ化。
17
+ - 512x512だけではなく任意サイズでの学習。
18
+ - augmentationによる品質の向上。
19
+ - DreamBoothだけではなくText Encoder+U-Netのfine tuningに対応。
20
+ - Stable Diffusion形式でのモデルの読み書き。
21
+ - Aspect Ratio Bucketing。
22
+ - Stable Diffusion v2.0対応。
23
+
24
+ # 学習の手順
25
+
26
+ あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。
27
+
28
+ ## データの準備
29
+
30
+ [学習データの準備について](./train_README-ja.md) を参照してください。
31
+
32
+ ## 学習の実行
33
+
34
+ スクリプトを実行します。最大限、メモリを節約したコマンドは以下のようになります(実際には1行で入力します)。それぞれの行を必要に応じて書き換えてください。12GB程度のVRAMで動作するようです。
35
+
36
+ ```
37
+ accelerate launch --num_cpu_threads_per_process 1 train_db.py
38
+ --pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
39
+ --dataset_config=<データ準備で作成した.tomlファイル>
40
+ --output_dir=<学習したモデルの出力先フォルダ>
41
+ --output_name=<学習したモデル出力時のファイル名>
42
+ --save_model_as=safetensors
43
+ --prior_loss_weight=1.0
44
+ --max_train_steps=1600
45
+ --learning_rate=1e-6
46
+ --optimizer_type="AdamW8bit"
47
+ --xformers
48
+ --mixed_precision="fp16"
49
+ --cache_latents
50
+ --gradient_checkpointing
51
+ ```
52
+
53
+ `num_cpu_threads_per_process` には通常は1を指定するとよいようです。
54
+
55
+ `pretrained_model_name_or_path` に追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。
56
+
57
+ `output_dir` に学習後のモデルを保存するフォルダを指定します。`output_name` にモデルのファイル名を拡張子を除いて指定します。`save_model_as` でsafetensors形式での保存を指定しています。
58
+
59
+ `dataset_config` に `.toml` ファイルを指定します。ファイル内でのバッチサイズ指定は、当初はメモリ消費を抑えるために `1` としてください。
60
+
61
+ `prior_loss_weight` は正則化画像のlossの重みです。通常は1.0を指定します。
62
+
63
+ 学習させるステップ数 `max_train_steps` を1600とします。学習率 `learning_rate` はここでは1e-6を指定しています。
64
+
65
+ 省メモリ化のため `mixed_precision="fp16"` を指定します(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。また `gradient_checkpointing` を指定します。
66
+
67
+ オプティマイザ(モデルを学習データにあうように最適化=学習させるクラス)にメモリ消費の少ない 8bit AdamW を使うため、 `optimizer_type="AdamW8bit"` を指定します。
68
+
69
+ `xformers` オプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。
70
+
71
+ 省メモリ化のため `cache_latents` オプションを指定してVAEの出力をキャッシュします。
72
+
73
+ ある程度メモリがある場合は、`.toml` ファイルを編集してバッチサイズをたとえば `4` くらいに増やしてください(高速化と精度向上の可能性があります)。また `cache_latents` を外すことで augmentation が可能になります。
74
+
75
+ ### よく使われるオプションについて
76
+
77
+ 以下の場合には [学習の共通ドキュメント](./train_README-ja.md) の「よく使われるオプション」を参照してください。
78
+
79
+ - Stable Diffusion 2.xまたはそこからの派生モデルを学習する
80
+ - clip skipを2以上を前提としたモデルを学習する
81
+ - 75トークンを超えたキャプションで学習する
82
+
83
+ ### DreamBoothでのステップ数について
84
+
85
+ 当スクリプトでは省メモリ化のため、ステップ当たりの学習回数が元のスクリプトの半分になっています(対象の画像と正則化画像を同一のバッチではなく別のバッチに分割して学習するため)。
86
+
87
+ 元のDiffusers版やXavierXiao氏のStable Diffusion版とほぼ同じ学習を行うには、ステップ数を倍にしてください。
88
+
89
+ (学習画像と正則化画像をまとめてから shuffle するため厳密にはデータの順番が変わってしまいますが、学習には大きな影響はないと思います。)
90
+
91
+ ### DreamBoothでのバッチサイズについて
92
+
93
+ モデル全体を学習するためLoRA等の学習に比べるとメモリ消費量は多くなります(fine tuningと同じ)。
94
+
95
+ ### 学習率について
96
+
97
+ Diffusers版では5e-6ですがStable Diffusion版は1e-6ですので、上のサンプルでは1e-6を指定しています。
98
+
99
+ ### 以前の形式のデータセット指定をした場合のコマンドライン
100
+
101
+ 解像度やバッチサイズをオプションで指定します。コマンドラインの例は以下の通りです。
102
+
103
+ ```
104
+ accelerate launch --num_cpu_threads_per_process 1 train_db.py
105
+ --pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
106
+ --train_data_dir=<学習用データのディレクトリ>
107
+ --reg_data_dir=<正則化画像のディレクトリ>
108
+ --output_dir=<学習したモデルの出力先ディレクトリ>
109
+ --output_name=<学習したモデル出力時のファイル名>
110
+ --prior_loss_weight=1.0
111
+ --resolution=512
112
+ --train_batch_size=1
113
+ --learning_rate=1e-6
114
+ --max_train_steps=1600
115
+ --use_8bit_adam
116
+ --xformers
117
+ --mixed_precision="bf16"
118
+ --cache_latents
119
+ --gradient_checkpointing
120
+ ```
121
+
122
+ ## 学習したモデルで画像生成する
123
+
124
+ 学習が終わると指定したフォルダに指定した名前でsafetensorsファイルが出力されます。
125
+
126
+ v1.4/1.5およびその他の派生モデルの場合、このモデルでAutomatic1111氏のWebUIなどで推論できます。models\Stable-diffusionフォルダに置いてください。
127
+
128
+ v2.xモデルでWebUIで画像生成する場合、モデルの仕様が記述された.yamlファイルが別途必要になります。v2.x baseの場合はv2-inference.yamlを、768/vの場合はv2-inference-v.yamlを、同じフォルダに置き、拡張子の前の部分をモデルと同じ名前にしてください。
129
+
130
+ ![image](https://user-images.githubusercontent.com/52813779/210776915-061d79c3-6582-42c2-8884-8b91d2f07313.png)
131
+
132
+ 各yamlファイルは[Stability AIのSD2.0のリポジトリ](https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion)にあります。
133
+
134
+ # DreamBooth特有のその他の主なオプション
135
+
136
+ すべてのオプションについては別文書を参照してください。
137
+
138
+ ## Text Encoderの学習を途中から行わない --stop_text_encoder_training
139
+
140
+ stop_text_encoder_trainingオプションに数値を指定すると、そのステップ数以降はText Encoderの学習を行わずU-Netだけ学習します。場合によっては精度の向上が期待できるかもしれません。
141
+
142
+ (恐らくText Encoderだけ先に過学習することがあり、それを防げるのではないかと推測していますが、詳細な影響は不明です。)
143
+
144
+ ## Tokenizerのパディングをしない --no_token_padding
145
+ no_token_paddingオプションを指定するとTokenizerの出力をpaddingしません(Diffusers版の旧DreamBoothと同じ動きになります)。
146
+
147
+
148
+ <!--
149
+ bucketing(後述)を利用しかつaugmentation(後述)を使う場合の例は以下のようになります。
150
+
151
+ ```
152
+ accelerate launch --num_cpu_threads_per_process 8 train_db.py
153
+ --pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
154
+ --train_data_dir=<学習用データのディレクトリ>
155
+ --reg_data_dir=<正則化画像のディレクトリ>
156
+ --output_dir=<学習したモデルの出力先ディレクトリ>
157
+ --resolution=768,512
158
+ --train_batch_size=20 --learning_rate=5e-6 --max_train_steps=800
159
+ --use_8bit_adam --xformers --mixed_precision="bf16"
160
+ --save_every_n_epochs=1 --save_state --save_precision="bf16"
161
+ --logging_dir=logs
162
+ --enable_bucket --min_bucket_reso=384 --max_bucket_reso=1280
163
+ --color_aug --flip_aug --gradient_checkpointing --seed 42
164
+ ```
165
+
166
+
167
+ -->
train_network.py CHANGED
@@ -1,8 +1,4 @@
1
- from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
2
- from torch.optim import Optimizer
3
- from torch.cuda.amp import autocast
4
  from torch.nn.parallel import DistributedDataParallel as DDP
5
- from typing import Optional, Union
6
  import importlib
7
  import argparse
8
  import gc
@@ -15,92 +11,39 @@ import json
15
  from tqdm import tqdm
16
  import torch
17
  from accelerate.utils import set_seed
18
- import diffusers
19
  from diffusers import DDPMScheduler
20
 
21
  import library.train_util as train_util
22
- from library.train_util import DreamBoothDataset, FineTuningDataset
 
 
 
 
 
 
 
23
 
24
 
25
  def collate_fn(examples):
26
  return examples[0]
27
 
28
 
 
29
  def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
30
  logs = {"loss/current": current_loss, "loss/average": avr_loss}
31
 
32
  if args.network_train_unet_only:
33
- logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
34
  elif args.network_train_text_encoder_only:
35
- logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
36
  else:
37
- logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
38
- logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
39
-
40
- return logs
41
 
 
 
42
 
43
- # Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
44
- # code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
45
- # Which is a newer release of diffusers than currently packaged with sd-scripts
46
- # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
47
-
48
-
49
- def get_scheduler_fix(
50
- name: Union[str, SchedulerType],
51
- optimizer: Optimizer,
52
- num_warmup_steps: Optional[int] = None,
53
- num_training_steps: Optional[int] = None,
54
- num_cycles: int = 1,
55
- power: float = 1.0,
56
- ):
57
- """
58
- Unified API to get any scheduler from its name.
59
- Args:
60
- name (`str` or `SchedulerType`):
61
- The name of the scheduler to use.
62
- optimizer (`torch.optim.Optimizer`):
63
- The optimizer that will be used during training.
64
- num_warmup_steps (`int`, *optional*):
65
- The number of warmup steps to do. This is not required by all schedulers (hence the argument being
66
- optional), the function will raise an error if it's unset and the scheduler type requires it.
67
- num_training_steps (`int``, *optional*):
68
- The number of training steps to do. This is not required by all schedulers (hence the argument being
69
- optional), the function will raise an error if it's unset and the scheduler type requires it.
70
- num_cycles (`int`, *optional*):
71
- The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
72
- power (`float`, *optional*, defaults to 1.0):
73
- Power factor. See `POLYNOMIAL` scheduler
74
- last_epoch (`int`, *optional*, defaults to -1):
75
- The index of the last epoch when resuming training.
76
- """
77
- name = SchedulerType(name)
78
- schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
79
- if name == SchedulerType.CONSTANT:
80
- return schedule_func(optimizer)
81
-
82
- # All other schedulers require `num_warmup_steps`
83
- if num_warmup_steps is None:
84
- raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
85
-
86
- if name == SchedulerType.CONSTANT_WITH_WARMUP:
87
- return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
88
-
89
- # All other schedulers require `num_training_steps`
90
- if num_training_steps is None:
91
- raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
92
-
93
- if name == SchedulerType.COSINE_WITH_RESTARTS:
94
- return schedule_func(
95
- optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
96
- )
97
-
98
- if name == SchedulerType.POLYNOMIAL:
99
- return schedule_func(
100
- optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
101
- )
102
-
103
- return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
104
 
105
 
106
  def train(args):
@@ -111,6 +54,7 @@ def train(args):
111
 
112
  cache_latents = args.cache_latents
113
  use_dreambooth_method = args.in_json is None
 
114
 
115
  if args.seed is not None:
116
  set_seed(args.seed)
@@ -118,38 +62,51 @@ def train(args):
118
  tokenizer = train_util.load_tokenizer(args)
119
 
120
  # データセットを準備する
121
- if use_dreambooth_method:
122
- print("Use DreamBooth method.")
123
- train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
124
- tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
125
- args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
126
- args.bucket_reso_steps, args.bucket_no_upscale,
127
- args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
128
- args.random_crop, args.debug_dataset)
129
  else:
130
- print("Train with captions.")
131
- train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
132
- tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
133
- args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
134
- args.bucket_reso_steps, args.bucket_no_upscale,
135
- args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
136
- args.dataset_repeats, args.debug_dataset)
137
-
138
- # 学習データのdropout率を設定する
139
- train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
140
-
141
- train_dataset.make_buckets()
 
 
 
 
 
 
 
 
142
 
143
  if args.debug_dataset:
144
- train_util.debug_dataset(train_dataset)
145
  return
146
- if len(train_dataset) == 0:
147
  print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
148
  return
149
 
 
 
 
 
150
  # acceleratorを準備する
151
  print("prepare accelerator")
152
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
 
153
 
154
  # mixed precisionに対応した型を用意しておき適宜castする
155
  weight_dtype, save_dtype = train_util.prepare_dtype(args)
@@ -161,7 +118,7 @@ def train(args):
161
  if args.lowram:
162
  text_encoder.to("cuda")
163
  unet.to("cuda")
164
-
165
  # モデルに xformers とか memory efficient attention を組み込む
166
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
167
 
@@ -171,13 +128,15 @@ def train(args):
171
  vae.requires_grad_(False)
172
  vae.eval()
173
  with torch.no_grad():
174
- train_dataset.cache_latents(vae)
175
  vae.to("cpu")
176
  if torch.cuda.is_available():
177
  torch.cuda.empty_cache()
178
  gc.collect()
179
 
180
  # prepare network
 
 
181
  print("import network module:", args.network_module)
182
  network_module = importlib.import_module(args.network_module)
183
 
@@ -208,48 +167,25 @@ def train(args):
208
  # 学習に必要なクラスを準備する
209
  print("prepare optimizer, data loader etc.")
210
 
211
- # 8-bit Adamを使う
212
- if args.use_8bit_adam:
213
- try:
214
- import bitsandbytes as bnb
215
- except ImportError:
216
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
217
- print("use 8-bit Adam optimizer")
218
- optimizer_class = bnb.optim.AdamW8bit
219
- elif args.use_lion_optimizer:
220
- try:
221
- import lion_pytorch
222
- except ImportError:
223
- raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
224
- print("use Lion optimizer")
225
- optimizer_class = lion_pytorch.Lion
226
- else:
227
- optimizer_class = torch.optim.AdamW
228
-
229
- optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
230
-
231
  trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
232
-
233
- # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
234
- optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
235
 
236
  # dataloaderを準備する
237
  # DataLoaderのプロセス数:0はメインプロセスになる
238
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
239
  train_dataloader = torch.utils.data.DataLoader(
240
- train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
241
 
242
  # 学習ステップ数を計算する
243
  if args.max_train_epochs is not None:
244
- args.max_train_steps = args.max_train_epochs * len(train_dataloader)
245
- print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
 
246
 
247
  # lr schedulerを用意する
248
- # lr_scheduler = diffusers.optimization.get_scheduler(
249
- lr_scheduler = get_scheduler_fix(
250
- args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
251
- num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
252
- num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
253
 
254
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
255
  if args.full_fp16:
@@ -317,17 +253,21 @@ def train(args):
317
  args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
318
 
319
  # 学習する
 
320
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
321
- print("running training / 学習開始")
322
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
323
- print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
324
- print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
325
- print(f" num epochs / epoch数: {num_train_epochs}")
326
- print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
327
- print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
328
- print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
329
- print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
330
-
 
 
 
331
  metadata = {
332
  "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
333
  "ss_training_started_at": training_started_at, # unix timestamp
@@ -335,12 +275,10 @@ def train(args):
335
  "ss_learning_rate": args.learning_rate,
336
  "ss_text_encoder_lr": args.text_encoder_lr,
337
  "ss_unet_lr": args.unet_lr,
338
- "ss_num_train_images": train_dataset.num_train_images, # includes repeating
339
- "ss_num_reg_images": train_dataset.num_reg_images,
340
  "ss_num_batches_per_epoch": len(train_dataloader),
341
  "ss_num_epochs": num_train_epochs,
342
- "ss_batch_size_per_device": args.train_batch_size,
343
- "ss_total_batch_size": total_batch_size,
344
  "ss_gradient_checkpointing": args.gradient_checkpointing,
345
  "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
346
  "ss_max_train_steps": args.max_train_steps,
@@ -352,33 +290,156 @@ def train(args):
352
  "ss_mixed_precision": args.mixed_precision,
353
  "ss_full_fp16": bool(args.full_fp16),
354
  "ss_v2": bool(args.v2),
355
- "ss_resolution": args.resolution,
356
  "ss_clip_skip": args.clip_skip,
357
  "ss_max_token_length": args.max_token_length,
358
- "ss_color_aug": bool(args.color_aug),
359
- "ss_flip_aug": bool(args.flip_aug),
360
- "ss_random_crop": bool(args.random_crop),
361
- "ss_shuffle_caption": bool(args.shuffle_caption),
362
  "ss_cache_latents": bool(args.cache_latents),
363
- "ss_enable_bucket": bool(train_dataset.enable_bucket),
364
- "ss_min_bucket_reso": train_dataset.min_bucket_reso,
365
- "ss_max_bucket_reso": train_dataset.max_bucket_reso,
366
  "ss_seed": args.seed,
367
- "ss_keep_tokens": args.keep_tokens,
368
  "ss_noise_offset": args.noise_offset,
369
- "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
370
- "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
371
- "ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
372
- "ss_bucket_info": json.dumps(train_dataset.bucket_info),
373
  "ss_training_comment": args.training_comment, # will not be updated after training
374
  "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
375
- "ss_optimizer": optimizer_name
 
 
 
 
 
 
376
  }
377
 
378
- # uncomment if another network is added
379
- # for key, value in net_kwargs.items():
380
- # metadata["ss_arg_" + key] = value
381
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  if args.pretrained_model_name_or_path is not None:
383
  sd_model_name = args.pretrained_model_name_or_path
384
  if os.path.exists(sd_model_name):
@@ -397,6 +458,13 @@ def train(args):
397
 
398
  metadata = {k: str(v) for k, v in metadata.items()}
399
 
 
 
 
 
 
 
 
400
  progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
401
  global_step = 0
402
 
@@ -409,8 +477,9 @@ def train(args):
409
  loss_list = []
410
  loss_total = 0.0
411
  for epoch in range(num_train_epochs):
412
- print(f"epoch {epoch+1}/{num_train_epochs}")
413
- train_dataset.set_current_epoch(epoch + 1)
 
414
 
415
  metadata["ss_epoch"] = str(epoch+1)
416
 
@@ -447,7 +516,7 @@ def train(args):
447
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
448
 
449
  # Predict the noise residual
450
- with autocast():
451
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
452
 
453
  if args.v_parameterization:
@@ -465,9 +534,9 @@ def train(args):
465
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
466
 
467
  accelerator.backward(loss)
468
- if accelerator.sync_gradients:
469
  params_to_clip = network.get_trainable_params()
470
- accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
471
 
472
  optimizer.step()
473
  lr_scheduler.step()
@@ -478,6 +547,8 @@ def train(args):
478
  progress_bar.update(1)
479
  global_step += 1
480
 
 
 
481
  current_loss = loss.detach().item()
482
  if epoch == 0:
483
  loss_list.append(current_loss)
@@ -508,8 +579,9 @@ def train(args):
508
  def save_func():
509
  ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
510
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
 
511
  print(f"saving checkpoint: {ckpt_file}")
512
- unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
513
 
514
  def remove_old_func(old_epoch_no):
515
  old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
@@ -518,15 +590,18 @@ def train(args):
518
  print(f"removing old checkpoint: {old_ckpt_file}")
519
  os.remove(old_ckpt_file)
520
 
521
- saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
522
- if saving and args.save_state:
523
- train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
 
 
 
524
 
525
  # end of epoch
526
 
527
  metadata["ss_epoch"] = str(num_train_epochs)
 
528
 
529
- is_main_process = accelerator.is_main_process
530
  if is_main_process:
531
  network = unwrap_model(network)
532
 
@@ -545,7 +620,7 @@ def train(args):
545
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
546
 
547
  print(f"save trained model to {ckpt_file}")
548
- network.save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
549
  print("model saved.")
550
 
551
 
@@ -555,6 +630,8 @@ if __name__ == '__main__':
555
  train_util.add_sd_models_arguments(parser)
556
  train_util.add_dataset_arguments(parser, True, True, True)
557
  train_util.add_training_arguments(parser, True)
 
 
558
 
559
  parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
560
  parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
@@ -562,10 +639,6 @@ if __name__ == '__main__':
562
 
563
  parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
564
  parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
565
- parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
566
- help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
567
- parser.add_argument("--lr_scheduler_power", type=float, default=1,
568
- help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
569
 
570
  parser.add_argument("--network_weights", type=str, default=None,
571
  help="pretrained weights for network / 学習するネットワークの初期重み")
 
 
 
 
1
  from torch.nn.parallel import DistributedDataParallel as DDP
 
2
  import importlib
3
  import argparse
4
  import gc
 
11
  from tqdm import tqdm
12
  import torch
13
  from accelerate.utils import set_seed
 
14
  from diffusers import DDPMScheduler
15
 
16
  import library.train_util as train_util
17
+ from library.train_util import (
18
+ DreamBoothDataset,
19
+ )
20
+ import library.config_util as config_util
21
+ from library.config_util import (
22
+ ConfigSanitizer,
23
+ BlueprintGenerator,
24
+ )
25
 
26
 
27
  def collate_fn(examples):
28
  return examples[0]
29
 
30
 
31
+ # TODO 他のスクリプトと共通化する
32
  def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
33
  logs = {"loss/current": current_loss, "loss/average": avr_loss}
34
 
35
  if args.network_train_unet_only:
36
+ logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
37
  elif args.network_train_text_encoder_only:
38
+ logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
39
  else:
40
+ logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
41
+ logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
 
 
42
 
43
+ if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
44
+ logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
45
 
46
+ return logs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
 
49
  def train(args):
 
54
 
55
  cache_latents = args.cache_latents
56
  use_dreambooth_method = args.in_json is None
57
+ use_user_config = args.dataset_config is not None
58
 
59
  if args.seed is not None:
60
  set_seed(args.seed)
 
62
  tokenizer = train_util.load_tokenizer(args)
63
 
64
  # データセットを準備する
65
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
66
+ if use_user_config:
67
+ print(f"Load dataset config from {args.dataset_config}")
68
+ user_config = config_util.load_user_config(args.dataset_config)
69
+ ignored = ["train_data_dir", "reg_data_dir", "in_json"]
70
+ if any(getattr(args, attr) is not None for attr in ignored):
71
+ print(
72
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
73
  else:
74
+ if use_dreambooth_method:
75
+ print("Use DreamBooth method.")
76
+ user_config = {
77
+ "datasets": [{
78
+ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
79
+ }]
80
+ }
81
+ else:
82
+ print("Train with captions.")
83
+ user_config = {
84
+ "datasets": [{
85
+ "subsets": [{
86
+ "image_dir": args.train_data_dir,
87
+ "metadata_file": args.in_json,
88
+ }]
89
+ }]
90
+ }
91
+
92
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
93
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
94
 
95
  if args.debug_dataset:
96
+ train_util.debug_dataset(train_dataset_group)
97
  return
98
+ if len(train_dataset_group) == 0:
99
  print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
100
  return
101
 
102
+ if cache_latents:
103
+ assert train_dataset_group.is_latent_cacheable(
104
+ ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
105
+
106
  # acceleratorを準備する
107
  print("prepare accelerator")
108
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
109
+ is_main_process = accelerator.is_main_process
110
 
111
  # mixed precisionに対応した型を用意しておき適宜castする
112
  weight_dtype, save_dtype = train_util.prepare_dtype(args)
 
118
  if args.lowram:
119
  text_encoder.to("cuda")
120
  unet.to("cuda")
121
+
122
  # モデルに xformers とか memory efficient attention を組み込む
123
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
124
 
 
128
  vae.requires_grad_(False)
129
  vae.eval()
130
  with torch.no_grad():
131
+ train_dataset_group.cache_latents(vae)
132
  vae.to("cpu")
133
  if torch.cuda.is_available():
134
  torch.cuda.empty_cache()
135
  gc.collect()
136
 
137
  # prepare network
138
+ import sys
139
+ sys.path.append(os.path.dirname(__file__))
140
  print("import network module:", args.network_module)
141
  network_module = importlib.import_module(args.network_module)
142
 
 
167
  # 学習に必要なクラスを準備する
168
  print("prepare optimizer, data loader etc.")
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
171
+ optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
 
 
172
 
173
  # dataloaderを準備する
174
  # DataLoaderのプロセス数:0はメインプロセスになる
175
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
176
  train_dataloader = torch.utils.data.DataLoader(
177
+ train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
178
 
179
  # 学習ステップ数を計算する
180
  if args.max_train_epochs is not None:
181
+ args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes)
182
+ if is_main_process:
183
+ print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
184
 
185
  # lr schedulerを用意する
186
+ lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
187
+ num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps,
188
+ num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
 
 
189
 
190
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
191
  if args.full_fp16:
 
253
  args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
254
 
255
  # 学習する
256
+ # TODO: find a way to handle total batch size when there are multiple datasets
257
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
258
+
259
+ if is_main_process:
260
+ print("running training / 学習開始")
261
+ print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
262
+ print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
263
+ print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
264
+ print(f" num epochs / epoch数: {num_train_epochs}")
265
+ print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
266
+ # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
267
+ print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
268
+ print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
269
+
270
+ # TODO refactor metadata creation and move to util
271
  metadata = {
272
  "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
273
  "ss_training_started_at": training_started_at, # unix timestamp
 
275
  "ss_learning_rate": args.learning_rate,
276
  "ss_text_encoder_lr": args.text_encoder_lr,
277
  "ss_unet_lr": args.unet_lr,
278
+ "ss_num_train_images": train_dataset_group.num_train_images,
279
+ "ss_num_reg_images": train_dataset_group.num_reg_images,
280
  "ss_num_batches_per_epoch": len(train_dataloader),
281
  "ss_num_epochs": num_train_epochs,
 
 
282
  "ss_gradient_checkpointing": args.gradient_checkpointing,
283
  "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
284
  "ss_max_train_steps": args.max_train_steps,
 
290
  "ss_mixed_precision": args.mixed_precision,
291
  "ss_full_fp16": bool(args.full_fp16),
292
  "ss_v2": bool(args.v2),
 
293
  "ss_clip_skip": args.clip_skip,
294
  "ss_max_token_length": args.max_token_length,
 
 
 
 
295
  "ss_cache_latents": bool(args.cache_latents),
 
 
 
296
  "ss_seed": args.seed,
297
+ "ss_lowram": args.lowram,
298
  "ss_noise_offset": args.noise_offset,
 
 
 
 
299
  "ss_training_comment": args.training_comment, # will not be updated after training
300
  "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
301
+ "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
302
+ "ss_max_grad_norm": args.max_grad_norm,
303
+ "ss_caption_dropout_rate": args.caption_dropout_rate,
304
+ "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
305
+ "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
306
+ "ss_face_crop_aug_range": args.face_crop_aug_range,
307
+ "ss_prior_loss_weight": args.prior_loss_weight,
308
  }
309
 
310
+ if use_user_config:
311
+ # save metadata of multiple datasets
312
+ # NOTE: pack "ss_datasets" value as json one time
313
+ # or should also pack nested collections as json?
314
+ datasets_metadata = []
315
+ tag_frequency = {} # merge tag frequency for metadata editor
316
+ dataset_dirs_info = {} # merge subset dirs for metadata editor
317
+
318
+ for dataset in train_dataset_group.datasets:
319
+ is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
320
+ dataset_metadata = {
321
+ "is_dreambooth": is_dreambooth_dataset,
322
+ "batch_size_per_device": dataset.batch_size,
323
+ "num_train_images": dataset.num_train_images, # includes repeating
324
+ "num_reg_images": dataset.num_reg_images,
325
+ "resolution": (dataset.width, dataset.height),
326
+ "enable_bucket": bool(dataset.enable_bucket),
327
+ "min_bucket_reso": dataset.min_bucket_reso,
328
+ "max_bucket_reso": dataset.max_bucket_reso,
329
+ "tag_frequency": dataset.tag_frequency,
330
+ "bucket_info": dataset.bucket_info,
331
+ }
332
+
333
+ subsets_metadata = []
334
+ for subset in dataset.subsets:
335
+ subset_metadata = {
336
+ "img_count": subset.img_count,
337
+ "num_repeats": subset.num_repeats,
338
+ "color_aug": bool(subset.color_aug),
339
+ "flip_aug": bool(subset.flip_aug),
340
+ "random_crop": bool(subset.random_crop),
341
+ "shuffle_caption": bool(subset.shuffle_caption),
342
+ "keep_tokens": subset.keep_tokens,
343
+ }
344
+
345
+ image_dir_or_metadata_file = None
346
+ if subset.image_dir:
347
+ image_dir = os.path.basename(subset.image_dir)
348
+ subset_metadata["image_dir"] = image_dir
349
+ image_dir_or_metadata_file = image_dir
350
+
351
+ if is_dreambooth_dataset:
352
+ subset_metadata["class_tokens"] = subset.class_tokens
353
+ subset_metadata["is_reg"] = subset.is_reg
354
+ if subset.is_reg:
355
+ image_dir_or_metadata_file = None # not merging reg dataset
356
+ else:
357
+ metadata_file = os.path.basename(subset.metadata_file)
358
+ subset_metadata["metadata_file"] = metadata_file
359
+ image_dir_or_metadata_file = metadata_file # may overwrite
360
+
361
+ subsets_metadata.append(subset_metadata)
362
+
363
+ # merge dataset dir: not reg subset only
364
+ # TODO update additional-network extension to show detailed dataset config from metadata
365
+ if image_dir_or_metadata_file is not None:
366
+ # datasets may have a certain dir multiple times
367
+ v = image_dir_or_metadata_file
368
+ i = 2
369
+ while v in dataset_dirs_info:
370
+ v = image_dir_or_metadata_file + f" ({i})"
371
+ i += 1
372
+ image_dir_or_metadata_file = v
373
+
374
+ dataset_dirs_info[image_dir_or_metadata_file] = {
375
+ "n_repeats": subset.num_repeats,
376
+ "img_count": subset.img_count
377
+ }
378
+
379
+ dataset_metadata["subsets"] = subsets_metadata
380
+ datasets_metadata.append(dataset_metadata)
381
+
382
+ # merge tag frequency:
383
+ for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
384
+ # あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
385
+ # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
386
+ # なので、ここで複数datasetの回数を合算してもあまり意味はない
387
+ if ds_dir_name in tag_frequency:
388
+ continue
389
+ tag_frequency[ds_dir_name] = ds_freq_for_dir
390
+
391
+ metadata["ss_datasets"] = json.dumps(datasets_metadata)
392
+ metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
393
+ metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
394
+ else:
395
+ # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
396
+ assert len(
397
+ train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
398
+
399
+ dataset = train_dataset_group.datasets[0]
400
+
401
+ dataset_dirs_info = {}
402
+ reg_dataset_dirs_info = {}
403
+ if use_dreambooth_method:
404
+ for subset in dataset.subsets:
405
+ info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
406
+ info[os.path.basename(subset.image_dir)] = {
407
+ "n_repeats": subset.num_repeats,
408
+ "img_count": subset.img_count
409
+ }
410
+ else:
411
+ for subset in dataset.subsets:
412
+ dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
413
+ "n_repeats": subset.num_repeats,
414
+ "img_count": subset.img_count
415
+ }
416
+
417
+ metadata.update({
418
+ "ss_batch_size_per_device": args.train_batch_size,
419
+ "ss_total_batch_size": total_batch_size,
420
+ "ss_resolution": args.resolution,
421
+ "ss_color_aug": bool(args.color_aug),
422
+ "ss_flip_aug": bool(args.flip_aug),
423
+ "ss_random_crop": bool(args.random_crop),
424
+ "ss_shuffle_caption": bool(args.shuffle_caption),
425
+ "ss_enable_bucket": bool(dataset.enable_bucket),
426
+ "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
427
+ "ss_min_bucket_reso": dataset.min_bucket_reso,
428
+ "ss_max_bucket_reso": dataset.max_bucket_reso,
429
+ "ss_keep_tokens": args.keep_tokens,
430
+ "ss_dataset_dirs": json.dumps(dataset_dirs_info),
431
+ "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
432
+ "ss_tag_frequency": json.dumps(dataset.tag_frequency),
433
+ "ss_bucket_info": json.dumps(dataset.bucket_info),
434
+ })
435
+
436
+ # add extra args
437
+ if args.network_args:
438
+ metadata["ss_network_args"] = json.dumps(net_kwargs)
439
+ # for key, value in net_kwargs.items():
440
+ # metadata["ss_arg_" + key] = value
441
+
442
+ # model name and hash
443
  if args.pretrained_model_name_or_path is not None:
444
  sd_model_name = args.pretrained_model_name_or_path
445
  if os.path.exists(sd_model_name):
 
458
 
459
  metadata = {k: str(v) for k, v in metadata.items()}
460
 
461
+ # make minimum metadata for filtering
462
+ minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"]
463
+ minimum_metadata = {}
464
+ for key in minimum_keys:
465
+ if key in metadata:
466
+ minimum_metadata[key] = metadata[key]
467
+
468
  progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
469
  global_step = 0
470
 
 
477
  loss_list = []
478
  loss_total = 0.0
479
  for epoch in range(num_train_epochs):
480
+ if is_main_process:
481
+ print(f"epoch {epoch+1}/{num_train_epochs}")
482
+ train_dataset_group.set_current_epoch(epoch + 1)
483
 
484
  metadata["ss_epoch"] = str(epoch+1)
485
 
 
516
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
517
 
518
  # Predict the noise residual
519
+ with accelerator.autocast():
520
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
521
 
522
  if args.v_parameterization:
 
534
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
535
 
536
  accelerator.backward(loss)
537
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
538
  params_to_clip = network.get_trainable_params()
539
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
540
 
541
  optimizer.step()
542
  lr_scheduler.step()
 
547
  progress_bar.update(1)
548
  global_step += 1
549
 
550
+ train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
551
+
552
  current_loss = loss.detach().item()
553
  if epoch == 0:
554
  loss_list.append(current_loss)
 
579
  def save_func():
580
  ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
581
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
582
+ metadata["ss_training_finished_at"] = str(time.time())
583
  print(f"saving checkpoint: {ckpt_file}")
584
+ unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
585
 
586
  def remove_old_func(old_epoch_no):
587
  old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
 
590
  print(f"removing old checkpoint: {old_ckpt_file}")
591
  os.remove(old_ckpt_file)
592
 
593
+ if is_main_process:
594
+ saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
595
+ if saving and args.save_state:
596
+ train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
597
+
598
+ train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
599
 
600
  # end of epoch
601
 
602
  metadata["ss_epoch"] = str(num_train_epochs)
603
+ metadata["ss_training_finished_at"] = str(time.time())
604
 
 
605
  if is_main_process:
606
  network = unwrap_model(network)
607
 
 
620
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
621
 
622
  print(f"save trained model to {ckpt_file}")
623
+ network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
624
  print("model saved.")
625
 
626
 
 
630
  train_util.add_sd_models_arguments(parser)
631
  train_util.add_dataset_arguments(parser, True, True, True)
632
  train_util.add_training_arguments(parser, True)
633
+ train_util.add_optimizer_arguments(parser)
634
+ config_util.add_config_arguments(parser)
635
 
636
  parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
637
  parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
 
639
 
640
  parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
641
  parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
 
 
 
 
642
 
643
  parser.add_argument("--network_weights", type=str, default=None,
644
  help="pretrained weights for network / 学習するネットワークの初期重み")
train_network_README-ja.md ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRAの学習について
2
+
3
+ [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685)(arxiv)、[LoRA](https://github.com/microsoft/LoRA)(github)をStable Diffusionに適用したものです。
4
+
5
+ [cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を大いに参考にさせていただきました。ありがとうございます。
6
+
7
+ 通常のLoRAは Linear およぴカーネルサイズ 1x1 の Conv2d にのみ適用されますが、カーネルサイズ 3x3 のConv2dに適用を拡大することもできます。
8
+
9
+ Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) が最初にリリースし、KohakuBlueleaf氏が [LoCon](https://github.com/KohakuBlueleaf/LoCon) でその有効性を明らかにしたものです。KohakuBlueleaf氏に深く感謝します。
10
+
11
+ 8GB VRAMでもぎりぎり動作するようです。
12
+
13
+ [学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。
14
+
15
+ ## 学習したモデルに関する注意
16
+
17
+ cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。
18
+
19
+ WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージしておくか、こちらの[WebUI用extension](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。
20
+
21
+ # 学習の手順
22
+
23
+ あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。
24
+
25
+ ## データの準備
26
+
27
+ [学習データの準備について](./train_README-ja.md) を参照してください。
28
+
29
+
30
+ ## 学習の実行
31
+
32
+ `train_network.py`を用います。
33
+
34
+ `train_network.py`では `--network_module` オプションに、学習対象のモジュール名を指定します。LoRAに対応するのはnetwork.loraとなりますので、それを指定してください。
35
+
36
+ なお学習率は通常のDreamBoothやfine tuningよりも高めの、1e-4程度を指定するとよいようです。
37
+
38
+ 以下はコマンドラインの例です。
39
+
40
+ ```
41
+ accelerate launch --num_cpu_threads_per_process 1 train_network.py
42
+ --pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
43
+ --dataset_config=<データ準備で作成した.tomlファイル>
44
+ --output_dir=<学習したモデルの出力先フォルダ>
45
+ --output_name=<学習したモデル出力時のファイル名>
46
+ --save_model_as=safetensors
47
+ --prior_loss_weight=1.0
48
+ --max_train_steps=400
49
+ --learning_rate=1e-4
50
+ --optimizer_type="AdamW8bit"
51
+ --xformers
52
+ --mixed_precision="fp16"
53
+ --cache_latents
54
+ --gradient_checkpointing
55
+ --save_every_n_epochs=1
56
+ --network_module=networks.lora
57
+ ```
58
+
59
+ `--output_dir` オプションで指定したフォルダに、LoRAのモデルが保存されます。他のオプション、オプティマイザ等については [学習の共通ドキュメント](./train_README-ja.md) の「よく使われるオプション」も参照してください。
60
+
61
+ その他、以下のオプションが指定できます。
62
+
63
+ * `--network_dim`
64
+ * LoRAのRANKを指定します(``--networkdim=4``など)。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。
65
+ * `--network_alpha`
66
+ * アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定します。デフォルトは1です。``network_dim``と同じ値を指定すると以前のバージョンと同じ動作になります。
67
+ * `--persistent_data_loader_workers`
68
+ * Windows環境で指定するとエポック間の待ち時間が大幅に短縮されます。
69
+ * `--max_data_loader_n_workers`
70
+ * データ読み込みのプロセス数を指定します。プロセス数が多いとデータ読み込みが速くなりGPUを効率的に利用できますが、メインメモリを消費します。デフォルトは「`8` または `CPU同時実行スレッド数-1` の小さいほう」なので、メインメモリに余裕がない場合や、GPU使用率が90%程度以上なら、それらの数値を見ながら `2` または `1` 程度まで下げてください。
71
+ * `--network_weights`
72
+ * 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。
73
+ * `--network_train_unet_only`
74
+ * U-Netに関連するLoRAモジュールのみ有効とします。fine tuning的な学習で指定するとよいかもしれません。
75
+ * `--network_train_text_encoder_only`
76
+ * Text Encoderに関連するLoRAモジュールのみ有効とします。Textual Inversion的な効果が期待できるかもしれません。
77
+ * `--unet_lr`
78
+ * U-Netに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。
79
+ * `--text_encoder_lr`
80
+ * Text Encoderに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。Text Encoderのほうを若干低めの学習率(5e-5など)にしたほうが良い、という話もあるようです。
81
+ * `--network_args`
82
+ * 複数の引数を指定できます。後述します。
83
+
84
+ `--network_train_unet_only` と `--network_train_text_encoder_only` の両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。
85
+
86
+ ## LoRA を Conv2d に拡大して適用する
87
+
88
+ 通常のLoRAは Linear およぴカーネルサイズ 1x1 の Conv2d にのみ適用されますが、カーネルサイズ 3x3 のConv2dに適用を拡大することもできます。
89
+
90
+ `--network_args` に以下のように指定してください。`conv_dim` で Conv2d (3x3) の rank を、`conv_alpha` で alpha を指定してください。
91
+
92
+ ```
93
+ --network_args "conv_dim=1" "conv_alpha=1"
94
+ ```
95
+
96
+ 以下のように alpha 省略時は1になります。
97
+
98
+ ```
99
+ --network_args "conv_dim=1"
100
+ ```
101
+
102
+ ## マージスクリプトについて
103
+
104
+ merge_lora.pyでStable DiffusionのモデルにLoRAの学習結果をマージしたり、複数のLoRAモデルをマージしたりできます。
105
+
106
+ ### Stable DiffusionのモデルにLoRAのモデルをマージする
107
+
108
+ マージ後のモデルは通常のStable Diffusionのckptと同様に扱えます。たとえば以下のようなコマンドラインになります。
109
+
110
+ ```
111
+ python networks\merge_lora.py --sd_model ..\model\model.ckpt
112
+ --save_to ..\lora_train1\model-char1-merged.safetensors
113
+ --models ..\lora_train1\last.safetensors --ratios 0.8
114
+ ```
115
+
116
+ Stable Diffusion v2.xのモデルで学習し、それにマージする場合は、--v2オプションを指定してください。
117
+
118
+ --sd_modelオプションにマージの元となるStable Diffusionのモデルファイルを指定します(.ckptまたは.safetensorsのみ対応で、Diffusersは今のところ対応していません)。
119
+
120
+ --save_toオプションにマージ後のモデルの保存先を指定します(.ckptまたは.safetensors、拡張子で自動判定)。
121
+
122
+ --modelsに学習したLoRAのモデルファイルを指定します。複数指定も可能で、その時は順にマージします。
123
+
124
+ --ratiosにそれぞれのモデルの適用率(どのくらい重みを元モデルに反映するか)を0~1.0の数値で指定します。例えば過学習に近いような場合は、適用率を下げるとマシになるかもしれません。モデルの数と同じだけ指定してください。
125
+
126
+ 複数指定時は以下のようになります。
127
+
128
+ ```
129
+ python networks\merge_lora.py --sd_model ..\model\model.ckpt
130
+ --save_to ..\lora_train1\model-char1-merged.safetensors
131
+ --models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.8 0.5
132
+ ```
133
+
134
+ ### 複数のLoRAのモデルをマージする
135
+
136
+ 複数のLoRAモデルをひとつずつSDモデルに適用する場合と、複数のLoRAモデルをマージしてからSDモデルにマージする場合とは、計算順序の関連で微妙に異なる結果になります。
137
+
138
+ たとえば以下のようなコマンドラインになります。
139
+
140
+ ```
141
+ python networks\merge_lora.py
142
+ --save_to ..\lora_train1\model-char1-style1-merged.safetensors
143
+ --models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.6 0.4
144
+ ```
145
+
146
+ --sd_modelオプションは指定不要です。
147
+
148
+ --save_toオプションにマージ後のLoRAモデルの保存先を指定します(.ckptまたは.safetensors、拡張子で自動判定)。
149
+
150
+ --modelsに学習したLoRAのモデルファイルを指定します。三つ以上も指定可能です。
151
+
152
+ --ratiosにそれぞれのモデルの比率(どのくらい重みを元モデルに反映するか)を0~1.0の数値で指定します。二つのモデルを一対一でマージす場合は、「0.5 0.5」になります。「1.0 1.0」では合計の重みが大きくなりすぎて、恐らく結果はあまり望ましくないものになると思われます。
153
+
154
+ v1で学習したLoRAとv2で学習したLoRA、rank(次元数)や``alpha``の異なるLoRAはマージできません。U-NetだけのLoRAとU-Net+Text EncoderのLoRAはマージできるはずですが、結果は未知数です。
155
+
156
+
157
+ ### その他のオプション
158
+
159
+ * precision
160
+ * マージ計算時の精度をfloat、fp16、bf16から指定できます。省略時は精度を確保するためfloatになります。メモリ使用量を減らしたい場合はfp16/bf16を指定してください。
161
+ * save_precision
162
+ * モ��ル保存時の精度をfloat、fp16、bf16から指定できます。省略時はprecisionと同じ精度になります。
163
+
164
+
165
+ ## 複数のrankが異なるLoRAのモデルをマージする
166
+
167
+ 複数のLoRAをひとつのLoRAで近似します(完全な再現はできません)。`svd_merge_lora.py`を用います。たとえば以下のようなコマンドラインになります。
168
+
169
+ ```
170
+ python networks\svd_merge_lora.py
171
+ --save_to ..\lora_train1\model-char1-style1-merged.safetensors
172
+ --models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors
173
+ --ratios 0.6 0.4 --new_rank 32 --device cuda
174
+ ```
175
+
176
+ `merge_lora.py` と主なオプションは同一です。以下のオプションが追加されています。
177
+
178
+ - `--new_rank`
179
+ - 作成するLoRAのrankを指定します。
180
+ - `--new_conv_rank`
181
+ - 作成する Conv2d 3x3 LoRA の rank を指定します。省略時は `new_rank` と同じになります。
182
+ - `--device`
183
+ - `--device cuda`としてcudaを指定すると計算をGPU上で行います。処理が速くなります。
184
+
185
+ ## 当リポジトリ内の画像生成スクリプトで生成する
186
+
187
+ gen_img_diffusers.pyに、--network_module、--network_weightsの各オプションを追加してください。意味は学習時と同様です。
188
+
189
+ --network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。
190
+
191
+ ## 二つのモデルの差分からLoRAモデルを作成する
192
+
193
+ [こちらのディスカッション](https://github.com/cloneofsimo/lora/discussions/56)を参考に実装したものです。数式はそのまま使わせていただきました(よく理解していませんが近似には特異値分解を用いるようです)。
194
+
195
+ 二つのモデル(たとえばfine tuningの元モデルとfine tuning後のモデル)の差分を、LoRAで近似します。
196
+
197
+ ### スクリプトの実行方法
198
+
199
+ 以下のように指定してください。
200
+ ```
201
+ python networks\extract_lora_from_models.py --model_org base-model.ckpt
202
+ --model_tuned fine-tuned-model.ckpt
203
+ --save_to lora-weights.safetensors --dim 4
204
+ ```
205
+
206
+ --model_orgオプションに元のStable Diffusionモデルを指定します。作成したLoRAモデルを適用する場合は、このモデルを指定して適用することになります。.ckptまたは.safetensorsが指定できます。
207
+
208
+ --model_tunedオプションに差分を抽出する対象のStable Diffusionモデルを指定します。たとえばfine tuningやDreamBooth後のモデルを指定します。.ckptまたは.safetensorsが指定できます。
209
+
210
+ --save_toにLoRAモデルの保存先を指定します。--dimにLoRAの次元数を指定します。
211
+
212
+ 生成されたLoRAモデルは、学習したLoRAモデルと同様に使用できます。
213
+
214
+ Text Encoderが二つのモデルで同じ場合にはLoRAはU-NetのみのLoRAとなります。
215
+
216
+ ### その他のオプション
217
+
218
+ - `--v2`
219
+ - v2.xのStable Diffusionモデルを使う場合に指定してください。
220
+ - `--device`
221
+ - ``--device cuda``としてcudaを指定すると計算をGPU上で行います。処理が速くなります(CPUでもそこまで遅くないため、せいぜい倍~数倍程度のようです)。
222
+ - `--save_precision`
223
+ - LoRAの保存形式を"float", "fp16", "bf16"から指定します。省略時はfloatになります。
224
+ - `--conv_dim`
225
+ - 指定するとLoRAの適用範囲を Conv2d 3x3 へ拡大します。Conv2d 3x3 の rank を指定します。
226
+
227
+ ## 画像リサイズスクリプト
228
+
229
+ (のちほどドキュメントを整理しますがとりあえずここに説明を書いておきます。)
230
+
231
+ Aspect Ratio Bucketingの機能拡張で、小さな画像については拡大しないでそのまま教師データとすることが可能になりました。元の教師画像を縮小した画像を、教師データに加えると精度が向上したという報告とともに前処理用のスクリプトをいただきましたので整備して追加しました。bmaltais氏に感謝します。
232
+
233
+ ### スクリプトの実行方法
234
+
235
+ 以下のように指定してください。元の画像そのまま、およびリサイズ後の画像が変換先フォルダに保存されます。リサイズ後の画像には、ファイル名に ``+512x512`` のようにリサイズ先の解像度が付け加えられます(画像サイズとは異なります)。リサイズ先の解像度より小さい画像は拡大されることはありません。
236
+
237
+ ```
238
+ python tools\resize_images_to_resolution.py --max_resolution 512x512,384x384,256x256 --save_as_png
239
+ --copy_associated_files 元画像フォルダ 変換先フォルダ
240
+ ```
241
+
242
+ 元画像フォルダ内の画像ファイルが、指定した解像度(複数指定可)と同じ面積になるようにリサイズされ、変換先フォルダに保存されます。画像以外のファイルはそのままコピーされます。
243
+
244
+ ``--max_resolution`` オプションにリサイズ���のサイズを例のように指定してください。面積がそのサイズになるようにリサイズします。複数指定すると、それぞれの解像度でリサイズされます。``512x512,384x384,256x256``なら、変換先フォルダの画像は、元サイズとリサイズ後サイズ×3の計4枚になります。
245
+
246
+ ``--save_as_png`` オプションを指定するとpng形式で保存します。省略するとjpeg形式(quality=100)で保存されます。
247
+
248
+ ``--copy_associated_files`` オプションを指定すると、拡張子を除き画像と同じファイル名(たとえばキャプションなど)のファイルが、リサイズ後の画像のファイル名と同じ名前でコピーされます。
249
+
250
+
251
+ ### その他のオプション
252
+
253
+ - divisible_by
254
+ - リサイズ後の画像のサイズ(縦、横のそれぞれ)がこの値で割り切れるように、画像中心を切り出します。
255
+ - interpolation
256
+ - 縮小時の補完方法を指定します。``area, cubic, lanczos4``から選択可能で、デフォルトは``area``です。
257
+
258
+
259
+ ## 追加情報
260
+
261
+ ### cloneofsimo氏のリポジトリとの違い
262
+
263
+ 2022/12/25時点では、当リポジトリはLoRAの適用個所をText EncoderのMLP、U-NetのFFN、Transformerのin/out projectionに拡大し、表現力が増しています。ただその代わりメモリ使用量は増え、8GBぎりぎりになりました。
264
+
265
+ またモジュール入れ替え機構は全く異なります。
266
+
267
+ ### 将来拡張について
268
+
269
+ LoRAだけでなく他の拡張にも対応可能ですので、それらも追加予定です。
train_network_opt.py CHANGED
@@ -1,8 +1,5 @@
1
- from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
2
- from torch.optim import Optimizer
3
  from torch.cuda.amp import autocast
4
  from torch.nn.parallel import DistributedDataParallel as DDP
5
- from typing import Optional, Union
6
  import importlib
7
  import argparse
8
  import gc
@@ -15,138 +12,49 @@ import json
15
  from tqdm import tqdm
16
  import torch
17
  from accelerate.utils import set_seed
18
- import diffusers
19
  from diffusers import DDPMScheduler
20
- print("**********************************")
21
- #先に
22
- #pip install torch_optimizer
23
- #が必要
24
- try:
25
- import torch_optimizer as optim
26
- except:
27
- print("torch_optimizerがインストールされていないためAdafactorとAdastand以外の追加optimzierは使えません。\noptimizerの変更をしたい場合先にpip install torch_optimizerでライブラリを追加してください")
28
- try:
29
- import adastand
30
- except:
31
- print("※Adastandが使えません")
32
-
33
- from transformers.optimization import Adafactor, AdafactorSchedule
34
- print("**********************************")
35
  ##### バケット拡張のためのモジュール
36
  import append_module
37
  ######
38
  import library.train_util as train_util
39
- from library.train_util import DreamBoothDataset, FineTuningDataset
 
 
 
 
 
 
 
40
 
41
 
42
  def collate_fn(examples):
43
  return examples[0]
44
 
45
 
46
- def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
 
47
  logs = {"loss/current": current_loss, "loss/average": avr_loss}
48
-
49
- if args.network_train_unet_only:
50
- logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
51
- elif args.network_train_text_encoder_only:
52
- logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
 
 
 
53
  else:
54
  last_lrs = lr_scheduler.get_last_lr()
55
- if len(last_lrs) == 2:
56
- logs["lr/textencoder"] = float(last_lrs[0])
57
- logs["lr/unet"] = float(last_lrs[-1]) # may be same to textencoder
58
- else:
59
- if len(last_lrs) == 4:
60
- logs_names = ["textencoder", "lora_unet_mid_block", "unet_down_blocks", "unet_up_blocks"]
61
- elif len(last_lrs) == 8:
62
- logs_names = ["textencoder", "unet_midblock"]
63
- for i in range(3):
64
- logs_names.append(f"unet_down_blocks_{i}")
65
- logs_names.append(f"unet_up_blocks_{i+1}")
66
- else:
67
- logs_names = []
68
- for i in range(12):
69
- logs_names.append(f"text_model_encoder_layers_{i}_")
70
- logs_names.append("unet_midblock")
71
- for i in range(3):
72
- logs_names.append(f"unet_down_blocks_{i}")
73
- logs_names.append(f"unet_up_blocks_{i+1}")
74
-
75
- for last_lr, logs_name in zip(last_lrs, logs_names):
76
- logs[f"lr/{logs_name}"] = float(last_lr)
77
 
78
  return logs
79
 
80
 
81
- # Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
82
- # code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
83
- # Which is a newer release of diffusers than currently packaged with sd-scripts
84
- # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
85
-
86
-
87
- def get_scheduler_fix(
88
- name: Union[str, SchedulerType],
89
- optimizer: Optimizer,
90
- num_warmup_steps: Optional[int] = None,
91
- num_training_steps: Optional[int] = None,
92
- num_cycles: float = 1.,
93
- power: float = 1.0,
94
- ):
95
- """
96
- Unified API to get any scheduler from its name.
97
- Args:
98
- name (`str` or `SchedulerType`):
99
- The name of the scheduler to use.
100
- optimizer (`torch.optim.Optimizer`):
101
- The optimizer that will be used during training.
102
- num_warmup_steps (`int`, *optional*):
103
- The number of warmup steps to do. This is not required by all schedulers (hence the argument being
104
- optional), the function will raise an error if it's unset and the scheduler type requires it.
105
- num_training_steps (`int``, *optional*):
106
- The number of training steps to do. This is not required by all schedulers (hence the argument being
107
- optional), the function will raise an error if it's unset and the scheduler type requires it.
108
- num_cycles (`int`, *optional*):
109
- The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
110
- power (`float`, *optional*, defaults to 1.0):
111
- Power factor. See `POLYNOMIAL` scheduler
112
- last_epoch (`int`, *optional*, defaults to -1):
113
- The index of the last epoch when resuming training.
114
- """
115
- name = SchedulerType(name)
116
- schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
117
- if name == SchedulerType.CONSTANT:
118
- return schedule_func(optimizer)
119
-
120
- # All other schedulers require `num_warmup_steps`
121
- if num_warmup_steps is None:
122
- raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
123
-
124
- if name == SchedulerType.CONSTANT_WITH_WARMUP:
125
- return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
126
-
127
- # All other schedulers require `num_training_steps`
128
- if num_training_steps is None:
129
- raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
130
-
131
- if name == SchedulerType.COSINE:
132
- print(f"{name} num_cycles: {num_cycles}")
133
- return schedule_func(
134
- optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
135
- )
136
- if name == SchedulerType.COSINE_WITH_RESTARTS:
137
- print(f"{name} num_cycles: {int(num_cycles)}")
138
- return schedule_func(
139
- optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=int(num_cycles)
140
- )
141
-
142
- if name == SchedulerType.POLYNOMIAL:
143
- return schedule_func(
144
- optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
145
- )
146
-
147
- return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
148
-
149
-
150
  def train(args):
151
  session_id = random.randint(0, 2**32)
152
  training_started_at = time.time()
@@ -155,6 +63,7 @@ def train(args):
155
 
156
  cache_latents = args.cache_latents
157
  use_dreambooth_method = args.in_json is None
 
158
 
159
  if args.seed is not None:
160
  set_seed(args.seed)
@@ -162,52 +71,72 @@ def train(args):
162
  tokenizer = train_util.load_tokenizer(args)
163
 
164
  # データセットを準備する
165
- if use_dreambooth_method:
166
- if args.min_resolution:
167
- args.min_resolution = tuple([int(r) for r in args.min_resolution.split(',')])
168
- if len(args.min_resolution) == 1:
169
- args.min_resolution = (args.min_resolution[0], args.min_resolution[0])
170
-
171
- print("Use DreamBooth method.")
172
- train_dataset = append_module.DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
173
- tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
174
- args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
175
- args.bucket_reso_steps, args.bucket_no_upscale,
176
- args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
177
- args.random_crop, args.debug_dataset, args.min_resolution, args.area_step)
178
  else:
179
- print("Train with captions.")
180
- train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
181
- tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
182
- args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
183
- args.bucket_reso_steps, args.bucket_no_upscale,
184
- args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
185
- args.dataset_repeats, args.debug_dataset)
186
-
187
- # 学習データのdropout率を設定する
188
- train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
189
-
190
- train_dataset.make_buckets()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
  if args.debug_dataset:
193
- train_util.debug_dataset(train_dataset)
194
  return
195
- if len(train_dataset) == 0:
196
  print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
197
  return
198
 
 
 
 
 
199
  # acceleratorを準備する
200
  print("prepare accelerator")
201
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
 
202
 
203
  # mixed precisionに対応した型を用意しておき適宜castする
204
  weight_dtype, save_dtype = train_util.prepare_dtype(args)
205
 
206
  # モデルを読み込む
207
  text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
208
- # unnecessary, but work on low-ram device
209
- text_encoder.to("cuda")
210
- unet.to("cuda")
 
 
 
211
  # モデルに xformers とか memory efficient attention を組み込む
212
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
213
 
@@ -217,13 +146,15 @@ def train(args):
217
  vae.requires_grad_(False)
218
  vae.eval()
219
  with torch.no_grad():
220
- train_dataset.cache_latents(vae)
221
  vae.to("cpu")
222
  if torch.cuda.is_available():
223
  torch.cuda.empty_cache()
224
  gc.collect()
225
 
226
  # prepare network
 
 
227
  print("import network module:", args.network_module)
228
  network_module = importlib.import_module(args.network_module)
229
 
@@ -253,188 +184,65 @@ def train(args):
253
 
254
  # 学習に必要なクラスを準備する
255
  print("prepare optimizer, data loader etc.")
256
- try:
257
- print(f"torch_optimzier version is {optim.__version__}")
258
- not_torch_optimizer_flag = False
259
- except:
260
- not_torch_optimizer_flag = True
261
- try:
262
- print(f"adastand version is {adastand.__version__()}")
263
- not_adasatand_optimzier_flag = False
264
- except:
265
- not_adasatand_optimzier_flag = True
266
-
267
- # 8-bit Adamを使う
268
- if args.optimizer=="Adafactor" or args.optimizer=="Adastand" or args.optimizer=="Adastand_belief":
269
- not_torch_optimizer_flag = False
270
- if args.optimizer=="Adafactor":
271
- not_adasatand_optimzier_flag = False
272
- if not_torch_optimizer_flag or not_adasatand_optimzier_flag:
273
- print(f"==========================\n必要なライブラリがないため {args.optimizer} の使用ができません。optimizerを AdamW に変更して実行します\n==========================")
274
- args.optimizer="AdamW"
275
- if args.use_8bit_adam:
276
- if not args.optimizer=="AdamW" and not args.optimizer=="Lamb":
277
- print(f"\n==========================\n{args.optimizer} は8bitAdamに実装されていないので8bitAdamをオフにします\n==========================\n")
278
- args.use_8bit_adam=False
279
- if args.use_8bit_adam:
280
- try:
281
- import bitsandbytes as bnb
282
- except ImportError:
283
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
284
- print("use 8-bit Adam optimizer")
285
- args.training_comment=f"{args.training_comment} use_8bit_adam=True"
286
- if args.optimizer=="Lamb":
287
- optimizer_class = bnb.optim.LAMB8bit
288
- else:
289
- args.optimizer="AdamW"
290
- optimizer_class = bnb.optim.AdamW8bit
291
- else:
292
- print(f"use {args.optimizer}")
293
- if args.optimizer=="RAdam":
294
- optimizer_class = torch.optim.RAdam
295
- elif args.optimizer=="AdaBound":
296
- optimizer_class = optim.AdaBound
297
- elif args.optimizer=="AdaBelief":
298
- optimizer_class = optim.AdaBelief
299
- elif args.optimizer=="AdamP":
300
- optimizer_class = optim.AdamP
301
- elif args.optimizer=="Adafactor":
302
- optimizer_class = Adafactor
303
- elif args.optimizer=="Adastand":
304
- optimizer_class = adastand.Adastand
305
- elif args.optimizer=="Adastand_belief":
306
- optimizer_class = adastand.Adastand_b
307
- elif args.optimizer=="AggMo":
308
- optimizer_class = optim.AggMo
309
- elif args.optimizer=="Apollo":
310
- optimizer_class = optim.Apollo
311
- elif args.optimizer=="Lamb":
312
- optimizer_class = optim.Lamb
313
- elif args.optimizer=="Ranger":
314
- optimizer_class = optim.Ranger
315
- elif args.optimizer=="RangerVA":
316
- optimizer_class = optim.RangerVA
317
- elif args.optimizer=="Yogi":
318
- optimizer_class = optim.Yogi
319
- elif args.optimizer=="Shampoo":
320
- optimizer_class = optim.Shampoo
321
- elif args.optimizer=="NovoGrad":
322
- optimizer_class = optim.NovoGrad
323
- elif args.optimizer=="QHAdam":
324
- optimizer_class = optim.QHAdam
325
- elif args.optimizer=="DiffGrad" or args.optimizer=="Lookahead_DiffGrad":
326
- optimizer_class = optim.DiffGrad
327
- elif args.optimizer=="MADGRAD":
328
- optimizer_class = optim.MADGRAD
329
- else:
330
- optimizer_class = torch.optim.AdamW
331
-
332
- # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
333
- #optimizerデフォ設定
334
- if args.optimizer_arg==None:
335
- if args.optimizer=="AdaBelief":
336
- args.optimizer_arg = ["eps=1e-16","betas=0.9,0.999","weight_decouple=True","rectify=False","fixed_decay=False"]
337
- elif args.optimizer=="DiffGrad":
338
- args.optimizer_arg = ["eps=1e-16"]
339
- optimizer_arg = {}
340
- lookahed_arg = {"k": 5, "alpha": 0.5}
341
- adafactor_scheduler_arg = {"initial_lr": 0.}
342
- int_args = ["k","n_sma_threshold","warmup"]
343
- str_args = ["transformer","grad_transformer"]
344
- if not args.optimizer_arg==None and len(args.optimizer_arg)>0:
345
- for _opt_arg in args.optimizer_arg:
346
- key, value = _opt_arg.split("=")
347
- if value=="True" or value=="False":
348
- optimizer_arg[key]=bool((value=="True"))
349
- elif key=="betas" or key=="nus" or key=="eps2" or (key=="eps" and "," in value):
350
- _value = value.split(",")
351
- optimizer_arg[key] = (float(_value[0]),float(_value[1]))
352
- del _value
353
- elif key in int_args:
354
- if "Lookahead" in args.optimizer:
355
- lookahed_arg[key] = int(value)
356
- else:
357
- optimizer_arg[key] = int(value)
358
- elif key in str_args:
359
- optimizer_arg[key] = value
360
- else:
361
- if key=="alpha" and "Lookahead" in args.optimizer:
362
- lookahed_arg[key] = int(value)
363
- elif key=="initial_lr" and args.optimizer == "Adafactor":
364
- adafactor_scheduler_arg[key] = float(value)
365
- else:
366
- optimizer_arg[key] = float(value)
367
- del _opt_arg
368
- AdafactorScheduler_Flag = False
369
- list_of_init_lr = []
370
- if args.optimizer=="Adafactor":
371
- if not "relative_step" in optimizer_arg:
372
- optimizer_arg["relative_step"] = True
373
- if "warmup_init" in optimizer_arg:
374
- if optimizer_arg["warmup_init"]==True and optimizer_arg["relative_step"]==False:
375
- print("**************\nwarmup_initはrelative_stepがオンである必要があるためrelative_stepをオンにします\n**************")
376
- optimizer_arg["relative_step"] = True
377
- if optimizer_arg["relative_step"] == True:
378
- AdafactorScheduler_Flag = True
379
- list_of_init_lr = [0.,0.]
380
- if args.text_encoder_lr is not None: list_of_init_lr[0] = float(args.text_encoder_lr)
381
- if args.unet_lr is not None: list_of_init_lr[1] = float(args.unet_lr)
382
- #if not "initial_lr" in adafactor_scheduler_arg:
383
- # adafactor_scheduler_arg = args.learning_rate
384
- args.learning_rate = None
385
- args.text_encoder_lr = None
386
- args.unet_lr = None
387
- print(f"optimizer arg: {optimizer_arg}")
388
- print("=-----------------------------------=")
389
- if not AdafactorScheduler_Flag: args.split_lora_networks = False
390
  if args.split_lora_networks:
 
391
  lora_names = append_module.create_split_names(args.split_lora_networks, args.split_lora_level)
392
- append_module.replace_prepare_optimizer_params(network)
393
- trainable_params, _list_of_init_lr = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, list_of_init_lr, lora_names)
394
  else:
395
  trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
396
- _list_of_init_lr = []
397
- print(f"trainable_params_len: {len(trainable_params)}")
398
- if len(_list_of_init_lr)>0:
399
- list_of_init_lr = _list_of_init_lr
400
- print(f"split loras network is {len(list_of_init_lr)}")
401
- if len(list_of_init_lr) > 0:
402
- adafactor_scheduler_arg["initial_lr"] = list_of_init_lr
403
-
404
- optimizer = optimizer_class(trainable_params, lr=args.learning_rate, **optimizer_arg)
405
-
406
- if args.optimizer=="Lookahead_DiffGrad" or args.optimizer=="Lookahedad_Adam":
407
- optimizer = optim.Lookahead(optimizer, **lookahed_arg)
408
- print(f"lookahed_arg: {lookahed_arg}")
409
-
 
 
 
 
 
 
 
 
 
 
410
  # dataloaderを準備する
411
  # DataLoaderのプロセス数:0はメインプロセスになる
412
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
413
  train_dataloader = torch.utils.data.DataLoader(
414
- train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
415
 
416
  # 学習ステップ数を計算する
417
  if args.max_train_epochs is not None:
418
- args.max_train_steps = args.max_train_epochs * len(train_dataloader)
419
- print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
 
420
 
421
  # lr schedulerを用意する
422
- # lr_scheduler = diffusers.optimization.get_scheduler(
423
- if AdafactorScheduler_Flag:
424
- print("===================================\nAdafactorはデフォルトでrelative_stepがオンになっているので lrは自動算出されるためLrScheculerの指定も無効になります\nもし任意のLrやLr_Schedulerを使いたい場合は --optimizer_arg relative_ste=False を指定してください\nまた任意のLrを使う場合は scale_parameter=False も併せて指定するのが推奨です\n===================================")
425
- lr_scheduler = append_module.AdafactorSchedule_append(optimizer, **adafactor_scheduler_arg)
426
- print(f"AdafactorSchedule initial lrs: {lr_scheduler.get_lr()}")
427
- del list_of_init_lr
428
  else:
429
- lr_scheduler = get_scheduler_fix(
430
- args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
431
- num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
432
- num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
433
-
434
  #追加機能の設定をコメントに追記して残す
435
- args.training_comment=f"{args.training_comment} optimizer: {args.optimizer} / optimizer_arg: {args.optimizer_arg}"
436
- if AdafactorScheduler_Flag:
437
- args.training_comment=f"{args.training_comment} split_lora_networks: {args.split_lora_networks}"
 
438
  if args.min_resolution:
439
  args.training_comment=f"{args.training_comment} min_resolution: {args.min_resolution} area_step: {args.area_step}"
440
 
@@ -504,17 +312,21 @@ def train(args):
504
  args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
505
 
506
  # 学習する
 
507
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
508
- print("running training / 学習開始")
509
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
510
- print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
511
- print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
512
- print(f" num epochs / epoch数: {num_train_epochs}")
513
- print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
514
- print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
515
- print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
516
- print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
517
-
 
 
 
518
  metadata = {
519
  "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
520
  "ss_training_started_at": training_started_at, # unix timestamp
@@ -522,12 +334,10 @@ def train(args):
522
  "ss_learning_rate": args.learning_rate,
523
  "ss_text_encoder_lr": args.text_encoder_lr,
524
  "ss_unet_lr": args.unet_lr,
525
- "ss_num_train_images": train_dataset.num_train_images, # includes repeating
526
- "ss_num_reg_images": train_dataset.num_reg_images,
527
  "ss_num_batches_per_epoch": len(train_dataloader),
528
  "ss_num_epochs": num_train_epochs,
529
- "ss_batch_size_per_device": args.train_batch_size,
530
- "ss_total_batch_size": total_batch_size,
531
  "ss_gradient_checkpointing": args.gradient_checkpointing,
532
  "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
533
  "ss_max_train_steps": args.max_train_steps,
@@ -539,32 +349,156 @@ def train(args):
539
  "ss_mixed_precision": args.mixed_precision,
540
  "ss_full_fp16": bool(args.full_fp16),
541
  "ss_v2": bool(args.v2),
542
- "ss_resolution": args.resolution,
543
  "ss_clip_skip": args.clip_skip,
544
  "ss_max_token_length": args.max_token_length,
545
- "ss_color_aug": bool(args.color_aug),
546
- "ss_flip_aug": bool(args.flip_aug),
547
- "ss_random_crop": bool(args.random_crop),
548
- "ss_shuffle_caption": bool(args.shuffle_caption),
549
  "ss_cache_latents": bool(args.cache_latents),
550
- "ss_enable_bucket": bool(train_dataset.enable_bucket),
551
- "ss_min_bucket_reso": train_dataset.min_bucket_reso,
552
- "ss_max_bucket_reso": train_dataset.max_bucket_reso,
553
  "ss_seed": args.seed,
554
- "ss_keep_tokens": args.keep_tokens,
555
  "ss_noise_offset": args.noise_offset,
556
- "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
557
- "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
558
- "ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
559
- "ss_bucket_info": json.dumps(train_dataset.bucket_info),
560
  "ss_training_comment": args.training_comment, # will not be updated after training
561
- "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash()
 
 
 
 
 
 
 
562
  }
563
 
564
- # uncomment if another network is added
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
  # for key, value in net_kwargs.items():
566
  # metadata["ss_arg_" + key] = value
567
 
 
568
  if args.pretrained_model_name_or_path is not None:
569
  sd_model_name = args.pretrained_model_name_or_path
570
  if os.path.exists(sd_model_name):
@@ -583,6 +517,13 @@ def train(args):
583
 
584
  metadata = {k: str(v) for k, v in metadata.items()}
585
 
 
 
 
 
 
 
 
586
  progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
587
  global_step = 0
588
 
@@ -595,8 +536,9 @@ def train(args):
595
  loss_list = []
596
  loss_total = 0.0
597
  for epoch in range(num_train_epochs):
598
- print(f"epoch {epoch+1}/{num_train_epochs}")
599
- train_dataset.set_current_epoch(epoch + 1)
 
600
 
601
  metadata["ss_epoch"] = str(epoch+1)
602
 
@@ -633,7 +575,7 @@ def train(args):
633
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
634
 
635
  # Predict the noise residual
636
- with autocast():
637
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
638
 
639
  if args.v_parameterization:
@@ -651,12 +593,13 @@ def train(args):
651
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
652
 
653
  accelerator.backward(loss)
654
- if accelerator.sync_gradients:
655
  params_to_clip = network.get_trainable_params()
656
- accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
657
 
658
  optimizer.step()
659
- lr_scheduler.step()
 
660
  optimizer.zero_grad(set_to_none=True)
661
 
662
  # Checks if the accelerator has performed an optimization step behind the scenes
@@ -664,6 +607,8 @@ def train(args):
664
  progress_bar.update(1)
665
  global_step += 1
666
 
 
 
667
  current_loss = loss.detach().item()
668
  if epoch == 0:
669
  loss_list.append(current_loss)
@@ -676,7 +621,7 @@ def train(args):
676
  progress_bar.set_postfix(**logs)
677
 
678
  if args.logging_dir is not None:
679
- logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
680
  accelerator.log(logs, step=global_step)
681
 
682
  if global_step >= args.max_train_steps:
@@ -694,8 +639,9 @@ def train(args):
694
  def save_func():
695
  ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
696
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
 
697
  print(f"saving checkpoint: {ckpt_file}")
698
- unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
699
 
700
  def remove_old_func(old_epoch_no):
701
  old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
@@ -704,15 +650,18 @@ def train(args):
704
  print(f"removing old checkpoint: {old_ckpt_file}")
705
  os.remove(old_ckpt_file)
706
 
707
- saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
708
- if saving and args.save_state:
709
- train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
 
 
 
710
 
711
  # end of epoch
712
 
713
  metadata["ss_epoch"] = str(num_train_epochs)
 
714
 
715
- is_main_process = accelerator.is_main_process
716
  if is_main_process:
717
  network = unwrap_model(network)
718
 
@@ -731,7 +680,7 @@ def train(args):
731
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
732
 
733
  print(f"save trained model to {ckpt_file}")
734
- network.save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
735
  print("model saved.")
736
 
737
 
@@ -741,6 +690,8 @@ if __name__ == '__main__':
741
  train_util.add_sd_models_arguments(parser)
742
  train_util.add_dataset_arguments(parser, True, True, True)
743
  train_util.add_training_arguments(parser, True)
 
 
744
 
745
  parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
746
  parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
@@ -748,10 +699,6 @@ if __name__ == '__main__':
748
 
749
  parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
750
  parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
751
- parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
752
- help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
753
- parser.add_argument("--lr_scheduler_power", type=float, default=1,
754
- help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
755
 
756
  parser.add_argument("--network_weights", type=str, default=None,
757
  help="pretrained weights for network / 学習するネットワークの初期重み")
@@ -771,27 +718,30 @@ if __name__ == '__main__':
771
  #Optimizer変更関連のオプション追加
772
  append_module.add_append_arguments(parser)
773
  args = append_module.get_config(parser)
 
 
 
 
 
 
 
 
 
 
 
 
774
 
775
  if args.resolution==args.min_resolution:
776
  args.min_resolution=None
777
 
778
  train(args)
 
779
 
780
- #学習が終わったら現在のargsを保存する
781
- # import yaml
782
- # import datetime
783
- # _t = datetime.datetime.today().strftime('%Y%m%d_%H%M')
784
- # if args.output_name==None:
785
- # config_name = f"train_network_config_{_t}.yaml"
786
- # else:
787
- # config_name = f"train_network_config_{os.path.basename(args.output_name)}_{_t}.yaml"
788
- # print(f"{config_name} に設定を書き出し中...")
789
- # with open(config_name, mode="w") as f:
790
- # yaml.dump(args.__dict__, f, indent=4)
791
- # print("done!")
792
 
793
  '''
794
  optimizer設定メモ
 
 
795
  (optimizer_argから設定できるように変更するためのメモ)
796
 
797
  AdamWのweight_decay初期値は1e-2
@@ -821,6 +771,7 @@ Adafactor
821
  transformerベースのT5学習において最強とかいう噂のoptimizer
822
  huggingfaceのサンプルパラ
823
  eps=1e-30,1e-3 clip_threshold=1.0 decay_rate=-0.8 relative_step=False scale_parameter=False warmup_init=False
 
824
 
825
  AggMo
826
 
 
 
 
1
  from torch.cuda.amp import autocast
2
  from torch.nn.parallel import DistributedDataParallel as DDP
 
3
  import importlib
4
  import argparse
5
  import gc
 
12
  from tqdm import tqdm
13
  import torch
14
  from accelerate.utils import set_seed
15
+ #import diffusers
16
  from diffusers import DDPMScheduler
17
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  ##### バケット拡張のためのモジュール
19
  import append_module
20
  ######
21
  import library.train_util as train_util
22
+ from library.train_util import (
23
+ DreamBoothDataset,
24
+ )
25
+ import library.config_util as config_util
26
+ from library.config_util import (
27
+ ConfigSanitizer,
28
+ BlueprintGenerator,
29
+ )
30
 
31
 
32
  def collate_fn(examples):
33
  return examples[0]
34
 
35
 
36
+ # TODO 他のスクリプトと共通化する
37
+ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, split_names=None):
38
  logs = {"loss/current": current_loss, "loss/average": avr_loss}
39
+ if not args.split_lora_networks:
40
+ if args.network_train_unet_only:
41
+ logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
42
+ elif args.network_train_text_encoder_only:
43
+ logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
44
+ else:
45
+ logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
46
+ logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
47
  else:
48
  last_lrs = lr_scheduler.get_last_lr()
49
+ for last_lr, t_name in zip(last_lrs, split_names):
50
+ logs[f"lr/{t_name}"] = float(last_lr)
51
+ #D-Adaptationの仕様ちゃんと見てないからたぶん分割したのをちゃんと表示するならそれに合わせた記述が必要 でも多分D-Adaptationの挙動的に全部同一の形になるのでいらない
52
+ if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
53
+ logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  return logs
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def train(args):
59
  session_id = random.randint(0, 2**32)
60
  training_started_at = time.time()
 
63
 
64
  cache_latents = args.cache_latents
65
  use_dreambooth_method = args.in_json is None
66
+ use_user_config = args.dataset_config is not None
67
 
68
  if args.seed is not None:
69
  set_seed(args.seed)
 
71
  tokenizer = train_util.load_tokenizer(args)
72
 
73
  # データセットを準備する
74
+ if args.min_resolution:
75
+ args.min_resolution = tuple([int(r) for r in args.min_resolution.split(',')])
76
+ if len(args.min_resolution) == 1:
77
+ args.min_resolution = (args.min_resolution[0], args.min_resolution[0])
78
+ blueprint_generator = append_module.BlueprintGenerator(append_module.ConfigSanitizer(True, True, True))
 
 
 
 
 
 
 
 
79
  else:
80
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
81
+ if use_user_config:
82
+ print(f"Load dataset config from {args.dataset_config}")
83
+ user_config = config_util.load_user_config(args.dataset_config)
84
+ ignored = ["train_data_dir", "reg_data_dir", "in_json"]
85
+ if any(getattr(args, attr) is not None for attr in ignored):
86
+ print(
87
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
88
+ else:
89
+ if use_dreambooth_method:
90
+ print("Use DreamBooth method.")
91
+ user_config = {
92
+ "datasets": [{
93
+ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
94
+ }]
95
+ }
96
+ else:
97
+ print("Train with captions.")
98
+ user_config = {
99
+ "datasets": [{
100
+ "subsets": [{
101
+ "image_dir": args.train_data_dir,
102
+ "metadata_file": args.in_json,
103
+ }]
104
+ }]
105
+ }
106
+
107
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
108
+ if args.min_resolution:
109
+ train_dataset_group = append_module.generate_dataset_group_by_blueprint(blueprint.dataset_group)
110
+ else:
111
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
112
 
113
  if args.debug_dataset:
114
+ train_util.debug_dataset(train_dataset_group)
115
  return
116
+ if len(train_dataset_group) == 0:
117
  print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
118
  return
119
 
120
+ if cache_latents:
121
+ assert train_dataset_group.is_latent_cacheable(
122
+ ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
123
+
124
  # acceleratorを準備する
125
  print("prepare accelerator")
126
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
127
+ is_main_process = accelerator.is_main_process
128
 
129
  # mixed precisionに対応した型を用意しておき適宜castする
130
  weight_dtype, save_dtype = train_util.prepare_dtype(args)
131
 
132
  # モデルを読み込む
133
  text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
134
+
135
+ # work on low-ram device
136
+ if args.lowram:
137
+ text_encoder.to("cuda")
138
+ unet.to("cuda")
139
+
140
  # モデルに xformers とか memory efficient attention を組み込む
141
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
142
 
 
146
  vae.requires_grad_(False)
147
  vae.eval()
148
  with torch.no_grad():
149
+ train_dataset_group.cache_latents(vae)
150
  vae.to("cpu")
151
  if torch.cuda.is_available():
152
  torch.cuda.empty_cache()
153
  gc.collect()
154
 
155
  # prepare network
156
+ import sys
157
+ sys.path.append(os.path.dirname(__file__))
158
  print("import network module:", args.network_module)
159
  network_module = importlib.import_module(args.network_module)
160
 
 
184
 
185
  # 学習に必要なクラスを準備する
186
  print("prepare optimizer, data loader etc.")
187
+ split_flag = (args.split_lora_networks) or ((not args.network_train_text_encoder_only) and (not args.network_train_unet_only))
188
+
189
+ used_names = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  if args.split_lora_networks:
191
+ lr_dic, block_args_dic = append_module.create_lr_blocks(args.blocks_lr_setting, args.block_optim_args)
192
  lora_names = append_module.create_split_names(args.split_lora_networks, args.split_lora_level)
193
+ append_module.replace_prepare_optimizer_params(network, network_module)
194
+ trainable_params, adafactor_scheduler_arg, used_names = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, lora_names, lr_dic, block_args_dic)
195
  else:
196
  trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
197
+ if split_flag:
198
+ _t_lr = 0.
199
+ _u_lr = 0.
200
+ if args.text_encoder_lr:
201
+ _t_lr = args.text_encoder_lr
202
+ if args.unet_lr:
203
+ _u_lr = args.unet_lr
204
+ adafactor_scheduler_arg = {"initial_lr": [_t_lr, _u_lr]}
205
+
206
+ optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
207
+ if args.use_lookahead:
208
+ try:
209
+ import torch_optimizer
210
+ lookahed_arg = {"k": 5, "alpha": 0.5}
211
+ if args.lookahead_arg is not None:
212
+ for _arg in args.lookahead_arg:
213
+ k, v = _arg.split("=")
214
+ if k == "k":
215
+ lookahed_arg[k] = int(v)
216
+ else:
217
+ lookahed_arg[k] = float(v)
218
+ optimizer = torch_optimizer.Lookahead(optimizer, **lookahed_arg)
219
+ except:
220
+ print("\n============\ntorch_optimizerのimportに失敗しました Lookaheadを無効化して処理を続けます\n============\n")
221
  # dataloaderを準備する
222
  # DataLoaderのプロセス数:0はメインプロセスになる
223
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
224
  train_dataloader = torch.utils.data.DataLoader(
225
+ train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
226
 
227
  # 学習ステップ数を計算する
228
  if args.max_train_epochs is not None:
229
+ args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes)
230
+ if is_main_process:
231
+ print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
232
 
233
  # lr schedulerを用意する
234
+ if args.lr_scheduler.startswith("adafactor") and split_flag:
235
+ lr_scheduler = append_module.get_scheduler_Adafactor(args.lr_scheduler, optimizer, adafactor_scheduler_arg)
 
 
 
 
236
  else:
237
+ lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
238
+ num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps,
239
+ num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
240
+
 
241
  #追加機能の設定をコメントに追記して残す
242
+ if args.use_lookahead:
243
+ args.training_comment=f"{args.training_comment} use Lookahead: True Lookahead args: {lookahed_arg}"
244
+ if args.split_lora_networks:
245
+ args.training_comment=f"{args.training_comment} split_lora_networks: {args.split_lora_networks} split_level: {args.split_lora_level}"
246
  if args.min_resolution:
247
  args.training_comment=f"{args.training_comment} min_resolution: {args.min_resolution} area_step: {args.area_step}"
248
 
 
312
  args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
313
 
314
  # 学習する
315
+ # TODO: find a way to handle total batch size when there are multiple datasets
316
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
317
+
318
+ if is_main_process:
319
+ print("running training / 学習開始")
320
+ print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
321
+ print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
322
+ print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
323
+ print(f" num epochs / epoch数: {num_train_epochs}")
324
+ print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
325
+ # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
326
+ print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
327
+ print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
328
+
329
+ # TODO refactor metadata creation and move to util
330
  metadata = {
331
  "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
332
  "ss_training_started_at": training_started_at, # unix timestamp
 
334
  "ss_learning_rate": args.learning_rate,
335
  "ss_text_encoder_lr": args.text_encoder_lr,
336
  "ss_unet_lr": args.unet_lr,
337
+ "ss_num_train_images": train_dataset_group.num_train_images,
338
+ "ss_num_reg_images": train_dataset_group.num_reg_images,
339
  "ss_num_batches_per_epoch": len(train_dataloader),
340
  "ss_num_epochs": num_train_epochs,
 
 
341
  "ss_gradient_checkpointing": args.gradient_checkpointing,
342
  "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
343
  "ss_max_train_steps": args.max_train_steps,
 
349
  "ss_mixed_precision": args.mixed_precision,
350
  "ss_full_fp16": bool(args.full_fp16),
351
  "ss_v2": bool(args.v2),
 
352
  "ss_clip_skip": args.clip_skip,
353
  "ss_max_token_length": args.max_token_length,
 
 
 
 
354
  "ss_cache_latents": bool(args.cache_latents),
 
 
 
355
  "ss_seed": args.seed,
356
+ "ss_lowram": args.lowram,
357
  "ss_noise_offset": args.noise_offset,
 
 
 
 
358
  "ss_training_comment": args.training_comment, # will not be updated after training
359
+ "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
360
+ "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
361
+ "ss_max_grad_norm": args.max_grad_norm,
362
+ "ss_caption_dropout_rate": args.caption_dropout_rate,
363
+ "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
364
+ "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
365
+ "ss_face_crop_aug_range": args.face_crop_aug_range,
366
+ "ss_prior_loss_weight": args.prior_loss_weight,
367
  }
368
 
369
+ if use_user_config:
370
+ # save metadata of multiple datasets
371
+ # NOTE: pack "ss_datasets" value as json one time
372
+ # or should also pack nested collections as json?
373
+ datasets_metadata = []
374
+ tag_frequency = {} # merge tag frequency for metadata editor
375
+ dataset_dirs_info = {} # merge subset dirs for metadata editor
376
+
377
+ for dataset in train_dataset_group.datasets:
378
+ is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
379
+ dataset_metadata = {
380
+ "is_dreambooth": is_dreambooth_dataset,
381
+ "batch_size_per_device": dataset.batch_size,
382
+ "num_train_images": dataset.num_train_images, # includes repeating
383
+ "num_reg_images": dataset.num_reg_images,
384
+ "resolution": (dataset.width, dataset.height),
385
+ "enable_bucket": bool(dataset.enable_bucket),
386
+ "min_bucket_reso": dataset.min_bucket_reso,
387
+ "max_bucket_reso": dataset.max_bucket_reso,
388
+ "tag_frequency": dataset.tag_frequency,
389
+ "bucket_info": dataset.bucket_info,
390
+ }
391
+
392
+ subsets_metadata = []
393
+ for subset in dataset.subsets:
394
+ subset_metadata = {
395
+ "img_count": subset.img_count,
396
+ "num_repeats": subset.num_repeats,
397
+ "color_aug": bool(subset.color_aug),
398
+ "flip_aug": bool(subset.flip_aug),
399
+ "random_crop": bool(subset.random_crop),
400
+ "shuffle_caption": bool(subset.shuffle_caption),
401
+ "keep_tokens": subset.keep_tokens,
402
+ }
403
+
404
+ image_dir_or_metadata_file = None
405
+ if subset.image_dir:
406
+ image_dir = os.path.basename(subset.image_dir)
407
+ subset_metadata["image_dir"] = image_dir
408
+ image_dir_or_metadata_file = image_dir
409
+
410
+ if is_dreambooth_dataset:
411
+ subset_metadata["class_tokens"] = subset.class_tokens
412
+ subset_metadata["is_reg"] = subset.is_reg
413
+ if subset.is_reg:
414
+ image_dir_or_metadata_file = None # not merging reg dataset
415
+ else:
416
+ metadata_file = os.path.basename(subset.metadata_file)
417
+ subset_metadata["metadata_file"] = metadata_file
418
+ image_dir_or_metadata_file = metadata_file # may overwrite
419
+
420
+ subsets_metadata.append(subset_metadata)
421
+
422
+ # merge dataset dir: not reg subset only
423
+ # TODO update additional-network extension to show detailed dataset config from metadata
424
+ if image_dir_or_metadata_file is not None:
425
+ # datasets may have a certain dir multiple times
426
+ v = image_dir_or_metadata_file
427
+ i = 2
428
+ while v in dataset_dirs_info:
429
+ v = image_dir_or_metadata_file + f" ({i})"
430
+ i += 1
431
+ image_dir_or_metadata_file = v
432
+
433
+ dataset_dirs_info[image_dir_or_metadata_file] = {
434
+ "n_repeats": subset.num_repeats,
435
+ "img_count": subset.img_count
436
+ }
437
+
438
+ dataset_metadata["subsets"] = subsets_metadata
439
+ datasets_metadata.append(dataset_metadata)
440
+
441
+ # merge tag frequency:
442
+ for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
443
+ # あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
444
+ # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
445
+ # なので、ここで複数datasetの回数を合算してもあまり意味はない
446
+ if ds_dir_name in tag_frequency:
447
+ continue
448
+ tag_frequency[ds_dir_name] = ds_freq_for_dir
449
+
450
+ metadata["ss_datasets"] = json.dumps(datasets_metadata)
451
+ metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
452
+ metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
453
+ else:
454
+ # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
455
+ assert len(
456
+ train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
457
+
458
+ dataset = train_dataset_group.datasets[0]
459
+
460
+ dataset_dirs_info = {}
461
+ reg_dataset_dirs_info = {}
462
+ if use_dreambooth_method:
463
+ for subset in dataset.subsets:
464
+ info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
465
+ info[os.path.basename(subset.image_dir)] = {
466
+ "n_repeats": subset.num_repeats,
467
+ "img_count": subset.img_count
468
+ }
469
+ else:
470
+ for subset in dataset.subsets:
471
+ dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
472
+ "n_repeats": subset.num_repeats,
473
+ "img_count": subset.img_count
474
+ }
475
+
476
+ metadata.update({
477
+ "ss_batch_size_per_device": args.train_batch_size,
478
+ "ss_total_batch_size": total_batch_size,
479
+ "ss_resolution": args.resolution,
480
+ "ss_color_aug": bool(args.color_aug),
481
+ "ss_flip_aug": bool(args.flip_aug),
482
+ "ss_random_crop": bool(args.random_crop),
483
+ "ss_shuffle_caption": bool(args.shuffle_caption),
484
+ "ss_enable_bucket": bool(dataset.enable_bucket),
485
+ "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
486
+ "ss_min_bucket_reso": dataset.min_bucket_reso,
487
+ "ss_max_bucket_reso": dataset.max_bucket_reso,
488
+ "ss_keep_tokens": args.keep_tokens,
489
+ "ss_dataset_dirs": json.dumps(dataset_dirs_info),
490
+ "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
491
+ "ss_tag_frequency": json.dumps(dataset.tag_frequency),
492
+ "ss_bucket_info": json.dumps(dataset.bucket_info),
493
+ })
494
+
495
+ # add extra args
496
+ if args.network_args:
497
+ metadata["ss_network_args"] = json.dumps(net_kwargs)
498
  # for key, value in net_kwargs.items():
499
  # metadata["ss_arg_" + key] = value
500
 
501
+ # model name and hash
502
  if args.pretrained_model_name_or_path is not None:
503
  sd_model_name = args.pretrained_model_name_or_path
504
  if os.path.exists(sd_model_name):
 
517
 
518
  metadata = {k: str(v) for k, v in metadata.items()}
519
 
520
+ # make minimum metadata for filtering
521
+ minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"]
522
+ minimum_metadata = {}
523
+ for key in minimum_keys:
524
+ if key in metadata:
525
+ minimum_metadata[key] = metadata[key]
526
+
527
  progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
528
  global_step = 0
529
 
 
536
  loss_list = []
537
  loss_total = 0.0
538
  for epoch in range(num_train_epochs):
539
+ if is_main_process:
540
+ print(f"epoch {epoch+1}/{num_train_epochs}")
541
+ train_dataset_group.set_current_epoch(epoch + 1)
542
 
543
  metadata["ss_epoch"] = str(epoch+1)
544
 
 
575
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
576
 
577
  # Predict the noise residual
578
+ with accelerator.autocast():
579
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
580
 
581
  if args.v_parameterization:
 
593
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
594
 
595
  accelerator.backward(loss)
596
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
597
  params_to_clip = network.get_trainable_params()
598
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
599
 
600
  optimizer.step()
601
+ if accelerator.sync_gradients:
602
+ lr_scheduler.step()
603
  optimizer.zero_grad(set_to_none=True)
604
 
605
  # Checks if the accelerator has performed an optimization step behind the scenes
 
607
  progress_bar.update(1)
608
  global_step += 1
609
 
610
+ train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
611
+
612
  current_loss = loss.detach().item()
613
  if epoch == 0:
614
  loss_list.append(current_loss)
 
621
  progress_bar.set_postfix(**logs)
622
 
623
  if args.logging_dir is not None:
624
+ logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, used_names)
625
  accelerator.log(logs, step=global_step)
626
 
627
  if global_step >= args.max_train_steps:
 
639
  def save_func():
640
  ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
641
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
642
+ metadata["ss_training_finished_at"] = str(time.time())
643
  print(f"saving checkpoint: {ckpt_file}")
644
+ unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
645
 
646
  def remove_old_func(old_epoch_no):
647
  old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
 
650
  print(f"removing old checkpoint: {old_ckpt_file}")
651
  os.remove(old_ckpt_file)
652
 
653
+ if is_main_process:
654
+ saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
655
+ if saving and args.save_state:
656
+ train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
657
+
658
+ train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
659
 
660
  # end of epoch
661
 
662
  metadata["ss_epoch"] = str(num_train_epochs)
663
+ metadata["ss_training_finished_at"] = str(time.time())
664
 
 
665
  if is_main_process:
666
  network = unwrap_model(network)
667
 
 
680
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
681
 
682
  print(f"save trained model to {ckpt_file}")
683
+ network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
684
  print("model saved.")
685
 
686
 
 
690
  train_util.add_sd_models_arguments(parser)
691
  train_util.add_dataset_arguments(parser, True, True, True)
692
  train_util.add_training_arguments(parser, True)
693
+ train_util.add_optimizer_arguments(parser)
694
+ config_util.add_config_arguments(parser)
695
 
696
  parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
697
  parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
 
699
 
700
  parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
701
  parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
 
 
 
 
702
 
703
  parser.add_argument("--network_weights", type=str, default=None,
704
  help="pretrained weights for network / 学習するネットワークの初期重み")
 
718
  #Optimizer変更関連のオプション追加
719
  append_module.add_append_arguments(parser)
720
  args = append_module.get_config(parser)
721
+ if not args.not_output_config:
722
+ #argsを保存する
723
+ import yaml
724
+ import datetime
725
+ _t = datetime.datetime.today().strftime('%Y%m%d_%H%M')
726
+ if args.output_name==None:
727
+ config_name = f"train_network_config_{_t}.yaml"
728
+ else:
729
+ config_name = f"train_network_config_{os.path.basename(args.output_name)}_{_t}.yaml"
730
+ print(f"{config_name} に設定を書き出し中...")
731
+ with open(config_name, mode="w") as f:
732
+ yaml.dump(args.__dict__, f, indent=4)
733
 
734
  if args.resolution==args.min_resolution:
735
  args.min_resolution=None
736
 
737
  train(args)
738
+ print("done!")
739
 
 
 
 
 
 
 
 
 
 
 
 
 
740
 
741
  '''
742
  optimizer設定メモ
743
+ torch_optimizer.AdaBelief
744
+ adastand.Adastand
745
  (optimizer_argから設定できるように変更するためのメモ)
746
 
747
  AdamWのweight_decay初期値は1e-2
 
771
  transformerベースのT5学習において最強とかいう噂のoptimizer
772
  huggingfaceのサンプルパラ
773
  eps=1e-30,1e-3 clip_threshold=1.0 decay_rate=-0.8 relative_step=False scale_parameter=False warmup_init=False
774
+ epsの二つ目の値1e-3が学習率に影響大きい
775
 
776
  AggMo
777
 
train_textual_inversion.py CHANGED
@@ -11,7 +11,11 @@ import diffusers
11
  from diffusers import DDPMScheduler
12
 
13
  import library.train_util as train_util
14
- from library.train_util import DreamBoothDataset, FineTuningDataset
 
 
 
 
15
 
16
  imagenet_templates_small = [
17
  "a photo of a {}",
@@ -79,7 +83,6 @@ def train(args):
79
  train_util.prepare_dataset_args(args, True)
80
 
81
  cache_latents = args.cache_latents
82
- use_dreambooth_method = args.in_json is None
83
 
84
  if args.seed is not None:
85
  set_seed(args.seed)
@@ -139,21 +142,35 @@ def train(args):
139
  print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
140
 
141
  # データセットを準備する
142
- if use_dreambooth_method:
143
- print("Use DreamBooth method.")
144
- train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
145
- tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
146
- args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
147
- args.bucket_reso_steps, args.bucket_no_upscale,
148
- args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
149
  else:
150
- print("Train with captions.")
151
- train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
152
- tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
153
- args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
154
- args.bucket_reso_steps, args.bucket_no_upscale,
155
- args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
156
- args.dataset_repeats, args.debug_dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
159
  if use_template:
@@ -163,20 +180,30 @@ def train(args):
163
  captions = []
164
  for tmpl in templates:
165
  captions.append(tmpl.format(replace_to))
166
- train_dataset.add_replacement("", captions)
167
- elif args.num_vectors_per_token > 1:
168
- replace_to = " ".join(token_strings)
169
- train_dataset.add_replacement(args.token_string, replace_to)
170
 
171
- train_dataset.make_buckets()
 
 
 
 
 
 
 
 
 
 
172
 
173
  if args.debug_dataset:
174
- train_util.debug_dataset(train_dataset, show_input_ids=True)
175
  return
176
- if len(train_dataset) == 0:
177
  print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
178
  return
179
 
 
 
 
180
  # モデルに xformers とか memory efficient attention を組み込む
181
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
182
 
@@ -186,7 +213,7 @@ def train(args):
186
  vae.requires_grad_(False)
187
  vae.eval()
188
  with torch.no_grad():
189
- train_dataset.cache_latents(vae)
190
  vae.to("cpu")
191
  if torch.cuda.is_available():
192
  torch.cuda.empty_cache()
@@ -198,35 +225,14 @@ def train(args):
198
 
199
  # 学習に必要なクラスを準備する
200
  print("prepare optimizer, data loader etc.")
201
-
202
- # 8-bit Adamを使う
203
- if args.use_8bit_adam:
204
- try:
205
- import bitsandbytes as bnb
206
- except ImportError:
207
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
208
- print("use 8-bit Adam optimizer")
209
- optimizer_class = bnb.optim.AdamW8bit
210
- elif args.use_lion_optimizer:
211
- try:
212
- import lion_pytorch
213
- except ImportError:
214
- raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
215
- print("use Lion optimizer")
216
- optimizer_class = lion_pytorch.Lion
217
- else:
218
- optimizer_class = torch.optim.AdamW
219
-
220
  trainable_params = text_encoder.get_input_embeddings().parameters()
221
-
222
- # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
223
- optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
224
 
225
  # dataloaderを準備する
226
  # DataLoaderのプロセス数:0はメインプロセスになる
227
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
228
  train_dataloader = torch.utils.data.DataLoader(
229
- train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
230
 
231
  # 学習ステップ数を計算する
232
  if args.max_train_epochs is not None:
@@ -234,8 +240,9 @@ def train(args):
234
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
235
 
236
  # lr schedulerを用意する
237
- lr_scheduler = diffusers.optimization.get_scheduler(
238
- args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
 
239
 
240
  # acceleratorがなんかよろしくやってくれるらしい
241
  text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
@@ -283,8 +290,8 @@ def train(args):
283
  # 学習する
284
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
285
  print("running training / 学習開始")
286
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
287
- print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
288
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
289
  print(f" num epochs / epoch数: {num_train_epochs}")
290
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
@@ -303,12 +310,11 @@ def train(args):
303
 
304
  for epoch in range(num_train_epochs):
305
  print(f"epoch {epoch+1}/{num_train_epochs}")
306
- train_dataset.set_current_epoch(epoch + 1)
307
 
308
  text_encoder.train()
309
 
310
  loss_total = 0
311
- bef_epo_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
312
  for step, batch in enumerate(train_dataloader):
313
  with accelerator.accumulate(text_encoder):
314
  with torch.no_grad():
@@ -357,9 +363,9 @@ def train(args):
357
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
358
 
359
  accelerator.backward(loss)
360
- if accelerator.sync_gradients:
361
  params_to_clip = text_encoder.get_input_embeddings().parameters()
362
- accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
363
 
364
  optimizer.step()
365
  lr_scheduler.step()
@@ -374,9 +380,14 @@ def train(args):
374
  progress_bar.update(1)
375
  global_step += 1
376
 
 
 
 
377
  current_loss = loss.detach().item()
378
  if args.logging_dir is not None:
379
- logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
 
 
380
  accelerator.log(logs, step=global_step)
381
 
382
  loss_total += current_loss
@@ -394,8 +405,6 @@ def train(args):
394
  accelerator.wait_for_everyone()
395
 
396
  updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
397
- # d = updated_embs - bef_epo_embs
398
- # print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
399
 
400
  if args.save_every_n_epochs is not None:
401
  model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
@@ -417,6 +426,9 @@ def train(args):
417
  if saving and args.save_state:
418
  train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
419
 
 
 
 
420
  # end of epoch
421
 
422
  is_main_process = accelerator.is_main_process
@@ -491,6 +503,8 @@ if __name__ == '__main__':
491
  train_util.add_sd_models_arguments(parser)
492
  train_util.add_dataset_arguments(parser, True, True, False)
493
  train_util.add_training_arguments(parser, True)
 
 
494
 
495
  parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
496
  help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
 
11
  from diffusers import DDPMScheduler
12
 
13
  import library.train_util as train_util
14
+ import library.config_util as config_util
15
+ from library.config_util import (
16
+ ConfigSanitizer,
17
+ BlueprintGenerator,
18
+ )
19
 
20
  imagenet_templates_small = [
21
  "a photo of a {}",
 
83
  train_util.prepare_dataset_args(args, True)
84
 
85
  cache_latents = args.cache_latents
 
86
 
87
  if args.seed is not None:
88
  set_seed(args.seed)
 
142
  print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
143
 
144
  # データセットを準備する
145
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
146
+ if args.dataset_config is not None:
147
+ print(f"Load dataset config from {args.dataset_config}")
148
+ user_config = config_util.load_user_config(args.dataset_config)
149
+ ignored = ["train_data_dir", "reg_data_dir", "in_json"]
150
+ if any(getattr(args, attr) is not None for attr in ignored):
151
+ print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
152
  else:
153
+ use_dreambooth_method = args.in_json is None
154
+ if use_dreambooth_method:
155
+ print("Use DreamBooth method.")
156
+ user_config = {
157
+ "datasets": [{
158
+ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
159
+ }]
160
+ }
161
+ else:
162
+ print("Train with captions.")
163
+ user_config = {
164
+ "datasets": [{
165
+ "subsets": [{
166
+ "image_dir": args.train_data_dir,
167
+ "metadata_file": args.in_json,
168
+ }]
169
+ }]
170
+ }
171
+
172
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
173
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
174
 
175
  # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
176
  if use_template:
 
180
  captions = []
181
  for tmpl in templates:
182
  captions.append(tmpl.format(replace_to))
183
+ train_dataset_group.add_replacement("", captions)
 
 
 
184
 
185
+ if args.num_vectors_per_token > 1:
186
+ prompt_replacement = (args.token_string, replace_to)
187
+ else:
188
+ prompt_replacement = None
189
+ else:
190
+ if args.num_vectors_per_token > 1:
191
+ replace_to = " ".join(token_strings)
192
+ train_dataset_group.add_replacement(args.token_string, replace_to)
193
+ prompt_replacement = (args.token_string, replace_to)
194
+ else:
195
+ prompt_replacement = None
196
 
197
  if args.debug_dataset:
198
+ train_util.debug_dataset(train_dataset_group, show_input_ids=True)
199
  return
200
+ if len(train_dataset_group) == 0:
201
  print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
202
  return
203
 
204
+ if cache_latents:
205
+ assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
206
+
207
  # モデルに xformers とか memory efficient attention を組み込む
208
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
209
 
 
213
  vae.requires_grad_(False)
214
  vae.eval()
215
  with torch.no_grad():
216
+ train_dataset_group.cache_latents(vae)
217
  vae.to("cpu")
218
  if torch.cuda.is_available():
219
  torch.cuda.empty_cache()
 
225
 
226
  # 学習に必要なクラスを準備する
227
  print("prepare optimizer, data loader etc.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  trainable_params = text_encoder.get_input_embeddings().parameters()
229
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params)
 
 
230
 
231
  # dataloaderを準備する
232
  # DataLoaderのプロセス数:0はメインプロセスになる
233
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
234
  train_dataloader = torch.utils.data.DataLoader(
235
+ train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
236
 
237
  # 学習ステップ数を計算する
238
  if args.max_train_epochs is not None:
 
240
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
241
 
242
  # lr schedulerを用意する
243
+ lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
244
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
245
+ num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
246
 
247
  # acceleratorがなんかよろしくやってくれるらしい
248
  text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
 
290
  # 学習する
291
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
292
  print("running training / 学習開始")
293
+ print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
294
+ print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
295
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
296
  print(f" num epochs / epoch数: {num_train_epochs}")
297
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
 
310
 
311
  for epoch in range(num_train_epochs):
312
  print(f"epoch {epoch+1}/{num_train_epochs}")
313
+ train_dataset_group.set_current_epoch(epoch + 1)
314
 
315
  text_encoder.train()
316
 
317
  loss_total = 0
 
318
  for step, batch in enumerate(train_dataloader):
319
  with accelerator.accumulate(text_encoder):
320
  with torch.no_grad():
 
363
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
364
 
365
  accelerator.backward(loss)
366
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
367
  params_to_clip = text_encoder.get_input_embeddings().parameters()
368
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
369
 
370
  optimizer.step()
371
  lr_scheduler.step()
 
380
  progress_bar.update(1)
381
  global_step += 1
382
 
383
+ train_util.sample_images(accelerator, args, None, global_step, accelerator.device,
384
+ vae, tokenizer, text_encoder, unet, prompt_replacement)
385
+
386
  current_loss = loss.detach().item()
387
  if args.logging_dir is not None:
388
+ logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
389
+ if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
390
+ logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
391
  accelerator.log(logs, step=global_step)
392
 
393
  loss_total += current_loss
 
405
  accelerator.wait_for_everyone()
406
 
407
  updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
 
 
408
 
409
  if args.save_every_n_epochs is not None:
410
  model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
 
426
  if saving and args.save_state:
427
  train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
428
 
429
+ train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device,
430
+ vae, tokenizer, text_encoder, unet, prompt_replacement)
431
+
432
  # end of epoch
433
 
434
  is_main_process = accelerator.is_main_process
 
503
  train_util.add_sd_models_arguments(parser)
504
  train_util.add_dataset_arguments(parser, True, True, False)
505
  train_util.add_training_arguments(parser, True)
506
+ train_util.add_optimizer_arguments(parser)
507
+ config_util.add_config_arguments(parser)
508
 
509
  parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
510
  help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
train_ti_README-ja.md ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [Textual Inversion](https://textual-inversion.github.io/) の学習についての説明です。
2
+
3
+ [学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。
4
+
5
+ 実装に当たっては https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion を大いに参考にしました。
6
+
7
+ 学習したモデルはWeb UIでもそのまま使えます。なお恐らくSD2.xにも対応していますが現時点では未テストです。
8
+
9
+ # 学習の手順
10
+
11
+ あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。
12
+
13
+ ## データの準備
14
+
15
+ [学習データの準備について](./train_README-ja.md) を参照してください。
16
+
17
+ ## 学習の実行
18
+
19
+ ``train_textual_inversion.py`` を用います。以下はコマンドラインの例です(DreamBooth手法)。
20
+
21
+ ```
22
+ accelerate launch --num_cpu_threads_per_process 1 train_textual_inversion.py
23
+ --dataset_config=<データ準備で作成した.tomlファイル>
24
+ --output_dir=<学習したモデルの出力先フォルダ>
25
+ --output_name=<学習したモデル出力時のファイル名>
26
+ --save_model_as=safetensors
27
+ --prior_loss_weight=1.0
28
+ --max_train_steps=1600
29
+ --learning_rate=1e-6
30
+ --optimizer_type="AdamW8bit"
31
+ --xformers
32
+ --mixed_precision="fp16"
33
+ --cache_latents
34
+ --gradient_checkpointing
35
+ --token_string=mychar4 --init_word=cute --num_vectors_per_token=4
36
+ ```
37
+
38
+ ``--token_string`` に学習時のトークン文字列を指定します。__学習時のプロンプトは、この文字列を含むようにしてください(token_stringがmychar4なら、``mychar4 1girl`` など)__。プロンプトのこの文字列の部分が、Textual Inversionの新しいtokenに置換されて学習されます。DreamBooth, class+identifier形式のデータセットとして、`token_string` をトークン文字列にするのが最も簡単で確実です。
39
+
40
+ プロンプトにトークン文字列が含まれているかどうかは、``--debug_dataset`` で置換後のtoken idが表示されますので、以下のように ``49408`` 以降のtokenが存在するかどうかで確認できます。
41
+
42
+ ```
43
+ input ids: tensor([[49406, 49408, 49409, 49410, 49411, 49412, 49413, 49414, 49415, 49407,
44
+ 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
45
+ 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
46
+ 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
47
+ 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
48
+ 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49
+ 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
50
+ 49407, 49407, 49407, 49407, 49407, 49407, 49407]])
51
+ ```
52
+
53
+ tokenizerがすでに持っている単語(一般的な単語)は使用できません。
54
+
55
+ ``--init_word`` にembeddingsを初期化するときのコピー元トークンの文字列を指定します。学ばせたい概念が近いものを選ぶとよいようです。二つ以上のトークンになる文字列は指定できません。
56
+
57
+ ``--num_vectors_per_token`` にいくつのトークンをこの学習で使うかを指定します。多いほうが表現力が増しますが、その分多くのトークンを消費します。たとえばnum_vectors_per_token=8の場合、指定したトークン文字列は(一般的なプロンプトの77トークン制限のうち)8トークンを消費します。
58
+
59
+ 以上がTextual Inversionのための主なオプションです。以降は他の学習スクリプトと同様です。
60
+
61
+ `num_cpu_threads_per_process` には通常は1を指定するとよいようです。
62
+
63
+ `pretrained_model_name_or_path` に追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。
64
+
65
+ `output_dir` に学習後のモデルを保存するフォルダを指定します。`output_name` にモデルのファイル名を拡張子を除いて指定します。`save_model_as` でsafetensors形式での保存を指定しています。
66
+
67
+ `dataset_config` に `.toml` ファイルを指定します。ファイル内でのバッチサイズ指定は、当初はメモリ消費を抑えるために `1` としてください。
68
+
69
+ 学習させるステップ数 `max_train_steps` を10000とします。学習率 `learning_rate` はここでは5e-6を指定しています。
70
+
71
+ 省メモリ化のため `mixed_precision="fp16"` を指定します(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。また `gradient_checkpointing` を指定します。
72
+
73
+ オプティマイザ(���デルを学習データにあうように最適化=学習させるクラス)にメモリ消費の少ない 8bit AdamW を使うため、 `optimizer_type="AdamW8bit"` を指定します。
74
+
75
+ `xformers` オプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。
76
+
77
+ ある程度メモリがある場合は、`.toml` ファイルを編集してバッチサイズをたとえば `8` くらいに増やしてください(高速化と精度向上の可能性があります)。
78
+
79
+ ### よく使われるオプションについて
80
+
81
+ 以下の場合にはオプションに関するドキュメントを参照してください。
82
+
83
+ - Stable Diffusion 2.xまたはそこからの派生モデルを学習する
84
+ - clip skipを2以上を前提としたモデルを学習する
85
+ - 75トークンを超えたキャプションで学習する
86
+
87
+ ### Textual Inversionでのバッチサイズについて
88
+
89
+ モデル全体を学習するDreamBoothやfine tuningに比べてメモリ使用量が少ないため、バッチサイズは大きめにできます。
90
+
91
+ # Textual Inversionのその他の主なオプション
92
+
93
+ すべてのオプションについては別文書を参照してください。
94
+
95
+ * `--weights`
96
+ * 学習前に学習済みのembeddingsを読み込み、そこから追加で学習します。
97
+ * `--use_object_template`
98
+ * キャプションではなく既定の物体用テンプレート文字列(``a photo of a {}``など)で学習します。公式実装と同じになります。キャプションは無視されます。
99
+ * `--use_style_template`
100
+ * キャプションではなく既定のスタイル用テンプレート文字列で学習します(``a painting in the style of {}``など)。公式実装と同じになります。キャプションは無視されます。
101
+
102
+ ## 当リポジトリ内の画像生成スクリプトで生成する
103
+
104
+ gen_img_diffusers.pyに、``--textual_inversion_embeddings`` オプションで学習したembeddingsファイルを指定してください(複数可)。プロンプトでembeddingsファイルのファイル名(拡張子を除く)を使うと、そのembeddingsが適用されます。
105
+