abc commited on
Commit
94f2ce5
·
1 Parent(s): 07048a3

Upload 54 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .github/workflows/typos.yml +21 -0
  2. .gitignore +7 -0
  3. LICENSE.md +201 -0
  4. _typos.toml +15 -0
  5. adastand.py +291 -0
  6. append_module.py +504 -0
  7. build/lib/library/__init__.py +0 -0
  8. build/lib/library/model_util.py +1180 -0
  9. build/lib/library/train_util.py +1796 -0
  10. fine_tune.py +360 -0
  11. gen_img_diffusers.py +0 -0
  12. library.egg-info/PKG-INFO +4 -0
  13. library.egg-info/SOURCES.txt +10 -0
  14. library.egg-info/dependency_links.txt +1 -0
  15. library.egg-info/top_level.txt +1 -0
  16. library/__init__.py +0 -0
  17. library/__pycache__/__init__.cpython-310.pyc +0 -0
  18. library/__pycache__/model_util.cpython-310.pyc +0 -0
  19. library/__pycache__/train_util.cpython-310.pyc +0 -0
  20. library/model_util.py +1180 -0
  21. library/train_util.py +1796 -0
  22. locon/__init__.py +0 -0
  23. locon/kohya_model_utils.py +1184 -0
  24. locon/kohya_utils.py +48 -0
  25. locon/locon.py +53 -0
  26. locon/locon_kohya.py +243 -0
  27. locon/utils.py +148 -0
  28. lora_train_popup.py +862 -0
  29. lycoris/__init__.py +8 -0
  30. lycoris/kohya.py +276 -0
  31. lycoris/kohya_model_utils.py +1184 -0
  32. lycoris/kohya_utils.py +48 -0
  33. lycoris/locon.py +76 -0
  34. lycoris/loha.py +116 -0
  35. lycoris/utils.py +271 -0
  36. networks/__pycache__/lora.cpython-310.pyc +0 -0
  37. networks/check_lora_weights.py +32 -0
  38. networks/extract_lora_from_models.py +164 -0
  39. networks/lora.py +237 -0
  40. networks/lora_interrogator.py +122 -0
  41. networks/merge_lora.py +212 -0
  42. networks/merge_lora_old.py +179 -0
  43. networks/resize_lora.py +198 -0
  44. networks/svd_merge_lora.py +164 -0
  45. requirements.txt +25 -0
  46. requirements_startup.txt +23 -0
  47. setup.py +3 -0
  48. tools/convert_diffusers20_original_sd.py +89 -0
  49. tools/detect_face_rotate.py +239 -0
  50. tools/resize_images_to_resolution.py +122 -0
.github/workflows/typos.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # yamllint disable rule:line-length
3
+ name: Typos
4
+
5
+ on: # yamllint disable-line rule:truthy
6
+ push:
7
+ pull_request:
8
+ types:
9
+ - opened
10
+ - synchronize
11
+ - reopened
12
+
13
+ jobs:
14
+ build:
15
+ runs-on: ubuntu-latest
16
+
17
+ steps:
18
+ - uses: actions/checkout@v3
19
+
20
+ - name: typos-action
21
+ uses: crate-ci/[email protected]
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ logs
2
+ __pycache__
3
+ wd14_tagger_model
4
+ venv
5
+ *.egg-info
6
+ build
7
+ .vscode
LICENSE.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2022] [kohya-ss]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
_typos.toml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Files for typos
2
+ # Instruction: https://github.com/marketplace/actions/typos-action#getting-started
3
+
4
+ [default.extend-identifiers]
5
+
6
+ [default.extend-words]
7
+ NIN="NIN"
8
+ parms="parms"
9
+ nin="nin"
10
+ extention="extention" # Intentionally left
11
+ nd="nd"
12
+
13
+
14
+ [files]
15
+ extend-exclude = ["_typos.toml"]
adastand.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ def __version__():
4
+ return 0.5
5
+ #######################################################################################
6
+ #NTT提案のAdastand Optimizer NTTを信用できるならAdamより少し性能高い(2019)
7
+ #参考コード:https://github.com/bunag-public/adastand_pack/
8
+ #似た計算式のAdaBeliefがAdamWと同じweight_decayの計算を導入していたのでAdamWのweight_decay式を使えるように
9
+ class Adastand(torch.optim.Optimizer):
10
+ """Implements Adastand algorithm.
11
+ Arguments:
12
+ params (iterable): iterable of parameters to optimize or dicts defining
13
+ parameter groups
14
+ lr (float, optional): learning rate (default: 1e-3)
15
+ betas (Tuple[float, float], optional): coefficients used for computing
16
+ running averages of gradient and its square (default: (0.9, 0.999))
17
+ eps (float, optional): term added to the denominator to improve
18
+ numerical stability (default: 1e-8)
19
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
20
+ weight_decouple (bool, optional): if True is weight decay as in AdamW
21
+ """
22
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
23
+ weight_decay=0, weight_decouple=False, fixed_decay=False, amsgrad=False):
24
+ if not 0.0 <= lr:
25
+ raise ValueError("Invalid learning rate: {}".format(lr))
26
+ if not 0.0 <= eps:
27
+ raise ValueError("Invalid epsilon value: {}".format(eps))
28
+ if not 0.0 <= betas[0] < 1.0:
29
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
30
+ if not 0.0 <= betas[1] < 1.0:
31
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
32
+ defaults = dict(lr=lr, betas=betas, eps=eps,
33
+ weight_decay=weight_decay, weight_decouple=weight_decouple, fixed_decay=fixed_decay, amsgrad=amsgrad)
34
+ super(Adastand, self).__init__(params, defaults)
35
+
36
+ def __setstate__(self, state):
37
+ super(Adastand, self).__setstate__(state)
38
+
39
+ def step(self, closure=None):
40
+ """Performs a single optimization step.
41
+ Arguments:
42
+ closure (callable, optional): A closure that reevaluates the model
43
+ and returns the loss.
44
+ """
45
+ loss = None
46
+ if closure is not None:
47
+ loss = closure()
48
+
49
+ for group in self.param_groups:
50
+ for p in group['params']:
51
+ if p.grad is None:
52
+ continue
53
+ grad = p.grad.data
54
+ if grad.is_sparse:
55
+ raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
56
+ weight_decouple = group['weight_decouple']
57
+ fixed_decay = group['fixed_decay']
58
+ amsgrad = group['amsgrad']
59
+
60
+ state = self.state[p]
61
+
62
+ # State initialization
63
+ if len(state) == 0:
64
+ state['step'] = 0
65
+ # Exponential moving average of gradient values
66
+ state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
67
+ # Exponential moving average of squared gradient values
68
+ state['exp_avg_sq'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
69
+ if amsgrad:
70
+ # Maintains max of all exp. moving avg. of
71
+ # sq. grad. values
72
+ state['exp_avg_sqs'] = torch.zeros_like(
73
+ p.data, memory_format=torch.preserve_format
74
+ )
75
+
76
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
77
+ beta1, beta2 = group['betas']
78
+
79
+ state['step'] += 1
80
+
81
+ if weight_decouple:
82
+ if not fixed_decay:
83
+ p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
84
+ else:
85
+ p.data.mul_(1.0 - group['weight_decay'])
86
+ else:
87
+ if group['weight_decay'] != 0:
88
+ grad.add_(p.data, alpha=group['weight_decay'])
89
+
90
+ # Decay the first and second moment running average coefficient
91
+ grad_residual = grad - exp_avg
92
+ exp_avg_sq.mul_(beta2).addcmul_(grad_residual, grad_residual, value=beta2 * (1 - beta2))
93
+ exp_avg.mul_(2 * beta1 - 1).add_(grad, alpha=1 - beta1)
94
+
95
+ bias_correction1 = 1 - beta1 ** state['step']
96
+ bias_correction2 = 1 - beta2 ** state['step']
97
+ if amsgrad:
98
+ exp_avg_sqs = state['exp_avg_sqs']
99
+ torch.max(exp_avg_sqs, exp_avg_sq, out=exp_avg_sqs)
100
+ denom = exp_avg_sqs.sqrt().add_(group['eps']/ math.sqrt(bias_correction2))
101
+ else:
102
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
103
+
104
+ step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
105
+
106
+ p.data.addcdiv_(exp_avg, denom, value=-step_size)
107
+ #p.data.addcdiv_(-step_size, exp_avg, denom)
108
+
109
+ return loss
110
+ #######################################################################################
111
+ class Adastand_b(torch.optim.Optimizer):
112
+ """Implements Adastand algorithm.
113
+ Arguments:
114
+ params (iterable): iterable of parameters to optimize or dicts defining
115
+ parameter groups
116
+ lr (float, optional): learning rate (default: 1e-3)
117
+ betas (Tuple[float, float], optional): coefficients used for computing
118
+ running averages of gradient and its square (default: (0.9, 0.999))
119
+ eps (float, optional): term added to the denominator to improve
120
+ numerical stability (default: 1e-8)
121
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
122
+ weight_decouple (bool, optional): if True is weight decay as in AdamW
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ params,
128
+ lr: float = 1e-3,
129
+ betas = (0.9, 0.999),
130
+ eps: float = 1e-8,
131
+ weight_decay: float = 0,
132
+ amsgrad: bool = False,
133
+ weight_decouple: bool = False,
134
+ fixed_decay: bool = False,
135
+ rectify: bool = False,
136
+ ) -> None:
137
+ if lr <= 0.0:
138
+ raise ValueError('Invalid learning rate: {}'.format(lr))
139
+ if eps < 0.0:
140
+ raise ValueError('Invalid epsilon value: {}'.format(eps))
141
+ if not 0.0 <= betas[0] < 1.0:
142
+ raise ValueError(
143
+ 'Invalid beta parameter at index 0: {}'.format(betas[0])
144
+ )
145
+ if not 0.0 <= betas[1] < 1.0:
146
+ raise ValueError(
147
+ 'Invalid beta parameter at index 1: {}'.format(betas[1])
148
+ )
149
+ if weight_decay < 0:
150
+ raise ValueError(
151
+ 'Invalid weight_decay value: {}'.format(weight_decay)
152
+ )
153
+ defaults = dict(
154
+ lr=lr,
155
+ betas=betas,
156
+ eps=eps,
157
+ weight_decay=weight_decay,
158
+ amsgrad=amsgrad,
159
+ )
160
+ super(Adastand_b, self).__init__(params, defaults)
161
+
162
+ self._weight_decouple = weight_decouple
163
+ self._rectify = rectify
164
+ self._fixed_decay = fixed_decay
165
+
166
+ def __setstate__(self, state):
167
+ super(Adastand_b, self).__setstate__(state)
168
+ for group in self.param_groups:
169
+ group.setdefault('amsgrad', False)
170
+
171
+ def step(self, closure=None):
172
+ r"""Performs a single optimization step.
173
+
174
+ Arguments:
175
+ closure: A closure that reevaluates the model and returns the loss.
176
+ """
177
+ loss = None
178
+ if closure is not None:
179
+ loss = closure()
180
+
181
+ for group in self.param_groups:
182
+ for p in group['params']:
183
+ if p.grad is None:
184
+ continue
185
+ grad = p.grad.data
186
+ if grad.is_sparse:
187
+ raise RuntimeError(
188
+ 'AdaBelief does not support sparse gradients, '
189
+ 'please consider SparseAdam instead'
190
+ )
191
+ amsgrad = group['amsgrad']
192
+
193
+ state = self.state[p]
194
+
195
+ beta1, beta2 = group['betas']
196
+
197
+ # State initialization
198
+ if len(state) == 0:
199
+ state['rho_inf'] = 2.0 / (1.0 - beta2) - 1.0
200
+ state['step'] = 0
201
+ # Exponential moving average of gradient values
202
+ state['exp_avg'] = torch.zeros_like(
203
+ p.data, memory_format=torch.preserve_format
204
+ )
205
+ # Exponential moving average of squared gradient values
206
+ state['exp_avg_var'] = torch.zeros_like(
207
+ p.data, memory_format=torch.preserve_format
208
+ )
209
+ if amsgrad:
210
+ # Maintains max of all exp. moving avg. of
211
+ # sq. grad. values
212
+ state['max_exp_avg_var'] = torch.zeros_like(
213
+ p.data, memory_format=torch.preserve_format
214
+ )
215
+
216
+ # get current state variable
217
+ exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
218
+
219
+ state['step'] += 1
220
+ bias_correction1 = 1 - beta1 ** state['step']
221
+ bias_correction2 = 1 - beta2 ** state['step']
222
+
223
+ # perform weight decay, check if decoupled weight decay
224
+ if self._weight_decouple:
225
+ if not self._fixed_decay:
226
+ p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
227
+ else:
228
+ p.data.mul_(1.0 - group['weight_decay'])
229
+ else:
230
+ if group['weight_decay'] != 0:
231
+ grad.add_(p.data, alpha=group['weight_decay'])
232
+
233
+ # Update first and second moment running average
234
+ exp_avg.mul_(2*beta1-1).add_(grad, alpha=1 - beta1)
235
+ grad_residual = grad - exp_avg
236
+ exp_avg_var.mul_(beta2).addcmul_(
237
+ grad_residual, grad_residual, value=beta2*(1 - beta2)
238
+ )
239
+
240
+ if amsgrad:
241
+ max_exp_avg_var = state['max_exp_avg_var']
242
+ # Maintains the maximum of all 2nd moment running
243
+ # avg. till now
244
+ torch.max(
245
+ max_exp_avg_var, exp_avg_var, out=max_exp_avg_var
246
+ )
247
+
248
+ # Use the max. for normalizing running avg. of gradient
249
+ denom = (
250
+ max_exp_avg_var.add_(group['eps']).sqrt()
251
+ / math.sqrt(bias_correction2)
252
+ ).add_(group['eps'])
253
+ else:
254
+ denom = (
255
+ exp_avg_var.add_(group['eps']).sqrt()
256
+ / math.sqrt(bias_correction2)
257
+ ).add_(group['eps'])
258
+
259
+ if not self._rectify:
260
+ # Default update
261
+ step_size = group['lr']* math.sqrt(bias_correction2) / bias_correction1
262
+ p.data.addcdiv_(exp_avg, denom, value=-step_size)
263
+
264
+ else: # Rectified update
265
+ # calculate rho_t
266
+ state['rho_t'] = state['rho_inf'] - 2 * state[
267
+ 'step'
268
+ ] * beta2 ** state['step'] / (1.0 - beta2 ** state['step'])
269
+
270
+ if (
271
+ state['rho_t'] > 4
272
+ ): # perform Adam style update if variance is small
273
+ rho_inf, rho_t = state['rho_inf'], state['rho_t']
274
+ rt = (
275
+ (rho_t - 4.0)
276
+ * (rho_t - 2.0)
277
+ * rho_inf
278
+ / (rho_inf - 4.0)
279
+ / (rho_inf - 2.0)
280
+ / rho_t
281
+ )
282
+ rt = math.sqrt(rt)
283
+
284
+ step_size = rt * group['lr'] / bias_correction1
285
+
286
+ p.data.addcdiv_(-step_size, exp_avg, denom)
287
+
288
+ else: # perform SGD style update
289
+ p.data.add_(-group['lr'], exp_avg)
290
+
291
+ return loss
append_module.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import shutil
4
+ import time
5
+ from typing import Dict, List, NamedTuple, Tuple
6
+ from accelerate import Accelerator
7
+ from torch.autograd.function import Function
8
+ import glob
9
+ import math
10
+ import os
11
+ import random
12
+ import hashlib
13
+ from io import BytesIO
14
+
15
+ from tqdm import tqdm
16
+ import torch
17
+ from torchvision import transforms
18
+ from transformers import CLIPTokenizer
19
+ import diffusers
20
+ from diffusers import DDPMScheduler, StableDiffusionPipeline
21
+ import albumentations as albu
22
+ import numpy as np
23
+ from PIL import Image
24
+ import cv2
25
+ from einops import rearrange
26
+ from torch import einsum
27
+ import safetensors.torch
28
+
29
+ import library.model_util as model_util
30
+ import library.train_util as train_util
31
+
32
+ #============================================================================================================
33
+ #AdafactorScheduleに暫定的にinitial_lrを層別に適用できるようにしたもの
34
+ #============================================================================================================
35
+ from torch.optim.lr_scheduler import LambdaLR
36
+ class AdafactorSchedule_append(LambdaLR):
37
+ """
38
+ Since [`~optimization.Adafactor`] performs its own scheduling, if the training loop relies on a scheduler (e.g.,
39
+ for logging), this class creates a proxy object that retrieves the current lr values from the optimizer.
40
+
41
+ It returns `initial_lr` during startup and the actual `lr` during stepping.
42
+ """
43
+
44
+ def __init__(self, optimizer, initial_lr=0.0):
45
+ def lr_lambda(_):
46
+ return initial_lr
47
+
48
+ for group in optimizer.param_groups:
49
+ if not type(initial_lr)==list:
50
+ group["initial_lr"] = initial_lr
51
+ else:
52
+ group["initial_lr"] = initial_lr.pop(0)
53
+ super().__init__(optimizer, lr_lambda)
54
+ for group in optimizer.param_groups:
55
+ del group["initial_lr"]
56
+
57
+ def get_lr(self):
58
+ opt = self.optimizer
59
+ lrs = [
60
+ opt._get_lr(group, opt.state[group["params"][0]])
61
+ for group in opt.param_groups
62
+ if group["params"][0].grad is not None
63
+ ]
64
+ if len(lrs) == 0:
65
+ lrs = self.base_lrs # if called before stepping
66
+ return lrs
67
+
68
+ #============================================================================================================
69
+ #model_util 内より
70
+ #============================================================================================================
71
+ def make_bucket_resolutions_fix(max_reso, min_reso, min_size=256, max_size=1024, divisible=64, step=1):
72
+ max_width, max_height = max_reso
73
+ max_area = (max_width // divisible) * (max_height // divisible)
74
+
75
+ min_widht, min_height = min_reso
76
+ min_area = (min_widht // divisible) * (min_height // divisible)
77
+
78
+ area_size_list = []
79
+ area_size_resos_list = []
80
+ _max_area = max_area
81
+
82
+ while True:
83
+ resos = set()
84
+ size = int(math.sqrt(_max_area)) * divisible
85
+ resos.add((size, size))
86
+
87
+ size = min_size
88
+ while size <= max_size:
89
+ width = size
90
+ height = min(max_size, (_max_area // (width // divisible)) * divisible)
91
+ resos.add((width, height))
92
+ resos.add((height, width))
93
+
94
+ # # make additional resos
95
+ # if width >= height and width - divisible >= min_size:
96
+ # resos.add((width - divisible, height))
97
+ # resos.add((height, width - divisible))
98
+ # if height >= width and height - divisible >= min_size:
99
+ # resos.add((width, height - divisible))
100
+ # resos.add((height - divisible, width))
101
+
102
+ size += divisible
103
+
104
+ resos = list(resos)
105
+ resos.sort()
106
+
107
+ #aspect_ratios = [w / h for w, h in resos]
108
+ area_size_list.append(_max_area)
109
+ area_size_resos_list.append(resos)
110
+ #area_size_ratio_list.append(aspect_ratios)
111
+
112
+ _max_area -= step
113
+ if _max_area < min_area:
114
+ break
115
+ return area_size_resos_list, area_size_list
116
+
117
+ #============================================================================================================
118
+ #train_util 内より
119
+ #============================================================================================================
120
+ class BucketManager_append(train_util.BucketManager):
121
+ def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps, min_reso=None, area_step=None) -> None:
122
+ super().__init__(no_upscale, max_reso, min_size, max_size, reso_steps)
123
+ print("BucketManager_appendを作成しました")
124
+ if min_reso is None:
125
+ self.min_reso = None
126
+ self.min_area = None
127
+ else:
128
+ self.min_reso = min_reso
129
+ self.min_area = min_reso[0] * min_reso[1]
130
+ self.area_step = area_step
131
+ self.area_sizes_flag = False
132
+
133
+ def make_buckets(self):
134
+ if self.min_reso:
135
+ print(f"make_resolution append")
136
+ resos, area_sizes = make_bucket_resolutions_fix(self.max_reso, self.min_reso, self.min_size, self.max_size, self.reso_steps, self.area_step)
137
+ self.set_predefined_resos(resos, area_sizes)
138
+ else:
139
+ resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps)
140
+ self.set_predefined_resos(resos)
141
+
142
+ def set_predefined_resos(self, resos, area_sizes=None):
143
+ # 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
144
+ if area_sizes:
145
+ self.area_sizes_flag = True
146
+ self.predefined_area_sizes = np.array(area_sizes.copy())
147
+ self.predefined_resos_list = resos.copy()
148
+ self.predefined_resos_set_list = [set(reso) for reso in resos]
149
+ self.predefined_aspect_ratios_list = [np.array([w/h for w,h in reso]) for reso in resos]
150
+ self.predefined_resos = None
151
+ self.predefined_resos_set = None
152
+ self.predefined_aspect_ratios = None
153
+ else:
154
+ self.area_sizes_flag = False
155
+ self.predefined_area_sizes = None
156
+ self.predefined_resos = resos.copy()
157
+ self.predefined_resos_set = set(resos)
158
+ self.predefined_aspect_ratios = np.array([w / h for w, h in resos])
159
+
160
+ def select_bucket(self, image_width, image_height):
161
+ # 画像サイズを算出する
162
+ area_size = (image_width//64) * (image_height//64)
163
+ aspect_ratio = image_width / image_height
164
+ bucket_size_id = None
165
+ # 拡張したバケットサイズを使うために画像サイズのエリアを決定する
166
+ if self.area_sizes_flag:
167
+ size_errors = self.predefined_area_sizes - area_size
168
+ bucket_size_id = np.abs(size_errors).argmin()
169
+ #一定の範囲を探索して使用する画像サイズを確定する
170
+ serch_size_range = 1
171
+ bucket_size_id_list = [bucket_size_id]
172
+ for i in range(serch_size_range):
173
+ if bucket_size_id - i <0:
174
+ bucket_size_id_list.append(bucket_size_id + i + 1)
175
+ elif bucket_size_id + 1 + i >= len(self.predefined_resos_list):
176
+ bucket_size_id_list.append(bucket_size_id - i - 1)
177
+ else:
178
+ bucket_size_id_list.append(bucket_size_id - i - 1)
179
+ bucket_size_id_list.append(bucket_size_id + i + 1)
180
+ _min_error = 1000.
181
+ _min_id = bucket_size_id
182
+ for now_size_id in bucket_size_id:
183
+ self.predefined_aspect_ratios = self.predefined_aspect_ratios_list[now_size_id]
184
+ ar_errors = self.predefined_aspect_ratios - aspect_ratio
185
+ ar_error = np.abs(ar_errors).min()
186
+ if _min_error > ar_error:
187
+ _min_error = ar_error
188
+ _min_id = now_size_id
189
+ if _min_error == 0.:
190
+ break
191
+ bucket_size_id = _min_id
192
+ del _min_error, _min_id, ar_error #余計なものは掃除
193
+ self.predefined_resos = self.predefined_resos_list[bucket_size_id]
194
+ self.predefined_resos_set = self.predefined_resos_set_list[bucket_size_id]
195
+ self.predefined_aspect_ratios = self.predefined_aspect_ratios_list[bucket_size_id]
196
+ # --ここから処理はそのまま
197
+ if not self.no_upscale:
198
+ # 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する
199
+ reso = (image_width, image_height)
200
+ if reso in self.predefined_resos_set:
201
+ pass
202
+ else:
203
+ ar_errors = self.predefined_aspect_ratios - aspect_ratio
204
+ predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
205
+ reso = self.predefined_resos[predefined_bucket_id]
206
+
207
+ ar_reso = reso[0] / reso[1]
208
+ if aspect_ratio > ar_reso: # 横が長い→縦を合わせる
209
+ scale = reso[1] / image_height
210
+ else:
211
+ scale = reso[0] / image_width
212
+
213
+ resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
214
+ # print("use predef", image_width, image_height, reso, resized_size)
215
+ else:
216
+ if image_width * image_height > self.max_area:
217
+ # 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める
218
+ resized_width = math.sqrt(self.max_area * aspect_ratio)
219
+ resized_height = self.max_area / resized_width
220
+ assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal"
221
+
222
+ # リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ
223
+ # 元のbucketingと同じロジック
224
+ b_width_rounded = self.round_to_steps(resized_width)
225
+ b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio)
226
+ ar_width_rounded = b_width_rounded / b_height_in_wr
227
+
228
+ b_height_rounded = self.round_to_steps(resized_height)
229
+ b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio)
230
+ ar_height_rounded = b_width_in_hr / b_height_rounded
231
+
232
+ # print(b_width_rounded, b_height_in_wr, ar_width_rounded)
233
+ # print(b_width_in_hr, b_height_rounded, ar_height_rounded)
234
+
235
+ if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio):
236
+ resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + .5))
237
+ else:
238
+ resized_size = (int(b_height_rounded * aspect_ratio + .5), b_height_rounded)
239
+ # print(resized_size)
240
+ else:
241
+ resized_size = (image_width, image_height) # リサイズは不要
242
+
243
+ # 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする)
244
+ bucket_width = resized_size[0] - resized_size[0] % self.reso_steps
245
+ bucket_height = resized_size[1] - resized_size[1] % self.reso_steps
246
+ # print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height)
247
+
248
+ reso = (bucket_width, bucket_height)
249
+
250
+ self.add_if_new_reso(reso)
251
+
252
+ ar_error = (reso[0] / reso[1]) - aspect_ratio
253
+ return reso, resized_size, ar_error
254
+
255
+ class DreamBoothDataset(train_util.DreamBoothDataset):
256
+ def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset, min_resolution=None, area_step=None) -> None:
257
+ print("use append DreamBoothDataset")
258
+ self.min_resolution = min_resolution
259
+ self.area_step = area_step
260
+ super().__init__(batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens,
261
+ resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight,
262
+ flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
263
+ def make_buckets(self):
264
+ '''
265
+ bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
266
+ min_size and max_size are ignored when enable_bucket is False
267
+ '''
268
+ print("loading image sizes.")
269
+ for info in tqdm(self.image_data.values()):
270
+ if info.image_size is None:
271
+ info.image_size = self.get_image_size(info.absolute_path)
272
+
273
+ if self.enable_bucket:
274
+ print("make buckets")
275
+ else:
276
+ print("prepare dataset")
277
+
278
+ # bucketを作成し、画像をbucketに振り分ける
279
+ if self.enable_bucket:
280
+ if self.bucket_manager is None: # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み
281
+ #======================================================================change
282
+ if self.min_resolution:
283
+ self.bucket_manager = BucketManager_append(self.bucket_no_upscale, (self.width, self.height),
284
+ self.min_bucket_reso, self.max_bucket_reso, self.bucket_reso_steps, self.min_resolution, self.area_step)
285
+ else:
286
+ self.bucket_manager = train_util.BucketManager(self.bucket_no_upscale, (self.width, self.height),
287
+ self.min_bucket_reso, self.max_bucket_reso, self.bucket_reso_steps)
288
+ #======================================================================change
289
+ if not self.bucket_no_upscale:
290
+ self.bucket_manager.make_buckets()
291
+ else:
292
+ print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
293
+
294
+ img_ar_errors = []
295
+ for image_info in self.image_data.values():
296
+ image_width, image_height = image_info.image_size
297
+ image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(image_width, image_height)
298
+
299
+ # print(image_info.image_key, image_info.bucket_reso)
300
+ img_ar_errors.append(abs(ar_error))
301
+
302
+ self.bucket_manager.sort()
303
+ else:
304
+ self.bucket_manager = train_util.BucketManager(False, (self.width, self.height), None, None, None)
305
+ self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ
306
+ for image_info in self.image_data.values():
307
+ image_width, image_height = image_info.image_size
308
+ image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height)
309
+
310
+ for image_info in self.image_data.values():
311
+ for _ in range(image_info.num_repeats):
312
+ self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key)
313
+
314
+ # bucket情報を表示、格納する
315
+ if self.enable_bucket:
316
+ self.bucket_info = {"buckets": {}}
317
+ print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
318
+ for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)):
319
+ count = len(bucket)
320
+ if count > 0:
321
+ self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
322
+ print(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
323
+
324
+ img_ar_errors = np.array(img_ar_errors)
325
+ mean_img_ar_error = np.mean(np.abs(img_ar_errors))
326
+ self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
327
+ print(f"mean ar error (without repeats): {mean_img_ar_error}")
328
+
329
+ # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
330
+ self.buckets_indices: List(train_util.BucketBatchIndex) = []
331
+ for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
332
+ batch_count = int(math.ceil(len(bucket) / self.batch_size))
333
+ for batch_index in range(batch_count):
334
+ self.buckets_indices.append(train_util.BucketBatchIndex(bucket_index, self.batch_size, batch_index))
335
+
336
+ # ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す
337
+ #  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる
338
+ #
339
+ # # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
340
+ # # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
341
+ # # そのためバッチサイズを画像種類までに制限する
342
+ # # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
343
+ # # TO DO 正則化画像をepochまたがりで利用する仕組み
344
+ # num_of_image_types = len(set(bucket))
345
+ # bucket_batch_size = min(self.batch_size, num_of_image_types)
346
+ # batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
347
+ # # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
348
+ # for batch_index in range(batch_count):
349
+ # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
350
+ # ↑ここまで
351
+
352
+ self.shuffle_buckets()
353
+ self._length = len(self.buckets_indices)
354
+
355
+ class FineTuningDataset(train_util.FineTuningDataset):
356
+ def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
357
+ train_util.glob_images = glob_images
358
+ super().__init__( json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
359
+ resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range,
360
+ random_crop, dataset_repeats, debug_dataset)
361
+
362
+ def glob_images(directory, base="*", npz_flag=True):
363
+ img_paths = []
364
+ dots = []
365
+ for ext in train_util.IMAGE_EXTENSIONS:
366
+ dots.append(ext)
367
+ if npz_flag:
368
+ dots.append(".npz")
369
+ for ext in dots:
370
+ if base == '*':
371
+ img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
372
+ else:
373
+ img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
374
+ return img_paths
375
+
376
+ #============================================================================================================
377
+ #networks.lora
378
+ #============================================================================================================
379
+ from networks.lora import LoRANetwork
380
+ def replace_prepare_optimizer_params(networks):
381
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, scheduler_lr=None, loranames=None):
382
+
383
+ def enumerate_params(loras, lora_name=None):
384
+ params = []
385
+ for lora in loras:
386
+ if lora_name is not None:
387
+ if lora_name in lora.lora_name:
388
+ params.extend(lora.parameters())
389
+ else:
390
+ params.extend(lora.parameters())
391
+ return params
392
+
393
+ self.requires_grad_(True)
394
+ all_params = []
395
+ ret_scheduler_lr = []
396
+
397
+ if loranames is not None:
398
+ textencoder_names = [None]
399
+ unet_names = [None]
400
+ if "text_encoder" in loranames:
401
+ textencoder_names = loranames["text_encoder"]
402
+ if "unet" in loranames:
403
+ unet_names = loranames["unet"]
404
+
405
+ if self.text_encoder_loras:
406
+ for textencoder_name in textencoder_names:
407
+ param_data = {'params': enumerate_params(self.text_encoder_loras, lora_name=textencoder_name)}
408
+ if text_encoder_lr is not None:
409
+ param_data['lr'] = text_encoder_lr
410
+ if scheduler_lr is not None:
411
+ ret_scheduler_lr.append(scheduler_lr[0])
412
+ all_params.append(param_data)
413
+
414
+ if self.unet_loras:
415
+ for unet_name in unet_names:
416
+ param_data = {'params': enumerate_params(self.unet_loras, lora_name=unet_name)}
417
+ if unet_lr is not None:
418
+ param_data['lr'] = unet_lr
419
+ if scheduler_lr is not None:
420
+ ret_scheduler_lr.append(scheduler_lr[1])
421
+ all_params.append(param_data)
422
+
423
+ return all_params, ret_scheduler_lr
424
+
425
+ LoRANetwork.prepare_optimizer_params = prepare_optimizer_params
426
+
427
+ #============================================================================================================
428
+ #新規追加
429
+ #============================================================================================================
430
+ def add_append_arguments(parser: argparse.ArgumentParser):
431
+ # for train_network_opt.py
432
+ parser.add_argument("--optimizer", type=str, default="AdamW", choices=["AdamW", "RAdam", "AdaBound", "AdaBelief", "AggMo", "AdamP", "Adastand", "Adastand_belief", "Apollo", "Lamb", "Ranger", "RangerVA", "Lookahead_Adam", "Lookahead_DiffGrad", "Yogi", "NovoGrad", "QHAdam", "DiffGrad", "MADGRAD", "Adafactor"], help="使用するoptimizerを指定する")
433
+ parser.add_argument("--optimizer_arg", type=str, default=None, nargs='*')
434
+ parser.add_argument("--split_lora_networks", action="store_true")
435
+ parser.add_argument("--split_lora_level", type=int, default=0, help="どれくらい細分化するかの設定 0がunetのみを層別に 1がunetを大枠で分割 2がtextencoder含めて層別")
436
+ parser.add_argument("--min_resolution", type=str, default=None)
437
+ parser.add_argument("--area_step", type=int, default=1)
438
+ parser.add_argument("--config", type=str, default=None)
439
+
440
+ def create_split_names(split_flag, split_level):
441
+ split_names = None
442
+ if split_flag:
443
+ split_names = {}
444
+ text_encoder_names = [None]
445
+ unet_names = ["lora_unet_mid_block"]
446
+ if split_level==1:
447
+ unet_names.append(f"lora_unet_down_blocks_")
448
+ unet_names.append(f"lora_unet_up_blocks_")
449
+ elif split_level==2 or split_level==0:
450
+ if split_level==2:
451
+ text_encoder_names = []
452
+ for i in range(12):
453
+ text_encoder_names.append(f"lora_te_text_model_encoder_layers_{i}_")
454
+ for i in range(3):
455
+ unet_names.append(f"lora_unet_down_blocks_{i}")
456
+ unet_names.append(f"lora_unet_up_blocks_{i+1}")
457
+ split_names["text_encoder"] = text_encoder_names
458
+ split_names["unet"] = unet_names
459
+ return split_names
460
+
461
+ def get_config(parser):
462
+ args = parser.parse_args()
463
+ if args.config is not None and (not args.config==""):
464
+ import yaml
465
+ import datetime
466
+ if os.path.splitext(args.config)[-1] == ".yaml":
467
+ args.config = os.path.splitext(args.config)[0]
468
+ config_path = f"./{args.config}.yaml"
469
+ if os.path.exists(config_path):
470
+ print(f"{config_path} から設定を読み込み中...")
471
+ margs, rest = parser.parse_known_args()
472
+ with open(config_path, mode="r") as f:
473
+ configs = yaml.unsafe_load(f)
474
+ #変数でのやり取りをするためargparserからDict型を取り出す
475
+ args_dic = vars(args)
476
+ #デフォから引数指定で変更があるものを確認
477
+ change_def_dic = {}
478
+ args_type_dic = {}
479
+ for key, v in args_dic.items():
480
+ if not parser.get_default(key) == v:
481
+ change_def_dic[key] = v
482
+ #デフォ指定されてるデータ型を取得する
483
+ for key, act in parser._option_string_actions.items():
484
+ if key=="-h": continue
485
+ key = key[2:]
486
+ args_type_dic[key] = act.type
487
+ #データタイプの確認とargsにkeyの内容を代入していく
488
+ for key, v in configs.items():
489
+ if key in args_dic:
490
+ if args_dic[key] is not None:
491
+ new_type = type(args_dic[key])
492
+ if (not type(v) == new_type) and (not new_type==list):
493
+ v = new_type(v)
494
+ else:
495
+ if v is not None:
496
+ if not type(v) == args_type_dic[key]:
497
+ v = args_type_dic[key](v)
498
+ args_dic[key] = v
499
+ #最後にデフォから指定が変わってるものを変更する
500
+ for key, v in change_def_dic.items():
501
+ args_dic[key] = v
502
+ else:
503
+ print(f"{config_path} が見つかりませんでした")
504
+ return args
build/lib/library/__init__.py ADDED
File without changes
build/lib/library/model_util.py ADDED
@@ -0,0 +1,1180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # v1: split from train_db_fixed.py.
2
+ # v2: support safetensors
3
+
4
+ import math
5
+ import os
6
+ import torch
7
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
8
+ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
9
+ from safetensors.torch import load_file, save_file
10
+
11
+ # DiffUsers版StableDiffusionのモデルパラメータ
12
+ NUM_TRAIN_TIMESTEPS = 1000
13
+ BETA_START = 0.00085
14
+ BETA_END = 0.0120
15
+
16
+ UNET_PARAMS_MODEL_CHANNELS = 320
17
+ UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
18
+ UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
19
+ UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
20
+ UNET_PARAMS_IN_CHANNELS = 4
21
+ UNET_PARAMS_OUT_CHANNELS = 4
22
+ UNET_PARAMS_NUM_RES_BLOCKS = 2
23
+ UNET_PARAMS_CONTEXT_DIM = 768
24
+ UNET_PARAMS_NUM_HEADS = 8
25
+
26
+ VAE_PARAMS_Z_CHANNELS = 4
27
+ VAE_PARAMS_RESOLUTION = 256
28
+ VAE_PARAMS_IN_CHANNELS = 3
29
+ VAE_PARAMS_OUT_CH = 3
30
+ VAE_PARAMS_CH = 128
31
+ VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
32
+ VAE_PARAMS_NUM_RES_BLOCKS = 2
33
+
34
+ # V2
35
+ V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
36
+ V2_UNET_PARAMS_CONTEXT_DIM = 1024
37
+
38
+ # Diffusersの設定を読み込むための参照モデル
39
+ DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
40
+ DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
41
+
42
+
43
+ # region StableDiffusion->Diffusersの変換コード
44
+ # convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
45
+
46
+
47
+ def shave_segments(path, n_shave_prefix_segments=1):
48
+ """
49
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
50
+ """
51
+ if n_shave_prefix_segments >= 0:
52
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
53
+ else:
54
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
55
+
56
+
57
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
58
+ """
59
+ Updates paths inside resnets to the new naming scheme (local renaming)
60
+ """
61
+ mapping = []
62
+ for old_item in old_list:
63
+ new_item = old_item.replace("in_layers.0", "norm1")
64
+ new_item = new_item.replace("in_layers.2", "conv1")
65
+
66
+ new_item = new_item.replace("out_layers.0", "norm2")
67
+ new_item = new_item.replace("out_layers.3", "conv2")
68
+
69
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
70
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
71
+
72
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
73
+
74
+ mapping.append({"old": old_item, "new": new_item})
75
+
76
+ return mapping
77
+
78
+
79
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
80
+ """
81
+ Updates paths inside resnets to the new naming scheme (local renaming)
82
+ """
83
+ mapping = []
84
+ for old_item in old_list:
85
+ new_item = old_item
86
+
87
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
88
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
89
+
90
+ mapping.append({"old": old_item, "new": new_item})
91
+
92
+ return mapping
93
+
94
+
95
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
96
+ """
97
+ Updates paths inside attentions to the new naming scheme (local renaming)
98
+ """
99
+ mapping = []
100
+ for old_item in old_list:
101
+ new_item = old_item
102
+
103
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
104
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
105
+
106
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
107
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
108
+
109
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
110
+
111
+ mapping.append({"old": old_item, "new": new_item})
112
+
113
+ return mapping
114
+
115
+
116
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
117
+ """
118
+ Updates paths inside attentions to the new naming scheme (local renaming)
119
+ """
120
+ mapping = []
121
+ for old_item in old_list:
122
+ new_item = old_item
123
+
124
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
125
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
126
+
127
+ new_item = new_item.replace("q.weight", "query.weight")
128
+ new_item = new_item.replace("q.bias", "query.bias")
129
+
130
+ new_item = new_item.replace("k.weight", "key.weight")
131
+ new_item = new_item.replace("k.bias", "key.bias")
132
+
133
+ new_item = new_item.replace("v.weight", "value.weight")
134
+ new_item = new_item.replace("v.bias", "value.bias")
135
+
136
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
137
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
138
+
139
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
140
+
141
+ mapping.append({"old": old_item, "new": new_item})
142
+
143
+ return mapping
144
+
145
+
146
+ def assign_to_checkpoint(
147
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
148
+ ):
149
+ """
150
+ This does the final conversion step: take locally converted weights and apply a global renaming
151
+ to them. It splits attention layers, and takes into account additional replacements
152
+ that may arise.
153
+
154
+ Assigns the weights to the new checkpoint.
155
+ """
156
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
157
+
158
+ # Splits the attention layers into three variables.
159
+ if attention_paths_to_split is not None:
160
+ for path, path_map in attention_paths_to_split.items():
161
+ old_tensor = old_checkpoint[path]
162
+ channels = old_tensor.shape[0] // 3
163
+
164
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
165
+
166
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
167
+
168
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
169
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
170
+
171
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
172
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
173
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
174
+
175
+ for path in paths:
176
+ new_path = path["new"]
177
+
178
+ # These have already been assigned
179
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
180
+ continue
181
+
182
+ # Global renaming happens here
183
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
184
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
185
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
186
+
187
+ if additional_replacements is not None:
188
+ for replacement in additional_replacements:
189
+ new_path = new_path.replace(replacement["old"], replacement["new"])
190
+
191
+ # proj_attn.weight has to be converted from conv 1D to linear
192
+ if "proj_attn.weight" in new_path:
193
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
194
+ else:
195
+ checkpoint[new_path] = old_checkpoint[path["old"]]
196
+
197
+
198
+ def conv_attn_to_linear(checkpoint):
199
+ keys = list(checkpoint.keys())
200
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
201
+ for key in keys:
202
+ if ".".join(key.split(".")[-2:]) in attn_keys:
203
+ if checkpoint[key].ndim > 2:
204
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
205
+ elif "proj_attn.weight" in key:
206
+ if checkpoint[key].ndim > 2:
207
+ checkpoint[key] = checkpoint[key][:, :, 0]
208
+
209
+
210
+ def linear_transformer_to_conv(checkpoint):
211
+ keys = list(checkpoint.keys())
212
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
213
+ for key in keys:
214
+ if ".".join(key.split(".")[-2:]) in tf_keys:
215
+ if checkpoint[key].ndim == 2:
216
+ checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
217
+
218
+
219
+ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
220
+ """
221
+ Takes a state dict and a config, and returns a converted checkpoint.
222
+ """
223
+
224
+ # extract state_dict for UNet
225
+ unet_state_dict = {}
226
+ unet_key = "model.diffusion_model."
227
+ keys = list(checkpoint.keys())
228
+ for key in keys:
229
+ if key.startswith(unet_key):
230
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
231
+
232
+ new_checkpoint = {}
233
+
234
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
235
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
236
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
237
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
238
+
239
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
240
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
241
+
242
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
243
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
244
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
245
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
246
+
247
+ # Retrieves the keys for the input blocks only
248
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
249
+ input_blocks = {
250
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
251
+ for layer_id in range(num_input_blocks)
252
+ }
253
+
254
+ # Retrieves the keys for the middle blocks only
255
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
256
+ middle_blocks = {
257
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
258
+ for layer_id in range(num_middle_blocks)
259
+ }
260
+
261
+ # Retrieves the keys for the output blocks only
262
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
263
+ output_blocks = {
264
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
265
+ for layer_id in range(num_output_blocks)
266
+ }
267
+
268
+ for i in range(1, num_input_blocks):
269
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
270
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
271
+
272
+ resnets = [
273
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
274
+ ]
275
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
276
+
277
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
278
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
279
+ f"input_blocks.{i}.0.op.weight"
280
+ )
281
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
282
+ f"input_blocks.{i}.0.op.bias"
283
+ )
284
+
285
+ paths = renew_resnet_paths(resnets)
286
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
287
+ assign_to_checkpoint(
288
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
289
+ )
290
+
291
+ if len(attentions):
292
+ paths = renew_attention_paths(attentions)
293
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
294
+ assign_to_checkpoint(
295
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
296
+ )
297
+
298
+ resnet_0 = middle_blocks[0]
299
+ attentions = middle_blocks[1]
300
+ resnet_1 = middle_blocks[2]
301
+
302
+ resnet_0_paths = renew_resnet_paths(resnet_0)
303
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
304
+
305
+ resnet_1_paths = renew_resnet_paths(resnet_1)
306
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
307
+
308
+ attentions_paths = renew_attention_paths(attentions)
309
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
310
+ assign_to_checkpoint(
311
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
312
+ )
313
+
314
+ for i in range(num_output_blocks):
315
+ block_id = i // (config["layers_per_block"] + 1)
316
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
317
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
318
+ output_block_list = {}
319
+
320
+ for layer in output_block_layers:
321
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
322
+ if layer_id in output_block_list:
323
+ output_block_list[layer_id].append(layer_name)
324
+ else:
325
+ output_block_list[layer_id] = [layer_name]
326
+
327
+ if len(output_block_list) > 1:
328
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
329
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
330
+
331
+ resnet_0_paths = renew_resnet_paths(resnets)
332
+ paths = renew_resnet_paths(resnets)
333
+
334
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
335
+ assign_to_checkpoint(
336
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
337
+ )
338
+
339
+ # オリジナル:
340
+ # if ["conv.weight", "conv.bias"] in output_block_list.values():
341
+ # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
342
+
343
+ # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
344
+ for l in output_block_list.values():
345
+ l.sort()
346
+
347
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
348
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
349
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
350
+ f"output_blocks.{i}.{index}.conv.bias"
351
+ ]
352
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
353
+ f"output_blocks.{i}.{index}.conv.weight"
354
+ ]
355
+
356
+ # Clear attentions as they have been attributed above.
357
+ if len(attentions) == 2:
358
+ attentions = []
359
+
360
+ if len(attentions):
361
+ paths = renew_attention_paths(attentions)
362
+ meta_path = {
363
+ "old": f"output_blocks.{i}.1",
364
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
365
+ }
366
+ assign_to_checkpoint(
367
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
368
+ )
369
+ else:
370
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
371
+ for path in resnet_0_paths:
372
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
373
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
374
+
375
+ new_checkpoint[new_path] = unet_state_dict[old_path]
376
+
377
+ # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
378
+ if v2:
379
+ linear_transformer_to_conv(new_checkpoint)
380
+
381
+ return new_checkpoint
382
+
383
+
384
+ def convert_ldm_vae_checkpoint(checkpoint, config):
385
+ # extract state dict for VAE
386
+ vae_state_dict = {}
387
+ vae_key = "first_stage_model."
388
+ keys = list(checkpoint.keys())
389
+ for key in keys:
390
+ if key.startswith(vae_key):
391
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
392
+ # if len(vae_state_dict) == 0:
393
+ # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
394
+ # vae_state_dict = checkpoint
395
+
396
+ new_checkpoint = {}
397
+
398
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
399
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
400
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
401
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
402
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
403
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
404
+
405
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
406
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
407
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
408
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
409
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
410
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
411
+
412
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
413
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
414
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
415
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
416
+
417
+ # Retrieves the keys for the encoder down blocks only
418
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
419
+ down_blocks = {
420
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
421
+ }
422
+
423
+ # Retrieves the keys for the decoder up blocks only
424
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
425
+ up_blocks = {
426
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
427
+ }
428
+
429
+ for i in range(num_down_blocks):
430
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
431
+
432
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
433
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
434
+ f"encoder.down.{i}.downsample.conv.weight"
435
+ )
436
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
437
+ f"encoder.down.{i}.downsample.conv.bias"
438
+ )
439
+
440
+ paths = renew_vae_resnet_paths(resnets)
441
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
442
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
443
+
444
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
445
+ num_mid_res_blocks = 2
446
+ for i in range(1, num_mid_res_blocks + 1):
447
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
448
+
449
+ paths = renew_vae_resnet_paths(resnets)
450
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
451
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
452
+
453
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
454
+ paths = renew_vae_attention_paths(mid_attentions)
455
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
456
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
457
+ conv_attn_to_linear(new_checkpoint)
458
+
459
+ for i in range(num_up_blocks):
460
+ block_id = num_up_blocks - 1 - i
461
+ resnets = [
462
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
463
+ ]
464
+
465
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
466
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
467
+ f"decoder.up.{block_id}.upsample.conv.weight"
468
+ ]
469
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
470
+ f"decoder.up.{block_id}.upsample.conv.bias"
471
+ ]
472
+
473
+ paths = renew_vae_resnet_paths(resnets)
474
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
475
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
476
+
477
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
478
+ num_mid_res_blocks = 2
479
+ for i in range(1, num_mid_res_blocks + 1):
480
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
481
+
482
+ paths = renew_vae_resnet_paths(resnets)
483
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
484
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
485
+
486
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
487
+ paths = renew_vae_attention_paths(mid_attentions)
488
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
489
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
490
+ conv_attn_to_linear(new_checkpoint)
491
+ return new_checkpoint
492
+
493
+
494
+ def create_unet_diffusers_config(v2):
495
+ """
496
+ Creates a config for the diffusers based on the config of the LDM model.
497
+ """
498
+ # unet_params = original_config.model.params.unet_config.params
499
+
500
+ block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
501
+
502
+ down_block_types = []
503
+ resolution = 1
504
+ for i in range(len(block_out_channels)):
505
+ block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
506
+ down_block_types.append(block_type)
507
+ if i != len(block_out_channels) - 1:
508
+ resolution *= 2
509
+
510
+ up_block_types = []
511
+ for i in range(len(block_out_channels)):
512
+ block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
513
+ up_block_types.append(block_type)
514
+ resolution //= 2
515
+
516
+ config = dict(
517
+ sample_size=UNET_PARAMS_IMAGE_SIZE,
518
+ in_channels=UNET_PARAMS_IN_CHANNELS,
519
+ out_channels=UNET_PARAMS_OUT_CHANNELS,
520
+ down_block_types=tuple(down_block_types),
521
+ up_block_types=tuple(up_block_types),
522
+ block_out_channels=tuple(block_out_channels),
523
+ layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
524
+ cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
525
+ attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
526
+ )
527
+
528
+ return config
529
+
530
+
531
+ def create_vae_diffusers_config():
532
+ """
533
+ Creates a config for the diffusers based on the config of the LDM model.
534
+ """
535
+ # vae_params = original_config.model.params.first_stage_config.params.ddconfig
536
+ # _ = original_config.model.params.first_stage_config.params.embed_dim
537
+ block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
538
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
539
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
540
+
541
+ config = dict(
542
+ sample_size=VAE_PARAMS_RESOLUTION,
543
+ in_channels=VAE_PARAMS_IN_CHANNELS,
544
+ out_channels=VAE_PARAMS_OUT_CH,
545
+ down_block_types=tuple(down_block_types),
546
+ up_block_types=tuple(up_block_types),
547
+ block_out_channels=tuple(block_out_channels),
548
+ latent_channels=VAE_PARAMS_Z_CHANNELS,
549
+ layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
550
+ )
551
+ return config
552
+
553
+
554
+ def convert_ldm_clip_checkpoint_v1(checkpoint):
555
+ keys = list(checkpoint.keys())
556
+ text_model_dict = {}
557
+ for key in keys:
558
+ if key.startswith("cond_stage_model.transformer"):
559
+ text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
560
+ return text_model_dict
561
+
562
+
563
+ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
564
+ # 嫌になるくらい違うぞ!
565
+ def convert_key(key):
566
+ if not key.startswith("cond_stage_model"):
567
+ return None
568
+
569
+ # common conversion
570
+ key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
571
+ key = key.replace("cond_stage_model.model.", "text_model.")
572
+
573
+ if "resblocks" in key:
574
+ # resblocks conversion
575
+ key = key.replace(".resblocks.", ".layers.")
576
+ if ".ln_" in key:
577
+ key = key.replace(".ln_", ".layer_norm")
578
+ elif ".mlp." in key:
579
+ key = key.replace(".c_fc.", ".fc1.")
580
+ key = key.replace(".c_proj.", ".fc2.")
581
+ elif '.attn.out_proj' in key:
582
+ key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
583
+ elif '.attn.in_proj' in key:
584
+ key = None # 特殊なので後で処理する
585
+ else:
586
+ raise ValueError(f"unexpected key in SD: {key}")
587
+ elif '.positional_embedding' in key:
588
+ key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
589
+ elif '.text_projection' in key:
590
+ key = None # 使われない???
591
+ elif '.logit_scale' in key:
592
+ key = None # 使われない???
593
+ elif '.token_embedding' in key:
594
+ key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
595
+ elif '.ln_final' in key:
596
+ key = key.replace(".ln_final", ".final_layer_norm")
597
+ return key
598
+
599
+ keys = list(checkpoint.keys())
600
+ new_sd = {}
601
+ for key in keys:
602
+ # remove resblocks 23
603
+ if '.resblocks.23.' in key:
604
+ continue
605
+ new_key = convert_key(key)
606
+ if new_key is None:
607
+ continue
608
+ new_sd[new_key] = checkpoint[key]
609
+
610
+ # attnの変換
611
+ for key in keys:
612
+ if '.resblocks.23.' in key:
613
+ continue
614
+ if '.resblocks' in key and '.attn.in_proj_' in key:
615
+ # 三つに分割
616
+ values = torch.chunk(checkpoint[key], 3)
617
+
618
+ key_suffix = ".weight" if "weight" in key else ".bias"
619
+ key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
620
+ key_pfx = key_pfx.replace("_weight", "")
621
+ key_pfx = key_pfx.replace("_bias", "")
622
+ key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
623
+ new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
624
+ new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
625
+ new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
626
+
627
+ # rename or add position_ids
628
+ ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
629
+ if ANOTHER_POSITION_IDS_KEY in new_sd:
630
+ # waifu diffusion v1.4
631
+ position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
632
+ del new_sd[ANOTHER_POSITION_IDS_KEY]
633
+ else:
634
+ position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
635
+
636
+ new_sd["text_model.embeddings.position_ids"] = position_ids
637
+ return new_sd
638
+
639
+ # endregion
640
+
641
+
642
+ # region Diffusers->StableDiffusion の変換コード
643
+ # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
644
+
645
+ def conv_transformer_to_linear(checkpoint):
646
+ keys = list(checkpoint.keys())
647
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
648
+ for key in keys:
649
+ if ".".join(key.split(".")[-2:]) in tf_keys:
650
+ if checkpoint[key].ndim > 2:
651
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
652
+
653
+
654
+ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
655
+ unet_conversion_map = [
656
+ # (stable-diffusion, HF Diffusers)
657
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
658
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
659
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
660
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
661
+ ("input_blocks.0.0.weight", "conv_in.weight"),
662
+ ("input_blocks.0.0.bias", "conv_in.bias"),
663
+ ("out.0.weight", "conv_norm_out.weight"),
664
+ ("out.0.bias", "conv_norm_out.bias"),
665
+ ("out.2.weight", "conv_out.weight"),
666
+ ("out.2.bias", "conv_out.bias"),
667
+ ]
668
+
669
+ unet_conversion_map_resnet = [
670
+ # (stable-diffusion, HF Diffusers)
671
+ ("in_layers.0", "norm1"),
672
+ ("in_layers.2", "conv1"),
673
+ ("out_layers.0", "norm2"),
674
+ ("out_layers.3", "conv2"),
675
+ ("emb_layers.1", "time_emb_proj"),
676
+ ("skip_connection", "conv_shortcut"),
677
+ ]
678
+
679
+ unet_conversion_map_layer = []
680
+ for i in range(4):
681
+ # loop over downblocks/upblocks
682
+
683
+ for j in range(2):
684
+ # loop over resnets/attentions for downblocks
685
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
686
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
687
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
688
+
689
+ if i < 3:
690
+ # no attention layers in down_blocks.3
691
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
692
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
693
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
694
+
695
+ for j in range(3):
696
+ # loop over resnets/attentions for upblocks
697
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
698
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
699
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
700
+
701
+ if i > 0:
702
+ # no attention layers in up_blocks.0
703
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
704
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
705
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
706
+
707
+ if i < 3:
708
+ # no downsample in down_blocks.3
709
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
710
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
711
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
712
+
713
+ # no upsample in up_blocks.3
714
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
715
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
716
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
717
+
718
+ hf_mid_atn_prefix = "mid_block.attentions.0."
719
+ sd_mid_atn_prefix = "middle_block.1."
720
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
721
+
722
+ for j in range(2):
723
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
724
+ sd_mid_res_prefix = f"middle_block.{2*j}."
725
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
726
+
727
+ # buyer beware: this is a *brittle* function,
728
+ # and correct output requires that all of these pieces interact in
729
+ # the exact order in which I have arranged them.
730
+ mapping = {k: k for k in unet_state_dict.keys()}
731
+ for sd_name, hf_name in unet_conversion_map:
732
+ mapping[hf_name] = sd_name
733
+ for k, v in mapping.items():
734
+ if "resnets" in k:
735
+ for sd_part, hf_part in unet_conversion_map_resnet:
736
+ v = v.replace(hf_part, sd_part)
737
+ mapping[k] = v
738
+ for k, v in mapping.items():
739
+ for sd_part, hf_part in unet_conversion_map_layer:
740
+ v = v.replace(hf_part, sd_part)
741
+ mapping[k] = v
742
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
743
+
744
+ if v2:
745
+ conv_transformer_to_linear(new_state_dict)
746
+
747
+ return new_state_dict
748
+
749
+
750
+ # ================#
751
+ # VAE Conversion #
752
+ # ================#
753
+
754
+ def reshape_weight_for_sd(w):
755
+ # convert HF linear weights to SD conv2d weights
756
+ return w.reshape(*w.shape, 1, 1)
757
+
758
+
759
+ def convert_vae_state_dict(vae_state_dict):
760
+ vae_conversion_map = [
761
+ # (stable-diffusion, HF Diffusers)
762
+ ("nin_shortcut", "conv_shortcut"),
763
+ ("norm_out", "conv_norm_out"),
764
+ ("mid.attn_1.", "mid_block.attentions.0."),
765
+ ]
766
+
767
+ for i in range(4):
768
+ # down_blocks have two resnets
769
+ for j in range(2):
770
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
771
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
772
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
773
+
774
+ if i < 3:
775
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
776
+ sd_downsample_prefix = f"down.{i}.downsample."
777
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
778
+
779
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
780
+ sd_upsample_prefix = f"up.{3-i}.upsample."
781
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
782
+
783
+ # up_blocks have three resnets
784
+ # also, up blocks in hf are numbered in reverse from sd
785
+ for j in range(3):
786
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
787
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
788
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
789
+
790
+ # this part accounts for mid blocks in both the encoder and the decoder
791
+ for i in range(2):
792
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
793
+ sd_mid_res_prefix = f"mid.block_{i+1}."
794
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
795
+
796
+ vae_conversion_map_attn = [
797
+ # (stable-diffusion, HF Diffusers)
798
+ ("norm.", "group_norm."),
799
+ ("q.", "query."),
800
+ ("k.", "key."),
801
+ ("v.", "value."),
802
+ ("proj_out.", "proj_attn."),
803
+ ]
804
+
805
+ mapping = {k: k for k in vae_state_dict.keys()}
806
+ for k, v in mapping.items():
807
+ for sd_part, hf_part in vae_conversion_map:
808
+ v = v.replace(hf_part, sd_part)
809
+ mapping[k] = v
810
+ for k, v in mapping.items():
811
+ if "attentions" in k:
812
+ for sd_part, hf_part in vae_conversion_map_attn:
813
+ v = v.replace(hf_part, sd_part)
814
+ mapping[k] = v
815
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
816
+ weights_to_convert = ["q", "k", "v", "proj_out"]
817
+ for k, v in new_state_dict.items():
818
+ for weight_name in weights_to_convert:
819
+ if f"mid.attn_1.{weight_name}.weight" in k:
820
+ # print(f"Reshaping {k} for SD format")
821
+ new_state_dict[k] = reshape_weight_for_sd(v)
822
+
823
+ return new_state_dict
824
+
825
+
826
+ # endregion
827
+
828
+ # region 自作のモデル読み書きなど
829
+
830
+ def is_safetensors(path):
831
+ return os.path.splitext(path)[1].lower() == '.safetensors'
832
+
833
+
834
+ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
835
+ # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
836
+ TEXT_ENCODER_KEY_REPLACEMENTS = [
837
+ ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
838
+ ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
839
+ ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
840
+ ]
841
+
842
+ if is_safetensors(ckpt_path):
843
+ checkpoint = None
844
+ state_dict = load_file(ckpt_path, "cpu")
845
+ else:
846
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
847
+ if "state_dict" in checkpoint:
848
+ state_dict = checkpoint["state_dict"]
849
+ else:
850
+ state_dict = checkpoint
851
+ checkpoint = None
852
+
853
+ key_reps = []
854
+ for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
855
+ for key in state_dict.keys():
856
+ if key.startswith(rep_from):
857
+ new_key = rep_to + key[len(rep_from):]
858
+ key_reps.append((key, new_key))
859
+
860
+ for key, new_key in key_reps:
861
+ state_dict[new_key] = state_dict[key]
862
+ del state_dict[key]
863
+
864
+ return checkpoint, state_dict
865
+
866
+
867
+ # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
868
+ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
869
+ _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
870
+ if dtype is not None:
871
+ for k, v in state_dict.items():
872
+ if type(v) is torch.Tensor:
873
+ state_dict[k] = v.to(dtype)
874
+
875
+ # Convert the UNet2DConditionModel model.
876
+ unet_config = create_unet_diffusers_config(v2)
877
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
878
+
879
+ unet = UNet2DConditionModel(**unet_config)
880
+ info = unet.load_state_dict(converted_unet_checkpoint)
881
+ print("loading u-net:", info)
882
+
883
+ # Convert the VAE model.
884
+ vae_config = create_vae_diffusers_config()
885
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
886
+
887
+ vae = AutoencoderKL(**vae_config)
888
+ info = vae.load_state_dict(converted_vae_checkpoint)
889
+ print("loading vae:", info)
890
+
891
+ # convert text_model
892
+ if v2:
893
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
894
+ cfg = CLIPTextConfig(
895
+ vocab_size=49408,
896
+ hidden_size=1024,
897
+ intermediate_size=4096,
898
+ num_hidden_layers=23,
899
+ num_attention_heads=16,
900
+ max_position_embeddings=77,
901
+ hidden_act="gelu",
902
+ layer_norm_eps=1e-05,
903
+ dropout=0.0,
904
+ attention_dropout=0.0,
905
+ initializer_range=0.02,
906
+ initializer_factor=1.0,
907
+ pad_token_id=1,
908
+ bos_token_id=0,
909
+ eos_token_id=2,
910
+ model_type="clip_text_model",
911
+ projection_dim=512,
912
+ torch_dtype="float32",
913
+ transformers_version="4.25.0.dev0",
914
+ )
915
+ text_model = CLIPTextModel._from_config(cfg)
916
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
917
+ else:
918
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
919
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
920
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
921
+ print("loading text encoder:", info)
922
+
923
+ return text_model, vae, unet
924
+
925
+
926
+ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
927
+ def convert_key(key):
928
+ # position_idsの除去
929
+ if ".position_ids" in key:
930
+ return None
931
+
932
+ # common
933
+ key = key.replace("text_model.encoder.", "transformer.")
934
+ key = key.replace("text_model.", "")
935
+ if "layers" in key:
936
+ # resblocks conversion
937
+ key = key.replace(".layers.", ".resblocks.")
938
+ if ".layer_norm" in key:
939
+ key = key.replace(".layer_norm", ".ln_")
940
+ elif ".mlp." in key:
941
+ key = key.replace(".fc1.", ".c_fc.")
942
+ key = key.replace(".fc2.", ".c_proj.")
943
+ elif '.self_attn.out_proj' in key:
944
+ key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
945
+ elif '.self_attn.' in key:
946
+ key = None # 特殊なので後で処理する
947
+ else:
948
+ raise ValueError(f"unexpected key in DiffUsers model: {key}")
949
+ elif '.position_embedding' in key:
950
+ key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
951
+ elif '.token_embedding' in key:
952
+ key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
953
+ elif 'final_layer_norm' in key:
954
+ key = key.replace("final_layer_norm", "ln_final")
955
+ return key
956
+
957
+ keys = list(checkpoint.keys())
958
+ new_sd = {}
959
+ for key in keys:
960
+ new_key = convert_key(key)
961
+ if new_key is None:
962
+ continue
963
+ new_sd[new_key] = checkpoint[key]
964
+
965
+ # attnの変換
966
+ for key in keys:
967
+ if 'layers' in key and 'q_proj' in key:
968
+ # 三つを結合
969
+ key_q = key
970
+ key_k = key.replace("q_proj", "k_proj")
971
+ key_v = key.replace("q_proj", "v_proj")
972
+
973
+ value_q = checkpoint[key_q]
974
+ value_k = checkpoint[key_k]
975
+ value_v = checkpoint[key_v]
976
+ value = torch.cat([value_q, value_k, value_v])
977
+
978
+ new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
979
+ new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
980
+ new_sd[new_key] = value
981
+
982
+ # 最後の層などを捏造するか
983
+ if make_dummy_weights:
984
+ print("make dummy weights for resblock.23, text_projection and logit scale.")
985
+ keys = list(new_sd.keys())
986
+ for key in keys:
987
+ if key.startswith("transformer.resblocks.22."):
988
+ new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
989
+
990
+ # Diffusersに含まれない重みを作っておく
991
+ new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
992
+ new_sd['logit_scale'] = torch.tensor(1)
993
+
994
+ return new_sd
995
+
996
+
997
+ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
998
+ if ckpt_path is not None:
999
+ # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
1000
+ checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
1001
+ if checkpoint is None: # safetensors または state_dictのckpt
1002
+ checkpoint = {}
1003
+ strict = False
1004
+ else:
1005
+ strict = True
1006
+ if "state_dict" in state_dict:
1007
+ del state_dict["state_dict"]
1008
+ else:
1009
+ # 新しく作る
1010
+ assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
1011
+ checkpoint = {}
1012
+ state_dict = {}
1013
+ strict = False
1014
+
1015
+ def update_sd(prefix, sd):
1016
+ for k, v in sd.items():
1017
+ key = prefix + k
1018
+ assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
1019
+ if save_dtype is not None:
1020
+ v = v.detach().clone().to("cpu").to(save_dtype)
1021
+ state_dict[key] = v
1022
+
1023
+ # Convert the UNet model
1024
+ unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
1025
+ update_sd("model.diffusion_model.", unet_state_dict)
1026
+
1027
+ # Convert the text encoder model
1028
+ if v2:
1029
+ make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製��て作るなどダミーの重みを入れる
1030
+ text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
1031
+ update_sd("cond_stage_model.model.", text_enc_dict)
1032
+ else:
1033
+ text_enc_dict = text_encoder.state_dict()
1034
+ update_sd("cond_stage_model.transformer.", text_enc_dict)
1035
+
1036
+ # Convert the VAE
1037
+ if vae is not None:
1038
+ vae_dict = convert_vae_state_dict(vae.state_dict())
1039
+ update_sd("first_stage_model.", vae_dict)
1040
+
1041
+ # Put together new checkpoint
1042
+ key_count = len(state_dict.keys())
1043
+ new_ckpt = {'state_dict': state_dict}
1044
+
1045
+ if 'epoch' in checkpoint:
1046
+ epochs += checkpoint['epoch']
1047
+ if 'global_step' in checkpoint:
1048
+ steps += checkpoint['global_step']
1049
+
1050
+ new_ckpt['epoch'] = epochs
1051
+ new_ckpt['global_step'] = steps
1052
+
1053
+ if is_safetensors(output_file):
1054
+ # TODO Tensor以外のdictの値を削除したほうがいいか
1055
+ save_file(state_dict, output_file)
1056
+ else:
1057
+ torch.save(new_ckpt, output_file)
1058
+
1059
+ return key_count
1060
+
1061
+
1062
+ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
1063
+ if pretrained_model_name_or_path is None:
1064
+ # load default settings for v1/v2
1065
+ if v2:
1066
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
1067
+ else:
1068
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
1069
+
1070
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
1071
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
1072
+ if vae is None:
1073
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
1074
+
1075
+ pipeline = StableDiffusionPipeline(
1076
+ unet=unet,
1077
+ text_encoder=text_encoder,
1078
+ vae=vae,
1079
+ scheduler=scheduler,
1080
+ tokenizer=tokenizer,
1081
+ safety_checker=None,
1082
+ feature_extractor=None,
1083
+ requires_safety_checker=None,
1084
+ )
1085
+ pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
1086
+
1087
+
1088
+ VAE_PREFIX = "first_stage_model."
1089
+
1090
+
1091
+ def load_vae(vae_id, dtype):
1092
+ print(f"load VAE: {vae_id}")
1093
+ if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
1094
+ # Diffusers local/remote
1095
+ try:
1096
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
1097
+ except EnvironmentError as e:
1098
+ print(f"exception occurs in loading vae: {e}")
1099
+ print("retry with subfolder='vae'")
1100
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
1101
+ return vae
1102
+
1103
+ # local
1104
+ vae_config = create_vae_diffusers_config()
1105
+
1106
+ if vae_id.endswith(".bin"):
1107
+ # SD 1.5 VAE on Huggingface
1108
+ converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
1109
+ else:
1110
+ # StableDiffusion
1111
+ vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
1112
+ else torch.load(vae_id, map_location="cpu"))
1113
+ vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
1114
+
1115
+ # vae only or full model
1116
+ full_model = False
1117
+ for vae_key in vae_sd:
1118
+ if vae_key.startswith(VAE_PREFIX):
1119
+ full_model = True
1120
+ break
1121
+ if not full_model:
1122
+ sd = {}
1123
+ for key, value in vae_sd.items():
1124
+ sd[VAE_PREFIX + key] = value
1125
+ vae_sd = sd
1126
+ del sd
1127
+
1128
+ # Convert the VAE model.
1129
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
1130
+
1131
+ vae = AutoencoderKL(**vae_config)
1132
+ vae.load_state_dict(converted_vae_checkpoint)
1133
+ return vae
1134
+
1135
+ # endregion
1136
+
1137
+
1138
+ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
1139
+ max_width, max_height = max_reso
1140
+ max_area = (max_width // divisible) * (max_height // divisible)
1141
+
1142
+ resos = set()
1143
+
1144
+ size = int(math.sqrt(max_area)) * divisible
1145
+ resos.add((size, size))
1146
+
1147
+ size = min_size
1148
+ while size <= max_size:
1149
+ width = size
1150
+ height = min(max_size, (max_area // (width // divisible)) * divisible)
1151
+ resos.add((width, height))
1152
+ resos.add((height, width))
1153
+
1154
+ # # make additional resos
1155
+ # if width >= height and width - divisible >= min_size:
1156
+ # resos.add((width - divisible, height))
1157
+ # resos.add((height, width - divisible))
1158
+ # if height >= width and height - divisible >= min_size:
1159
+ # resos.add((width, height - divisible))
1160
+ # resos.add((height - divisible, width))
1161
+
1162
+ size += divisible
1163
+
1164
+ resos = list(resos)
1165
+ resos.sort()
1166
+ return resos
1167
+
1168
+
1169
+ if __name__ == '__main__':
1170
+ resos = make_bucket_resolutions((512, 768))
1171
+ print(len(resos))
1172
+ print(resos)
1173
+ aspect_ratios = [w / h for w, h in resos]
1174
+ print(aspect_ratios)
1175
+
1176
+ ars = set()
1177
+ for ar in aspect_ratios:
1178
+ if ar in ars:
1179
+ print("error! duplicate ar:", ar)
1180
+ ars.add(ar)
build/lib/library/train_util.py ADDED
@@ -0,0 +1,1796 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # common functions for training
2
+
3
+ import argparse
4
+ import json
5
+ import shutil
6
+ import time
7
+ from typing import Dict, List, NamedTuple, Tuple
8
+ from accelerate import Accelerator
9
+ from torch.autograd.function import Function
10
+ import glob
11
+ import math
12
+ import os
13
+ import random
14
+ import hashlib
15
+ import subprocess
16
+ from io import BytesIO
17
+
18
+ from tqdm import tqdm
19
+ import torch
20
+ from torchvision import transforms
21
+ from transformers import CLIPTokenizer
22
+ import diffusers
23
+ from diffusers import DDPMScheduler, StableDiffusionPipeline
24
+ import albumentations as albu
25
+ import numpy as np
26
+ from PIL import Image
27
+ import cv2
28
+ from einops import rearrange
29
+ from torch import einsum
30
+ import safetensors.torch
31
+
32
+ import library.model_util as model_util
33
+
34
+ # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
35
+ TOKENIZER_PATH = "openai/clip-vit-large-patch14"
36
+ V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
37
+
38
+ # checkpointファイル名
39
+ EPOCH_STATE_NAME = "{}-{:06d}-state"
40
+ EPOCH_FILE_NAME = "{}-{:06d}"
41
+ EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}"
42
+ LAST_STATE_NAME = "{}-state"
43
+ DEFAULT_EPOCH_NAME = "epoch"
44
+ DEFAULT_LAST_OUTPUT_NAME = "last"
45
+
46
+ # region dataset
47
+
48
+ IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
49
+ # , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux?
50
+
51
+
52
+ class ImageInfo():
53
+ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
54
+ self.image_key: str = image_key
55
+ self.num_repeats: int = num_repeats
56
+ self.caption: str = caption
57
+ self.is_reg: bool = is_reg
58
+ self.absolute_path: str = absolute_path
59
+ self.image_size: Tuple[int, int] = None
60
+ self.resized_size: Tuple[int, int] = None
61
+ self.bucket_reso: Tuple[int, int] = None
62
+ self.latents: torch.Tensor = None
63
+ self.latents_flipped: torch.Tensor = None
64
+ self.latents_npz: str = None
65
+ self.latents_npz_flipped: str = None
66
+
67
+
68
+ class BucketManager():
69
+ def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
70
+ self.no_upscale = no_upscale
71
+ if max_reso is None:
72
+ self.max_reso = None
73
+ self.max_area = None
74
+ else:
75
+ self.max_reso = max_reso
76
+ self.max_area = max_reso[0] * max_reso[1]
77
+ self.min_size = min_size
78
+ self.max_size = max_size
79
+ self.reso_steps = reso_steps
80
+
81
+ self.resos = []
82
+ self.reso_to_id = {}
83
+ self.buckets = [] # 前処理時は (image_key, image)、学習時は image_key
84
+
85
+ def add_image(self, reso, image):
86
+ bucket_id = self.reso_to_id[reso]
87
+ self.buckets[bucket_id].append(image)
88
+
89
+ def shuffle(self):
90
+ for bucket in self.buckets:
91
+ random.shuffle(bucket)
92
+
93
+ def sort(self):
94
+ # 解像度順にソートする(表示時、メタデータ格納時の見栄えをよくするためだけ)。bucketsも入れ替えてreso_to_idも振り直す
95
+ sorted_resos = self.resos.copy()
96
+ sorted_resos.sort()
97
+
98
+ sorted_buckets = []
99
+ sorted_reso_to_id = {}
100
+ for i, reso in enumerate(sorted_resos):
101
+ bucket_id = self.reso_to_id[reso]
102
+ sorted_buckets.append(self.buckets[bucket_id])
103
+ sorted_reso_to_id[reso] = i
104
+
105
+ self.resos = sorted_resos
106
+ self.buckets = sorted_buckets
107
+ self.reso_to_id = sorted_reso_to_id
108
+
109
+ def make_buckets(self):
110
+ resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps)
111
+ self.set_predefined_resos(resos)
112
+
113
+ def set_predefined_resos(self, resos):
114
+ # 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
115
+ self.predefined_resos = resos.copy()
116
+ self.predefined_resos_set = set(resos)
117
+ self.predefined_aspect_ratios = np.array([w / h for w, h in resos])
118
+
119
+ def add_if_new_reso(self, reso):
120
+ if reso not in self.reso_to_id:
121
+ bucket_id = len(self.resos)
122
+ self.reso_to_id[reso] = bucket_id
123
+ self.resos.append(reso)
124
+ self.buckets.append([])
125
+ # print(reso, bucket_id, len(self.buckets))
126
+
127
+ def round_to_steps(self, x):
128
+ x = int(x + .5)
129
+ return x - x % self.reso_steps
130
+
131
+ def select_bucket(self, image_width, image_height):
132
+ aspect_ratio = image_width / image_height
133
+ if not self.no_upscale:
134
+ # 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する
135
+ reso = (image_width, image_height)
136
+ if reso in self.predefined_resos_set:
137
+ pass
138
+ else:
139
+ ar_errors = self.predefined_aspect_ratios - aspect_ratio
140
+ predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
141
+ reso = self.predefined_resos[predefined_bucket_id]
142
+
143
+ ar_reso = reso[0] / reso[1]
144
+ if aspect_ratio > ar_reso: # 横が長い→縦を合わせる
145
+ scale = reso[1] / image_height
146
+ else:
147
+ scale = reso[0] / image_width
148
+
149
+ resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
150
+ # print("use predef", image_width, image_height, reso, resized_size)
151
+ else:
152
+ if image_width * image_height > self.max_area:
153
+ # 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める
154
+ resized_width = math.sqrt(self.max_area * aspect_ratio)
155
+ resized_height = self.max_area / resized_width
156
+ assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal"
157
+
158
+ # リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ
159
+ # 元のbucketingと同じロジック
160
+ b_width_rounded = self.round_to_steps(resized_width)
161
+ b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio)
162
+ ar_width_rounded = b_width_rounded / b_height_in_wr
163
+
164
+ b_height_rounded = self.round_to_steps(resized_height)
165
+ b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio)
166
+ ar_height_rounded = b_width_in_hr / b_height_rounded
167
+
168
+ # print(b_width_rounded, b_height_in_wr, ar_width_rounded)
169
+ # print(b_width_in_hr, b_height_rounded, ar_height_rounded)
170
+
171
+ if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio):
172
+ resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + .5))
173
+ else:
174
+ resized_size = (int(b_height_rounded * aspect_ratio + .5), b_height_rounded)
175
+ # print(resized_size)
176
+ else:
177
+ resized_size = (image_width, image_height) # リサイズは不要
178
+
179
+ # 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする)
180
+ bucket_width = resized_size[0] - resized_size[0] % self.reso_steps
181
+ bucket_height = resized_size[1] - resized_size[1] % self.reso_steps
182
+ # print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height)
183
+
184
+ reso = (bucket_width, bucket_height)
185
+
186
+ self.add_if_new_reso(reso)
187
+
188
+ ar_error = (reso[0] / reso[1]) - aspect_ratio
189
+ return reso, resized_size, ar_error
190
+
191
+
192
+ class BucketBatchIndex(NamedTuple):
193
+ bucket_index: int
194
+ bucket_batch_size: int
195
+ batch_index: int
196
+
197
+
198
+ class BaseDataset(torch.utils.data.Dataset):
199
+ def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
200
+ super().__init__()
201
+ self.tokenizer: CLIPTokenizer = tokenizer
202
+ self.max_token_length = max_token_length
203
+ self.shuffle_caption = shuffle_caption
204
+ self.shuffle_keep_tokens = shuffle_keep_tokens
205
+ # width/height is used when enable_bucket==False
206
+ self.width, self.height = (None, None) if resolution is None else resolution
207
+ self.face_crop_aug_range = face_crop_aug_range
208
+ self.flip_aug = flip_aug
209
+ self.color_aug = color_aug
210
+ self.debug_dataset = debug_dataset
211
+ self.random_crop = random_crop
212
+ self.token_padding_disabled = False
213
+ self.dataset_dirs_info = {}
214
+ self.reg_dataset_dirs_info = {}
215
+ self.tag_frequency = {}
216
+
217
+ self.enable_bucket = False
218
+ self.bucket_manager: BucketManager = None # not initialized
219
+ self.min_bucket_reso = None
220
+ self.max_bucket_reso = None
221
+ self.bucket_reso_steps = None
222
+ self.bucket_no_upscale = None
223
+ self.bucket_info = None # for metadata
224
+
225
+ self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
226
+
227
+ self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
228
+ self.dropout_rate: float = 0
229
+ self.dropout_every_n_epochs: int = None
230
+ self.tag_dropout_rate: float = 0
231
+
232
+ # augmentation
233
+ flip_p = 0.5 if flip_aug else 0.0
234
+ if color_aug:
235
+ # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
236
+ self.aug = albu.Compose([
237
+ albu.OneOf([
238
+ albu.HueSaturationValue(8, 0, 0, p=.5),
239
+ albu.RandomGamma((95, 105), p=.5),
240
+ ], p=.33),
241
+ albu.HorizontalFlip(p=flip_p)
242
+ ], p=1.)
243
+ elif flip_aug:
244
+ self.aug = albu.Compose([
245
+ albu.HorizontalFlip(p=flip_p)
246
+ ], p=1.)
247
+ else:
248
+ self.aug = None
249
+
250
+ self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
251
+
252
+ self.image_data: Dict[str, ImageInfo] = {}
253
+
254
+ self.replacements = {}
255
+
256
+ def set_current_epoch(self, epoch):
257
+ self.current_epoch = epoch
258
+
259
+ def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
260
+ # コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
261
+ self.dropout_rate = dropout_rate
262
+ self.dropout_every_n_epochs = dropout_every_n_epochs
263
+ self.tag_dropout_rate = tag_dropout_rate
264
+
265
+ def set_tag_frequency(self, dir_name, captions):
266
+ frequency_for_dir = self.tag_frequency.get(dir_name, {})
267
+ self.tag_frequency[dir_name] = frequency_for_dir
268
+ for caption in captions:
269
+ for tag in caption.split(","):
270
+ if tag and not tag.isspace():
271
+ tag = tag.lower()
272
+ frequency = frequency_for_dir.get(tag, 0)
273
+ frequency_for_dir[tag] = frequency + 1
274
+
275
+ def disable_token_padding(self):
276
+ self.token_padding_disabled = True
277
+
278
+ def add_replacement(self, str_from, str_to):
279
+ self.replacements[str_from] = str_to
280
+
281
+ def process_caption(self, caption):
282
+ # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
283
+ is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
284
+ is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
285
+
286
+ if is_drop_out:
287
+ caption = ""
288
+ else:
289
+ if self.shuffle_caption or self.tag_dropout_rate > 0:
290
+ def dropout_tags(tokens):
291
+ if self.tag_dropout_rate <= 0:
292
+ return tokens
293
+ l = []
294
+ for token in tokens:
295
+ if random.random() >= self.tag_dropout_rate:
296
+ l.append(token)
297
+ return l
298
+
299
+ tokens = [t.strip() for t in caption.strip().split(",")]
300
+ if self.shuffle_keep_tokens is None:
301
+ if self.shuffle_caption:
302
+ random.shuffle(tokens)
303
+
304
+ tokens = dropout_tags(tokens)
305
+ else:
306
+ if len(tokens) > self.shuffle_keep_tokens:
307
+ keep_tokens = tokens[:self.shuffle_keep_tokens]
308
+ tokens = tokens[self.shuffle_keep_tokens:]
309
+
310
+ if self.shuffle_caption:
311
+ random.shuffle(tokens)
312
+
313
+ tokens = dropout_tags(tokens)
314
+
315
+ tokens = keep_tokens + tokens
316
+ caption = ", ".join(tokens)
317
+
318
+ # textual inversion対応
319
+ for str_from, str_to in self.replacements.items():
320
+ if str_from == "":
321
+ # replace all
322
+ if type(str_to) == list:
323
+ caption = random.choice(str_to)
324
+ else:
325
+ caption = str_to
326
+ else:
327
+ caption = caption.replace(str_from, str_to)
328
+
329
+ return caption
330
+
331
+ def get_input_ids(self, caption):
332
+ input_ids = self.tokenizer(caption, padding="max_length", truncation=True,
333
+ max_length=self.tokenizer_max_length, return_tensors="pt").input_ids
334
+
335
+ if self.tokenizer_max_length > self.tokenizer.model_max_length:
336
+ input_ids = input_ids.squeeze(0)
337
+ iids_list = []
338
+ if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
339
+ # v1
340
+ # 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
341
+ # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
342
+ for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75)
343
+ ids_chunk = (input_ids[0].unsqueeze(0),
344
+ input_ids[i:i + self.tokenizer.model_max_length - 2],
345
+ input_ids[-1].unsqueeze(0))
346
+ ids_chunk = torch.cat(ids_chunk)
347
+ iids_list.append(ids_chunk)
348
+ else:
349
+ # v2
350
+ # 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
351
+ for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2):
352
+ ids_chunk = (input_ids[0].unsqueeze(0), # BOS
353
+ input_ids[i:i + self.tokenizer.model_max_length - 2],
354
+ input_ids[-1].unsqueeze(0)) # PAD or EOS
355
+ ids_chunk = torch.cat(ids_chunk)
356
+
357
+ # 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
358
+ # 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
359
+ if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id:
360
+ ids_chunk[-1] = self.tokenizer.eos_token_id
361
+ # 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
362
+ if ids_chunk[1] == self.tokenizer.pad_token_id:
363
+ ids_chunk[1] = self.tokenizer.eos_token_id
364
+
365
+ iids_list.append(ids_chunk)
366
+
367
+ input_ids = torch.stack(iids_list) # 3,77
368
+ return input_ids
369
+
370
+ def register_image(self, info: ImageInfo):
371
+ self.image_data[info.image_key] = info
372
+
373
+ def make_buckets(self):
374
+ '''
375
+ bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
376
+ min_size and max_size are ignored when enable_bucket is False
377
+ '''
378
+ print("loading image sizes.")
379
+ for info in tqdm(self.image_data.values()):
380
+ if info.image_size is None:
381
+ info.image_size = self.get_image_size(info.absolute_path)
382
+
383
+ if self.enable_bucket:
384
+ print("make buckets")
385
+ else:
386
+ print("prepare dataset")
387
+
388
+ # bucketを作成し、画像をbucketに振り分ける
389
+ if self.enable_bucket:
390
+ if self.bucket_manager is None: # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み
391
+ self.bucket_manager = BucketManager(self.bucket_no_upscale, (self.width, self.height),
392
+ self.min_bucket_reso, self.max_bucket_reso, self.bucket_reso_steps)
393
+ if not self.bucket_no_upscale:
394
+ self.bucket_manager.make_buckets()
395
+ else:
396
+ print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
397
+
398
+ img_ar_errors = []
399
+ for image_info in self.image_data.values():
400
+ image_width, image_height = image_info.image_size
401
+ image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(image_width, image_height)
402
+
403
+ # print(image_info.image_key, image_info.bucket_reso)
404
+ img_ar_errors.append(abs(ar_error))
405
+
406
+ self.bucket_manager.sort()
407
+ else:
408
+ self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None)
409
+ self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ
410
+ for image_info in self.image_data.values():
411
+ image_width, image_height = image_info.image_size
412
+ image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height)
413
+
414
+ for image_info in self.image_data.values():
415
+ for _ in range(image_info.num_repeats):
416
+ self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key)
417
+
418
+ # bucket情報を表示、格納する
419
+ if self.enable_bucket:
420
+ self.bucket_info = {"buckets": {}}
421
+ print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
422
+ for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)):
423
+ count = len(bucket)
424
+ if count > 0:
425
+ self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
426
+ print(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
427
+
428
+ img_ar_errors = np.array(img_ar_errors)
429
+ mean_img_ar_error = np.mean(np.abs(img_ar_errors))
430
+ self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
431
+ print(f"mean ar error (without repeats): {mean_img_ar_error}")
432
+
433
+ # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
434
+ self.buckets_indices: List(BucketBatchIndex) = []
435
+ for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
436
+ batch_count = int(math.ceil(len(bucket) / self.batch_size))
437
+ for batch_index in range(batch_count):
438
+ self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index))
439
+
440
+ # ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す
441
+ #  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる
442
+ #
443
+ # # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
444
+ # # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
445
+ # # そのためバッチサイズを画像種類までに制限する
446
+ # # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
447
+ # # TO DO 正則化画像をepochまたがりで利用する仕組み
448
+ # num_of_image_types = len(set(bucket))
449
+ # bucket_batch_size = min(self.batch_size, num_of_image_types)
450
+ # batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
451
+ # # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
452
+ # for batch_index in range(batch_count):
453
+ # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
454
+ # ↑ここまで
455
+
456
+ self.shuffle_buckets()
457
+ self._length = len(self.buckets_indices)
458
+
459
+ def shuffle_buckets(self):
460
+ random.shuffle(self.buckets_indices)
461
+ self.bucket_manager.shuffle()
462
+
463
+ def load_image(self, image_path):
464
+ image = Image.open(image_path)
465
+ if not image.mode == "RGB":
466
+ image = image.convert("RGB")
467
+ img = np.array(image, np.uint8)
468
+ return img
469
+
470
+ def trim_and_resize_if_required(self, image, reso, resized_size):
471
+ image_height, image_width = image.shape[0:2]
472
+
473
+ if image_width != resized_size[0] or image_height != resized_size[1]:
474
+ # リサイズする
475
+ image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
476
+
477
+ image_height, image_width = image.shape[0:2]
478
+ if image_width > reso[0]:
479
+ trim_size = image_width - reso[0]
480
+ p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
481
+ # print("w", trim_size, p)
482
+ image = image[:, p:p + reso[0]]
483
+ if image_height > reso[1]:
484
+ trim_size = image_height - reso[1]
485
+ p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
486
+ # print("h", trim_size, p)
487
+ image = image[p:p + reso[1]]
488
+
489
+ assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
490
+ return image
491
+
492
+ def cache_latents(self, vae):
493
+ # TODO ここを高速化したい
494
+ print("caching latents.")
495
+ for info in tqdm(self.image_data.values()):
496
+ if info.latents_npz is not None:
497
+ info.latents = self.load_latents_from_npz(info, False)
498
+ info.latents = torch.FloatTensor(info.latents)
499
+ info.latents_flipped = self.load_latents_from_npz(info, True) # might be None
500
+ if info.latents_flipped is not None:
501
+ info.latents_flipped = torch.FloatTensor(info.latents_flipped)
502
+ continue
503
+
504
+ image = self.load_image(info.absolute_path)
505
+ image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
506
+
507
+ img_tensor = self.image_transforms(image)
508
+ img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
509
+ info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
510
+
511
+ if self.flip_aug:
512
+ image = image[:, ::-1].copy() # cannot convert to Tensor without copy
513
+ img_tensor = self.image_transforms(image)
514
+ img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
515
+ info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
516
+
517
+ def get_image_size(self, image_path):
518
+ image = Image.open(image_path)
519
+ return image.size
520
+
521
+ def load_image_with_face_info(self, image_path: str):
522
+ img = self.load_image(image_path)
523
+
524
+ face_cx = face_cy = face_w = face_h = 0
525
+ if self.face_crop_aug_range is not None:
526
+ tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
527
+ if len(tokens) >= 5:
528
+ face_cx = int(tokens[-4])
529
+ face_cy = int(tokens[-3])
530
+ face_w = int(tokens[-2])
531
+ face_h = int(tokens[-1])
532
+
533
+ return img, face_cx, face_cy, face_w, face_h
534
+
535
+ # いい感じに切り出す
536
+ def crop_target(self, image, face_cx, face_cy, face_w, face_h):
537
+ height, width = image.shape[0:2]
538
+ if height == self.height and width == self.width:
539
+ return image
540
+
541
+ # 画像サイズはsizeより大きいのでリサイズする
542
+ face_size = max(face_w, face_h)
543
+ min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
544
+ min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
545
+ max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
546
+ if min_scale >= max_scale: # range指定がmin==max
547
+ scale = min_scale
548
+ else:
549
+ scale = random.uniform(min_scale, max_scale)
550
+
551
+ nh = int(height * scale + .5)
552
+ nw = int(width * scale + .5)
553
+ assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
554
+ image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
555
+ face_cx = int(face_cx * scale + .5)
556
+ face_cy = int(face_cy * scale + .5)
557
+ height, width = nh, nw
558
+
559
+ # 顔を中心として448*640とかへ切り出す
560
+ for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
561
+ p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
562
+
563
+ if self.random_crop:
564
+ # 背景も含めるために顔を中心に置く確率を高めつつずらす
565
+ range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
566
+ p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
567
+ else:
568
+ # range指定があるときのみ、すこしだけランダムに(わりと適当)
569
+ if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
570
+ if face_size > self.size // 10 and face_size >= 40:
571
+ p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
572
+
573
+ p1 = max(0, min(p1, length - target_size))
574
+
575
+ if axis == 0:
576
+ image = image[p1:p1 + target_size, :]
577
+ else:
578
+ image = image[:, p1:p1 + target_size]
579
+
580
+ return image
581
+
582
+ def load_latents_from_npz(self, image_info: ImageInfo, flipped):
583
+ npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
584
+ if npz_file is None:
585
+ return None
586
+ return np.load(npz_file)['arr_0']
587
+
588
+ def __len__(self):
589
+ return self._length
590
+
591
+ def __getitem__(self, index):
592
+ if index == 0:
593
+ self.shuffle_buckets()
594
+
595
+ bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
596
+ bucket_batch_size = self.buckets_indices[index].bucket_batch_size
597
+ image_index = self.buckets_indices[index].batch_index * bucket_batch_size
598
+
599
+ loss_weights = []
600
+ captions = []
601
+ input_ids_list = []
602
+ latents_list = []
603
+ images = []
604
+
605
+ for image_key in bucket[image_index:image_index + bucket_batch_size]:
606
+ image_info = self.image_data[image_key]
607
+ loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
608
+
609
+ # image/latentsを処理する
610
+ if image_info.latents is not None:
611
+ latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped
612
+ image = None
613
+ elif image_info.latents_npz is not None:
614
+ latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5)
615
+ latents = torch.FloatTensor(latents)
616
+ image = None
617
+ else:
618
+ # 画像を読み込み、必要ならcropする
619
+ img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
620
+ im_h, im_w = img.shape[0:2]
621
+
622
+ if self.enable_bucket:
623
+ img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
624
+ else:
625
+ if face_cx > 0: # 顔位置情報あり
626
+ img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
627
+ elif im_h > self.height or im_w > self.width:
628
+ assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
629
+ if im_h > self.height:
630
+ p = random.randint(0, im_h - self.height)
631
+ img = img[p:p + self.height]
632
+ if im_w > self.width:
633
+ p = random.randint(0, im_w - self.width)
634
+ img = img[:, p:p + self.width]
635
+
636
+ im_h, im_w = img.shape[0:2]
637
+ assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
638
+
639
+ # augmentation
640
+ if self.aug is not None:
641
+ img = self.aug(image=img)['image']
642
+
643
+ latents = None
644
+ image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
645
+
646
+ images.append(image)
647
+ latents_list.append(latents)
648
+
649
+ caption = self.process_caption(image_info.caption)
650
+ captions.append(caption)
651
+ if not self.token_padding_disabled: # this option might be omitted in future
652
+ input_ids_list.append(self.get_input_ids(caption))
653
+
654
+ example = {}
655
+ example['loss_weights'] = torch.FloatTensor(loss_weights)
656
+
657
+ if self.token_padding_disabled:
658
+ # padding=True means pad in the batch
659
+ example['input_ids'] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
660
+ else:
661
+ # batch processing seems to be good
662
+ example['input_ids'] = torch.stack(input_ids_list)
663
+
664
+ if images[0] is not None:
665
+ images = torch.stack(images)
666
+ images = images.to(memory_format=torch.contiguous_format).float()
667
+ else:
668
+ images = None
669
+ example['images'] = images
670
+
671
+ example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None
672
+
673
+ if self.debug_dataset:
674
+ example['image_keys'] = bucket[image_index:image_index + self.batch_size]
675
+ example['captions'] = captions
676
+ return example
677
+
678
+
679
+ class DreamBoothDataset(BaseDataset):
680
+ def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
681
+ super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
682
+ resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
683
+
684
+ assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
685
+
686
+ self.batch_size = batch_size
687
+ self.size = min(self.width, self.height) # 短いほう
688
+ self.prior_loss_weight = prior_loss_weight
689
+ self.latents_cache = None
690
+
691
+ self.enable_bucket = enable_bucket
692
+ if self.enable_bucket:
693
+ assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
694
+ assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
695
+ self.min_bucket_reso = min_bucket_reso
696
+ self.max_bucket_reso = max_bucket_reso
697
+ self.bucket_reso_steps = bucket_reso_steps
698
+ self.bucket_no_upscale = bucket_no_upscale
699
+ else:
700
+ self.min_bucket_reso = None
701
+ self.max_bucket_reso = None
702
+ self.bucket_reso_steps = None # この情報は使われない
703
+ self.bucket_no_upscale = False
704
+
705
+ def read_caption(img_path):
706
+ # captionの候補ファイル名を作る
707
+ base_name = os.path.splitext(img_path)[0]
708
+ base_name_face_det = base_name
709
+ tokens = base_name.split("_")
710
+ if len(tokens) >= 5:
711
+ base_name_face_det = "_".join(tokens[:-4])
712
+ cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
713
+
714
+ caption = None
715
+ for cap_path in cap_paths:
716
+ if os.path.isfile(cap_path):
717
+ with open(cap_path, "rt", encoding='utf-8') as f:
718
+ try:
719
+ lines = f.readlines()
720
+ except UnicodeDecodeError as e:
721
+ print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
722
+ raise e
723
+ assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
724
+ caption = lines[0].strip()
725
+ break
726
+ return caption
727
+
728
+ def load_dreambooth_dir(dir):
729
+ if not os.path.isdir(dir):
730
+ # print(f"ignore file: {dir}")
731
+ return 0, [], []
732
+
733
+ tokens = os.path.basename(dir).split('_')
734
+ try:
735
+ n_repeats = int(tokens[0])
736
+ except ValueError as e:
737
+ print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
738
+ return 0, [], []
739
+
740
+ caption_by_folder = '_'.join(tokens[1:])
741
+ img_paths = glob_images(dir, "*")
742
+ print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
743
+
744
+ # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
745
+ captions = []
746
+ for img_path in img_paths:
747
+ cap_for_img = read_caption(img_path)
748
+ captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
749
+
750
+ self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
751
+
752
+ return n_repeats, img_paths, captions
753
+
754
+ print("prepare train images.")
755
+ train_dirs = os.listdir(train_data_dir)
756
+ num_train_images = 0
757
+ for dir in train_dirs:
758
+ n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
759
+ num_train_images += n_repeats * len(img_paths)
760
+
761
+ for img_path, caption in zip(img_paths, captions):
762
+ info = ImageInfo(img_path, n_repeats, caption, False, img_path)
763
+ self.register_image(info)
764
+
765
+ self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
766
+
767
+ print(f"{num_train_images} train images with repeating.")
768
+ self.num_train_images = num_train_images
769
+
770
+ # reg imageは数を数えて学習画像と同じ枚数にする
771
+ num_reg_images = 0
772
+ if reg_data_dir:
773
+ print("prepare reg images.")
774
+ reg_infos: List[ImageInfo] = []
775
+
776
+ reg_dirs = os.listdir(reg_data_dir)
777
+ for dir in reg_dirs:
778
+ n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
779
+ num_reg_images += n_repeats * len(img_paths)
780
+
781
+ for img_path, caption in zip(img_paths, captions):
782
+ info = ImageInfo(img_path, n_repeats, caption, True, img_path)
783
+ reg_infos.append(info)
784
+
785
+ self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
786
+
787
+ print(f"{num_reg_images} reg images.")
788
+ if num_train_images < num_reg_images:
789
+ print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
790
+
791
+ if num_reg_images == 0:
792
+ print("no regularization images / 正則化画像が見つかりませんでした")
793
+ else:
794
+ # num_repeatsを計算する:どうせ大した数ではないのでループで処理する
795
+ n = 0
796
+ first_loop = True
797
+ while n < num_train_images:
798
+ for info in reg_infos:
799
+ if first_loop:
800
+ self.register_image(info)
801
+ n += info.num_repeats
802
+ else:
803
+ info.num_repeats += 1
804
+ n += 1
805
+ if n >= num_train_images:
806
+ break
807
+ first_loop = False
808
+
809
+ self.num_reg_images = num_reg_images
810
+
811
+
812
+ class FineTuningDataset(BaseDataset):
813
+ def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
814
+ super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
815
+ resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
816
+
817
+ # メタデータを読み込む
818
+ if os.path.exists(json_file_name):
819
+ print(f"loading existing metadata: {json_file_name}")
820
+ with open(json_file_name, "rt", encoding='utf-8') as f:
821
+ metadata = json.load(f)
822
+ else:
823
+ raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}")
824
+
825
+ self.metadata = metadata
826
+ self.train_data_dir = train_data_dir
827
+ self.batch_size = batch_size
828
+
829
+ tags_list = []
830
+ for image_key, img_md in metadata.items():
831
+ # path情報を作る
832
+ if os.path.exists(image_key):
833
+ abs_path = image_key
834
+ else:
835
+ # わりといい加減だがいい方法が思いつかん
836
+ abs_path = glob_images(train_data_dir, image_key)
837
+ assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
838
+ abs_path = abs_path[0]
839
+
840
+ caption = img_md.get('caption')
841
+ tags = img_md.get('tags')
842
+ if caption is None:
843
+ caption = tags
844
+ elif tags is not None and len(tags) > 0:
845
+ caption = caption + ', ' + tags
846
+ tags_list.append(tags)
847
+ assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
848
+
849
+ image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
850
+ image_info.image_size = img_md.get('train_resolution')
851
+
852
+ if not self.color_aug and not self.random_crop:
853
+ # if npz exists, use them
854
+ image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
855
+
856
+ self.register_image(image_info)
857
+ self.num_train_images = len(metadata) * dataset_repeats
858
+ self.num_reg_images = 0
859
+
860
+ # TODO do not record tag freq when no tag
861
+ self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
862
+ self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
863
+
864
+ # check existence of all npz files
865
+ use_npz_latents = not (self.color_aug or self.random_crop)
866
+ if use_npz_latents:
867
+ npz_any = False
868
+ npz_all = True
869
+ for image_info in self.image_data.values():
870
+ has_npz = image_info.latents_npz is not None
871
+ npz_any = npz_any or has_npz
872
+
873
+ if self.flip_aug:
874
+ has_npz = has_npz and image_info.latents_npz_flipped is not None
875
+ npz_all = npz_all and has_npz
876
+
877
+ if npz_any and not npz_all:
878
+ break
879
+
880
+ if not npz_any:
881
+ use_npz_latents = False
882
+ print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します")
883
+ elif not npz_all:
884
+ use_npz_latents = False
885
+ print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
886
+ if self.flip_aug:
887
+ print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
888
+ # else:
889
+ # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
890
+
891
+ # check min/max bucket size
892
+ sizes = set()
893
+ resos = set()
894
+ for image_info in self.image_data.values():
895
+ if image_info.image_size is None:
896
+ sizes = None # not calculated
897
+ break
898
+ sizes.add(image_info.image_size[0])
899
+ sizes.add(image_info.image_size[1])
900
+ resos.add(tuple(image_info.image_size))
901
+
902
+ if sizes is None:
903
+ if use_npz_latents:
904
+ use_npz_latents = False
905
+ print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します")
906
+
907
+ assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
908
+
909
+ self.enable_bucket = enable_bucket
910
+ if self.enable_bucket:
911
+ self.min_bucket_reso = min_bucket_reso
912
+ self.max_bucket_reso = max_bucket_reso
913
+ self.bucket_reso_steps = bucket_reso_steps
914
+ self.bucket_no_upscale = bucket_no_upscale
915
+ else:
916
+ if not enable_bucket:
917
+ print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします")
918
+ print("using bucket info in metadata / メタデータ内のbucket情報を使います")
919
+ self.enable_bucket = True
920
+
921
+ assert not bucket_no_upscale, "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません"
922
+
923
+ # bucket情報を初期化しておく、make_bucketsで再作成しない
924
+ self.bucket_manager = BucketManager(False, None, None, None, None)
925
+ self.bucket_manager.set_predefined_resos(resos)
926
+
927
+ # npz情報をきれいにしておく
928
+ if not use_npz_latents:
929
+ for image_info in self.image_data.values():
930
+ image_info.latents_npz = image_info.latents_npz_flipped = None
931
+
932
+ def image_key_to_npz_file(self, image_key):
933
+ base_name = os.path.splitext(image_key)[0]
934
+ npz_file_norm = base_name + '.npz'
935
+
936
+ if os.path.exists(npz_file_norm):
937
+ # image_key is full path
938
+ npz_file_flip = base_name + '_flip.npz'
939
+ if not os.path.exists(npz_file_flip):
940
+ npz_file_flip = None
941
+ return npz_file_norm, npz_file_flip
942
+
943
+ # image_key is relative path
944
+ npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz')
945
+ npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz')
946
+
947
+ if not os.path.exists(npz_file_norm):
948
+ npz_file_norm = None
949
+ npz_file_flip = None
950
+ elif not os.path.exists(npz_file_flip):
951
+ npz_file_flip = None
952
+
953
+ return npz_file_norm, npz_file_flip
954
+
955
+
956
+ def debug_dataset(train_dataset, show_input_ids=False):
957
+ print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
958
+ print("Escape for exit. / Escキーで中断、終了します")
959
+
960
+ train_dataset.set_current_epoch(1)
961
+ k = 0
962
+ for i, example in enumerate(train_dataset):
963
+ if example['latents'] is not None:
964
+ print(f"sample has latents from npz file: {example['latents'].size()}")
965
+ for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
966
+ print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
967
+ if show_input_ids:
968
+ print(f"input ids: {iid}")
969
+ if example['images'] is not None:
970
+ im = example['images'][j]
971
+ print(f"image size: {im.size()}")
972
+ im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
973
+ im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
974
+ im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
975
+ if os.name == 'nt': # only windows
976
+ cv2.imshow("img", im)
977
+ k = cv2.waitKey()
978
+ cv2.destroyAllWindows()
979
+ if k == 27:
980
+ break
981
+ if k == 27 or (example['images'] is None and i >= 8):
982
+ break
983
+
984
+
985
+ def glob_images(directory, base="*"):
986
+ img_paths = []
987
+ for ext in IMAGE_EXTENSIONS:
988
+ if base == '*':
989
+ img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
990
+ else:
991
+ img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
992
+ # img_paths = list(set(img_paths)) # 重複を排除
993
+ # img_paths.sort()
994
+ return img_paths
995
+
996
+
997
+ def glob_images_pathlib(dir_path, recursive):
998
+ image_paths = []
999
+ if recursive:
1000
+ for ext in IMAGE_EXTENSIONS:
1001
+ image_paths += list(dir_path.rglob('*' + ext))
1002
+ else:
1003
+ for ext in IMAGE_EXTENSIONS:
1004
+ image_paths += list(dir_path.glob('*' + ext))
1005
+ # image_paths = list(set(image_paths)) # 重複を排除
1006
+ # image_paths.sort()
1007
+ return image_paths
1008
+
1009
+ # endregion
1010
+
1011
+
1012
+ # region モジュール入れ替え部
1013
+ """
1014
+ 高速化のためのモジュール入れ替え
1015
+ """
1016
+
1017
+ # FlashAttentionを使うCrossAttention
1018
+ # based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
1019
+ # LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
1020
+
1021
+ # constants
1022
+
1023
+ EPSILON = 1e-6
1024
+
1025
+ # helper functions
1026
+
1027
+
1028
+ def exists(val):
1029
+ return val is not None
1030
+
1031
+
1032
+ def default(val, d):
1033
+ return val if exists(val) else d
1034
+
1035
+
1036
+ def model_hash(filename):
1037
+ """Old model hash used by stable-diffusion-webui"""
1038
+ try:
1039
+ with open(filename, "rb") as file:
1040
+ m = hashlib.sha256()
1041
+
1042
+ file.seek(0x100000)
1043
+ m.update(file.read(0x10000))
1044
+ return m.hexdigest()[0:8]
1045
+ except FileNotFoundError:
1046
+ return 'NOFILE'
1047
+
1048
+
1049
+ def calculate_sha256(filename):
1050
+ """New model hash used by stable-diffusion-webui"""
1051
+ hash_sha256 = hashlib.sha256()
1052
+ blksize = 1024 * 1024
1053
+
1054
+ with open(filename, "rb") as f:
1055
+ for chunk in iter(lambda: f.read(blksize), b""):
1056
+ hash_sha256.update(chunk)
1057
+
1058
+ return hash_sha256.hexdigest()
1059
+
1060
+
1061
+ def precalculate_safetensors_hashes(tensors, metadata):
1062
+ """Precalculate the model hashes needed by sd-webui-additional-networks to
1063
+ save time on indexing the model later."""
1064
+
1065
+ # Because writing user metadata to the file can change the result of
1066
+ # sd_models.model_hash(), only retain the training metadata for purposes of
1067
+ # calculating the hash, as they are meant to be immutable
1068
+ metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
1069
+
1070
+ bytes = safetensors.torch.save(tensors, metadata)
1071
+ b = BytesIO(bytes)
1072
+
1073
+ model_hash = addnet_hash_safetensors(b)
1074
+ legacy_hash = addnet_hash_legacy(b)
1075
+ return model_hash, legacy_hash
1076
+
1077
+
1078
+ def addnet_hash_legacy(b):
1079
+ """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
1080
+ m = hashlib.sha256()
1081
+
1082
+ b.seek(0x100000)
1083
+ m.update(b.read(0x10000))
1084
+ return m.hexdigest()[0:8]
1085
+
1086
+
1087
+ def addnet_hash_safetensors(b):
1088
+ """New model hash used by sd-webui-additional-networks for .safetensors format files"""
1089
+ hash_sha256 = hashlib.sha256()
1090
+ blksize = 1024 * 1024
1091
+
1092
+ b.seek(0)
1093
+ header = b.read(8)
1094
+ n = int.from_bytes(header, "little")
1095
+
1096
+ offset = n + 8
1097
+ b.seek(offset)
1098
+ for chunk in iter(lambda: b.read(blksize), b""):
1099
+ hash_sha256.update(chunk)
1100
+
1101
+ return hash_sha256.hexdigest()
1102
+
1103
+
1104
+ def get_git_revision_hash() -> str:
1105
+ try:
1106
+ return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=os.path.dirname(__file__)).decode('ascii').strip()
1107
+ except:
1108
+ return "(unknown)"
1109
+
1110
+
1111
+ # flash attention forwards and backwards
1112
+
1113
+ # https://arxiv.org/abs/2205.14135
1114
+
1115
+
1116
+ class FlashAttentionFunction(torch.autograd.function.Function):
1117
+ @ staticmethod
1118
+ @ torch.no_grad()
1119
+ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
1120
+ """ Algorithm 2 in the paper """
1121
+
1122
+ device = q.device
1123
+ dtype = q.dtype
1124
+ max_neg_value = -torch.finfo(q.dtype).max
1125
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
1126
+
1127
+ o = torch.zeros_like(q)
1128
+ all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
1129
+ all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
1130
+
1131
+ scale = (q.shape[-1] ** -0.5)
1132
+
1133
+ if not exists(mask):
1134
+ mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
1135
+ else:
1136
+ mask = rearrange(mask, 'b n -> b 1 1 n')
1137
+ mask = mask.split(q_bucket_size, dim=-1)
1138
+
1139
+ row_splits = zip(
1140
+ q.split(q_bucket_size, dim=-2),
1141
+ o.split(q_bucket_size, dim=-2),
1142
+ mask,
1143
+ all_row_sums.split(q_bucket_size, dim=-2),
1144
+ all_row_maxes.split(q_bucket_size, dim=-2),
1145
+ )
1146
+
1147
+ for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
1148
+ q_start_index = ind * q_bucket_size - qk_len_diff
1149
+
1150
+ col_splits = zip(
1151
+ k.split(k_bucket_size, dim=-2),
1152
+ v.split(k_bucket_size, dim=-2),
1153
+ )
1154
+
1155
+ for k_ind, (kc, vc) in enumerate(col_splits):
1156
+ k_start_index = k_ind * k_bucket_size
1157
+
1158
+ attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
1159
+
1160
+ if exists(row_mask):
1161
+ attn_weights.masked_fill_(~row_mask, max_neg_value)
1162
+
1163
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
1164
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
1165
+ device=device).triu(q_start_index - k_start_index + 1)
1166
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
1167
+
1168
+ block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
1169
+ attn_weights -= block_row_maxes
1170
+ exp_weights = torch.exp(attn_weights)
1171
+
1172
+ if exists(row_mask):
1173
+ exp_weights.masked_fill_(~row_mask, 0.)
1174
+
1175
+ block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
1176
+
1177
+ new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
1178
+
1179
+ exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
1180
+
1181
+ exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
1182
+ exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
1183
+
1184
+ new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
1185
+
1186
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
1187
+
1188
+ row_maxes.copy_(new_row_maxes)
1189
+ row_sums.copy_(new_row_sums)
1190
+
1191
+ ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
1192
+ ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
1193
+
1194
+ return o
1195
+
1196
+ @ staticmethod
1197
+ @ torch.no_grad()
1198
+ def backward(ctx, do):
1199
+ """ Algorithm 4 in the paper """
1200
+
1201
+ causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
1202
+ q, k, v, o, l, m = ctx.saved_tensors
1203
+
1204
+ device = q.device
1205
+
1206
+ max_neg_value = -torch.finfo(q.dtype).max
1207
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
1208
+
1209
+ dq = torch.zeros_like(q)
1210
+ dk = torch.zeros_like(k)
1211
+ dv = torch.zeros_like(v)
1212
+
1213
+ row_splits = zip(
1214
+ q.split(q_bucket_size, dim=-2),
1215
+ o.split(q_bucket_size, dim=-2),
1216
+ do.split(q_bucket_size, dim=-2),
1217
+ mask,
1218
+ l.split(q_bucket_size, dim=-2),
1219
+ m.split(q_bucket_size, dim=-2),
1220
+ dq.split(q_bucket_size, dim=-2)
1221
+ )
1222
+
1223
+ for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
1224
+ q_start_index = ind * q_bucket_size - qk_len_diff
1225
+
1226
+ col_splits = zip(
1227
+ k.split(k_bucket_size, dim=-2),
1228
+ v.split(k_bucket_size, dim=-2),
1229
+ dk.split(k_bucket_size, dim=-2),
1230
+ dv.split(k_bucket_size, dim=-2),
1231
+ )
1232
+
1233
+ for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
1234
+ k_start_index = k_ind * k_bucket_size
1235
+
1236
+ attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
1237
+
1238
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
1239
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
1240
+ device=device).triu(q_start_index - k_start_index + 1)
1241
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
1242
+
1243
+ exp_attn_weights = torch.exp(attn_weights - mc)
1244
+
1245
+ if exists(row_mask):
1246
+ exp_attn_weights.masked_fill_(~row_mask, 0.)
1247
+
1248
+ p = exp_attn_weights / lc
1249
+
1250
+ dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
1251
+ dp = einsum('... i d, ... j d -> ... i j', doc, vc)
1252
+
1253
+ D = (doc * oc).sum(dim=-1, keepdims=True)
1254
+ ds = p * scale * (dp - D)
1255
+
1256
+ dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
1257
+ dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
1258
+
1259
+ dqc.add_(dq_chunk)
1260
+ dkc.add_(dk_chunk)
1261
+ dvc.add_(dv_chunk)
1262
+
1263
+ return dq, dk, dv, None, None, None, None
1264
+
1265
+
1266
+ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
1267
+ if mem_eff_attn:
1268
+ replace_unet_cross_attn_to_memory_efficient()
1269
+ elif xformers:
1270
+ replace_unet_cross_attn_to_xformers()
1271
+
1272
+
1273
+ def replace_unet_cross_attn_to_memory_efficient():
1274
+ print("Replace CrossAttention.forward to use FlashAttention (not xformers)")
1275
+ flash_func = FlashAttentionFunction
1276
+
1277
+ def forward_flash_attn(self, x, context=None, mask=None):
1278
+ q_bucket_size = 512
1279
+ k_bucket_size = 1024
1280
+
1281
+ h = self.heads
1282
+ q = self.to_q(x)
1283
+
1284
+ context = context if context is not None else x
1285
+ context = context.to(x.dtype)
1286
+
1287
+ if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
1288
+ context_k, context_v = self.hypernetwork.forward(x, context)
1289
+ context_k = context_k.to(x.dtype)
1290
+ context_v = context_v.to(x.dtype)
1291
+ else:
1292
+ context_k = context
1293
+ context_v = context
1294
+
1295
+ k = self.to_k(context_k)
1296
+ v = self.to_v(context_v)
1297
+ del context, x
1298
+
1299
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
1300
+
1301
+ out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
1302
+
1303
+ out = rearrange(out, 'b h n d -> b n (h d)')
1304
+
1305
+ # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`)
1306
+ out = self.to_out[0](out)
1307
+ out = self.to_out[1](out)
1308
+ return out
1309
+
1310
+ diffusers.models.attention.CrossAttention.forward = forward_flash_attn
1311
+
1312
+
1313
+ def replace_unet_cross_attn_to_xformers():
1314
+ print("Replace CrossAttention.forward to use xformers")
1315
+ try:
1316
+ import xformers.ops
1317
+ except ImportError:
1318
+ raise ImportError("No xformers / xformersがインストールされていないようです")
1319
+
1320
+ def forward_xformers(self, x, context=None, mask=None):
1321
+ h = self.heads
1322
+ q_in = self.to_q(x)
1323
+
1324
+ context = default(context, x)
1325
+ context = context.to(x.dtype)
1326
+
1327
+ if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
1328
+ context_k, context_v = self.hypernetwork.forward(x, context)
1329
+ context_k = context_k.to(x.dtype)
1330
+ context_v = context_v.to(x.dtype)
1331
+ else:
1332
+ context_k = context
1333
+ context_v = context
1334
+
1335
+ k_in = self.to_k(context_k)
1336
+ v_in = self.to_v(context_v)
1337
+
1338
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
1339
+ del q_in, k_in, v_in
1340
+
1341
+ q = q.contiguous()
1342
+ k = k.contiguous()
1343
+ v = v.contiguous()
1344
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
1345
+
1346
+ out = rearrange(out, 'b n h d -> b n (h d)', h=h)
1347
+
1348
+ # diffusers 0.7.0~
1349
+ out = self.to_out[0](out)
1350
+ out = self.to_out[1](out)
1351
+ return out
1352
+
1353
+ diffusers.models.attention.CrossAttention.forward = forward_xformers
1354
+ # endregion
1355
+
1356
+
1357
+ # region arguments
1358
+
1359
+ def add_sd_models_arguments(parser: argparse.ArgumentParser):
1360
+ # for pretrained models
1361
+ parser.add_argument("--v2", action='store_true',
1362
+ help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
1363
+ parser.add_argument("--v_parameterization", action='store_true',
1364
+ help='enable v-parameterization training / v-parameterization学習を有効にする')
1365
+ parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
1366
+ help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
1367
+
1368
+
1369
+ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
1370
+ parser.add_argument("--output_dir", type=str, default=None,
1371
+ help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
1372
+ parser.add_argument("--output_name", type=str, default=None,
1373
+ help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
1374
+ parser.add_argument("--save_precision", type=str, default=None,
1375
+ choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
1376
+ parser.add_argument("--save_every_n_epochs", type=int, default=None,
1377
+ help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
1378
+ parser.add_argument("--save_n_epoch_ratio", type=int, default=None,
1379
+ help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)")
1380
+ parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
1381
+ parser.add_argument("--save_last_n_epochs_state", type=int, default=None,
1382
+ help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
1383
+ parser.add_argument("--save_state", action="store_true",
1384
+ help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
1385
+ parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
1386
+
1387
+ parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
1388
+ parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
1389
+ help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
1390
+ parser.add_argument("--use_8bit_adam", action="store_true",
1391
+ help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
1392
+ parser.add_argument("--use_lion_optimizer", action="store_true",
1393
+ help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
1394
+ parser.add_argument("--mem_eff_attn", action="store_true",
1395
+ help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
1396
+ parser.add_argument("--xformers", action="store_true",
1397
+ help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
1398
+ parser.add_argument("--vae", type=str, default=None,
1399
+ help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
1400
+
1401
+ parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
1402
+ parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
1403
+ parser.add_argument("--max_train_epochs", type=int, default=None,
1404
+ help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
1405
+ parser.add_argument("--max_data_loader_n_workers", type=int, default=8,
1406
+ help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
1407
+ parser.add_argument("--persistent_data_loader_workers", action="store_true",
1408
+ help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)")
1409
+ parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
1410
+ parser.add_argument("--gradient_checkpointing", action="store_true",
1411
+ help="enable gradient checkpointing / grandient checkpointingを有効にする")
1412
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
1413
+ help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数")
1414
+ parser.add_argument("--mixed_precision", type=str, default="no",
1415
+ choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
1416
+ parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
1417
+ parser.add_argument("--clip_skip", type=int, default=None,
1418
+ help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
1419
+ parser.add_argument("--logging_dir", type=str, default=None,
1420
+ help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
1421
+ parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
1422
+ parser.add_argument("--lr_scheduler", type=str, default="constant",
1423
+ help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
1424
+ parser.add_argument("--lr_warmup_steps", type=int, default=0,
1425
+ help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
1426
+ parser.add_argument("--noise_offset", type=float, default=None,
1427
+ help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
1428
+ parser.add_argument("--lowram", action="store_true",
1429
+ help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
1430
+
1431
+ if support_dreambooth:
1432
+ # DreamBooth training
1433
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0,
1434
+ help="loss weight for regularization images / 正則化画像のlossの重み")
1435
+
1436
+
1437
+ def verify_training_args(args: argparse.Namespace):
1438
+ if args.v_parameterization and not args.v2:
1439
+ print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
1440
+ if args.v2 and args.clip_skip is not None:
1441
+ print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
1442
+
1443
+
1444
+ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool):
1445
+ # dataset common
1446
+ parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
1447
+ parser.add_argument("--shuffle_caption", action="store_true",
1448
+ help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする")
1449
+ parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
1450
+ parser.add_argument("--caption_extention", type=str, default=None,
1451
+ help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
1452
+ parser.add_argument("--keep_tokens", type=int, default=None,
1453
+ help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
1454
+ parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
1455
+ parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
1456
+ parser.add_argument("--face_crop_aug_range", type=str, default=None,
1457
+ help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)")
1458
+ parser.add_argument("--random_crop", action="store_true",
1459
+ help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)")
1460
+ parser.add_argument("--debug_dataset", action="store_true",
1461
+ help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
1462
+ parser.add_argument("--resolution", type=str, default=None,
1463
+ help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)")
1464
+ parser.add_argument("--cache_latents", action="store_true",
1465
+ help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
1466
+ parser.add_argument("--enable_bucket", action="store_true",
1467
+ help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
1468
+ parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
1469
+ parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
1470
+ parser.add_argument("--bucket_reso_steps", type=int, default=64,
1471
+ help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
1472
+ parser.add_argument("--bucket_no_upscale", action="store_true",
1473
+ help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
1474
+
1475
+ if support_caption_dropout:
1476
+ # Textual Inversion はcaptionのdropoutをsupportしない
1477
+ # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
1478
+ parser.add_argument("--caption_dropout_rate", type=float, default=0,
1479
+ help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
1480
+ parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
1481
+ help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
1482
+ parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
1483
+ help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
1484
+
1485
+ if support_dreambooth:
1486
+ # DreamBooth dataset
1487
+ parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
1488
+
1489
+ if support_caption:
1490
+ # caption dataset
1491
+ parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル")
1492
+ parser.add_argument("--dataset_repeats", type=int, default=1,
1493
+ help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数")
1494
+
1495
+
1496
+ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
1497
+ parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"],
1498
+ help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)")
1499
+ parser.add_argument("--use_safetensors", action='store_true',
1500
+ help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)")
1501
+
1502
+ # endregion
1503
+
1504
+ # region utils
1505
+
1506
+
1507
+ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
1508
+ # backward compatibility
1509
+ if args.caption_extention is not None:
1510
+ args.caption_extension = args.caption_extention
1511
+ args.caption_extention = None
1512
+
1513
+ if args.cache_latents:
1514
+ assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
1515
+ assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
1516
+
1517
+ # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
1518
+ if args.resolution is not None:
1519
+ args.resolution = tuple([int(r) for r in args.resolution.split(',')])
1520
+ if len(args.resolution) == 1:
1521
+ args.resolution = (args.resolution[0], args.resolution[0])
1522
+ assert len(args.resolution) == 2, \
1523
+ f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
1524
+
1525
+ if args.face_crop_aug_range is not None:
1526
+ args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
1527
+ assert len(args.face_crop_aug_range) == 2 and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1], \
1528
+ f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}"
1529
+ else:
1530
+ args.face_crop_aug_range = None
1531
+
1532
+ if support_metadata:
1533
+ if args.in_json is not None and (args.color_aug or args.random_crop):
1534
+ print(f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます")
1535
+
1536
+
1537
+ def load_tokenizer(args: argparse.Namespace):
1538
+ print("prepare tokenizer")
1539
+ if args.v2:
1540
+ tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
1541
+ else:
1542
+ tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
1543
+ if args.max_token_length is not None:
1544
+ print(f"update token length: {args.max_token_length}")
1545
+ return tokenizer
1546
+
1547
+
1548
+ def prepare_accelerator(args: argparse.Namespace):
1549
+ if args.logging_dir is None:
1550
+ log_with = None
1551
+ logging_dir = None
1552
+ else:
1553
+ log_with = "tensorboard"
1554
+ log_prefix = "" if args.log_prefix is None else args.log_prefix
1555
+ logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime())
1556
+
1557
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision,
1558
+ log_with=log_with, logging_dir=logging_dir)
1559
+
1560
+ # accelerateの互換性問題を解決する
1561
+ accelerator_0_15 = True
1562
+ try:
1563
+ accelerator.unwrap_model("dummy", True)
1564
+ print("Using accelerator 0.15.0 or above.")
1565
+ except TypeError:
1566
+ accelerator_0_15 = False
1567
+
1568
+ def unwrap_model(model):
1569
+ if accelerator_0_15:
1570
+ return accelerator.unwrap_model(model, True)
1571
+ return accelerator.unwrap_model(model)
1572
+
1573
+ return accelerator, unwrap_model
1574
+
1575
+
1576
+ def prepare_dtype(args: argparse.Namespace):
1577
+ weight_dtype = torch.float32
1578
+ if args.mixed_precision == "fp16":
1579
+ weight_dtype = torch.float16
1580
+ elif args.mixed_precision == "bf16":
1581
+ weight_dtype = torch.bfloat16
1582
+
1583
+ save_dtype = None
1584
+ if args.save_precision == "fp16":
1585
+ save_dtype = torch.float16
1586
+ elif args.save_precision == "bf16":
1587
+ save_dtype = torch.bfloat16
1588
+ elif args.save_precision == "float":
1589
+ save_dtype = torch.float32
1590
+
1591
+ return weight_dtype, save_dtype
1592
+
1593
+
1594
+ def load_target_model(args: argparse.Namespace, weight_dtype):
1595
+ load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers
1596
+ if load_stable_diffusion_format:
1597
+ print("load StableDiffusion checkpoint")
1598
+ text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
1599
+ else:
1600
+ print("load Diffusers pretrained models")
1601
+ pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
1602
+ text_encoder = pipe.text_encoder
1603
+ vae = pipe.vae
1604
+ unet = pipe.unet
1605
+ del pipe
1606
+
1607
+ # VAEを読み込む
1608
+ if args.vae is not None:
1609
+ vae = model_util.load_vae(args.vae, weight_dtype)
1610
+ print("additional VAE loaded")
1611
+
1612
+ return text_encoder, vae, unet, load_stable_diffusion_format
1613
+
1614
+
1615
+ def patch_accelerator_for_fp16_training(accelerator):
1616
+ org_unscale_grads = accelerator.scaler._unscale_grads_
1617
+
1618
+ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
1619
+ return org_unscale_grads(optimizer, inv_scale, found_inf, True)
1620
+
1621
+ accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
1622
+
1623
+
1624
+ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None):
1625
+ # with no_token_padding, the length is not max length, return result immediately
1626
+ if input_ids.size()[-1] != tokenizer.model_max_length:
1627
+ return text_encoder(input_ids)[0]
1628
+
1629
+ b_size = input_ids.size()[0]
1630
+ input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
1631
+
1632
+ if args.clip_skip is None:
1633
+ encoder_hidden_states = text_encoder(input_ids)[0]
1634
+ else:
1635
+ enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
1636
+ encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
1637
+ encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
1638
+
1639
+ # bs*3, 77, 768 or 1024
1640
+ encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
1641
+
1642
+ if args.max_token_length is not None:
1643
+ if args.v2:
1644
+ # v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
1645
+ states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
1646
+ for i in range(1, args.max_token_length, tokenizer.model_max_length):
1647
+ chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # <BOS> の後から 最後の前まで
1648
+ if i > 0:
1649
+ for j in range(len(chunk)):
1650
+ if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
1651
+ chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
1652
+ states_list.append(chunk) # <BOS> の後から <EOS> の前まで
1653
+ states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
1654
+ encoder_hidden_states = torch.cat(states_list, dim=1)
1655
+ else:
1656
+ # v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
1657
+ states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
1658
+ for i in range(1, args.max_token_length, tokenizer.model_max_length):
1659
+ states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
1660
+ states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
1661
+ encoder_hidden_states = torch.cat(states_list, dim=1)
1662
+
1663
+ if weight_dtype is not None:
1664
+ # this is required for additional network training
1665
+ encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
1666
+
1667
+ return encoder_hidden_states
1668
+
1669
+
1670
+ def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch):
1671
+ model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
1672
+ ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt")
1673
+ return model_name, ckpt_name
1674
+
1675
+
1676
+ def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int):
1677
+ saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
1678
+ if saving:
1679
+ os.makedirs(args.output_dir, exist_ok=True)
1680
+ save_func()
1681
+
1682
+ if args.save_last_n_epochs is not None:
1683
+ remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
1684
+ remove_old_func(remove_epoch_no)
1685
+ return saving
1686
+
1687
+
1688
+ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae):
1689
+ epoch_no = epoch + 1
1690
+ model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no)
1691
+
1692
+ if save_stable_diffusion_format:
1693
+ def save_sd():
1694
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
1695
+ print(f"saving checkpoint: {ckpt_file}")
1696
+ model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
1697
+ src_path, epoch_no, global_step, save_dtype, vae)
1698
+
1699
+ def remove_sd(old_epoch_no):
1700
+ _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
1701
+ old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
1702
+ if os.path.exists(old_ckpt_file):
1703
+ print(f"removing old checkpoint: {old_ckpt_file}")
1704
+ os.remove(old_ckpt_file)
1705
+
1706
+ save_func = save_sd
1707
+ remove_old_func = remove_sd
1708
+ else:
1709
+ def save_du():
1710
+ out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no))
1711
+ print(f"saving model: {out_dir}")
1712
+ os.makedirs(out_dir, exist_ok=True)
1713
+ model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
1714
+ src_path, vae=vae, use_safetensors=use_safetensors)
1715
+
1716
+ def remove_du(old_epoch_no):
1717
+ out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
1718
+ if os.path.exists(out_dir_old):
1719
+ print(f"removing old model: {out_dir_old}")
1720
+ shutil.rmtree(out_dir_old)
1721
+
1722
+ save_func = save_du
1723
+ remove_old_func = remove_du
1724
+
1725
+ saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
1726
+ if saving and args.save_state:
1727
+ save_state_on_epoch_end(args, accelerator, model_name, epoch_no)
1728
+
1729
+
1730
+ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no):
1731
+ print("saving state.")
1732
+ accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
1733
+
1734
+ last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
1735
+ if last_n_epochs is not None:
1736
+ remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs
1737
+ state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
1738
+ if os.path.exists(state_dir_old):
1739
+ print(f"removing old state: {state_dir_old}")
1740
+ shutil.rmtree(state_dir_old)
1741
+
1742
+
1743
+ def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, global_step: int, text_encoder, unet, vae):
1744
+ model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
1745
+
1746
+ if save_stable_diffusion_format:
1747
+ os.makedirs(args.output_dir, exist_ok=True)
1748
+
1749
+ ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt")
1750
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
1751
+
1752
+ print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
1753
+ model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
1754
+ src_path, epoch, global_step, save_dtype, vae)
1755
+ else:
1756
+ out_dir = os.path.join(args.output_dir, model_name)
1757
+ os.makedirs(out_dir, exist_ok=True)
1758
+
1759
+ print(f"save trained model as Diffusers to {out_dir}")
1760
+ model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
1761
+ src_path, vae=vae, use_safetensors=use_safetensors)
1762
+
1763
+
1764
+ def save_state_on_train_end(args: argparse.Namespace, accelerator):
1765
+ print("saving last state.")
1766
+ os.makedirs(args.output_dir, exist_ok=True)
1767
+ model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
1768
+ accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
1769
+
1770
+ # endregion
1771
+
1772
+ # region 前処理用
1773
+
1774
+
1775
+ class ImageLoadingDataset(torch.utils.data.Dataset):
1776
+ def __init__(self, image_paths):
1777
+ self.images = image_paths
1778
+
1779
+ def __len__(self):
1780
+ return len(self.images)
1781
+
1782
+ def __getitem__(self, idx):
1783
+ img_path = self.images[idx]
1784
+
1785
+ try:
1786
+ image = Image.open(img_path).convert("RGB")
1787
+ # convert to tensor temporarily so dataloader will accept it
1788
+ tensor_pil = transforms.functional.pil_to_tensor(image)
1789
+ except Exception as e:
1790
+ print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
1791
+ return None
1792
+
1793
+ return (tensor_pil, img_path)
1794
+
1795
+
1796
+ # endregion
fine_tune.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training with captions
2
+ # XXX dropped option: hypernetwork training
3
+
4
+ import argparse
5
+ import gc
6
+ import math
7
+ import os
8
+
9
+ from tqdm import tqdm
10
+ import torch
11
+ from accelerate.utils import set_seed
12
+ import diffusers
13
+ from diffusers import DDPMScheduler
14
+
15
+ import library.train_util as train_util
16
+
17
+
18
+ def collate_fn(examples):
19
+ return examples[0]
20
+
21
+
22
+ def train(args):
23
+ train_util.verify_training_args(args)
24
+ train_util.prepare_dataset_args(args, True)
25
+
26
+ cache_latents = args.cache_latents
27
+
28
+ if args.seed is not None:
29
+ set_seed(args.seed) # 乱数系列を初期化する
30
+
31
+ tokenizer = train_util.load_tokenizer(args)
32
+
33
+ train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
34
+ tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
35
+ args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
36
+ args.bucket_reso_steps, args.bucket_no_upscale,
37
+ args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
38
+ args.dataset_repeats, args.debug_dataset)
39
+
40
+ # 学習データのdropout率を設定する
41
+ train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
42
+
43
+ train_dataset.make_buckets()
44
+
45
+ if args.debug_dataset:
46
+ train_util.debug_dataset(train_dataset)
47
+ return
48
+ if len(train_dataset) == 0:
49
+ print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
50
+ return
51
+
52
+ # acceleratorを準備する
53
+ print("prepare accelerator")
54
+ accelerator, unwrap_model = train_util.prepare_accelerator(args)
55
+
56
+ # mixed precisionに対応した型を用意しておき適宜castする
57
+ weight_dtype, save_dtype = train_util.prepare_dtype(args)
58
+
59
+ # モデルを読み込む
60
+ text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
61
+
62
+ # verify load/save model formats
63
+ if load_stable_diffusion_format:
64
+ src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
65
+ src_diffusers_model_path = None
66
+ else:
67
+ src_stable_diffusion_ckpt = None
68
+ src_diffusers_model_path = args.pretrained_model_name_or_path
69
+
70
+ if args.save_model_as is None:
71
+ save_stable_diffusion_format = load_stable_diffusion_format
72
+ use_safetensors = args.use_safetensors
73
+ else:
74
+ save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors'
75
+ use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
76
+
77
+ # Diffusers版のxformers使用フラグを設定する関数
78
+ def set_diffusers_xformers_flag(model, valid):
79
+ # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう
80
+ # pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`)
81
+ # U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか
82
+ # 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^)
83
+
84
+ # Recursively walk through all the children.
85
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
86
+ # gets the message
87
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
88
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
89
+ module.set_use_memory_efficient_attention_xformers(valid)
90
+
91
+ for child in module.children():
92
+ fn_recursive_set_mem_eff(child)
93
+
94
+ fn_recursive_set_mem_eff(model)
95
+
96
+ # モデルに xformers とか memory efficient attention を組み込む
97
+ if args.diffusers_xformers:
98
+ print("Use xformers by Diffusers")
99
+ set_diffusers_xformers_flag(unet, True)
100
+ else:
101
+ # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
102
+ print("Disable Diffusers' xformers")
103
+ set_diffusers_xformers_flag(unet, False)
104
+ train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
105
+
106
+ # 学習を準備する
107
+ if cache_latents:
108
+ vae.to(accelerator.device, dtype=weight_dtype)
109
+ vae.requires_grad_(False)
110
+ vae.eval()
111
+ with torch.no_grad():
112
+ train_dataset.cache_latents(vae)
113
+ vae.to("cpu")
114
+ if torch.cuda.is_available():
115
+ torch.cuda.empty_cache()
116
+ gc.collect()
117
+
118
+ # 学習を準備する:モデルを適切な状態にする
119
+ training_models = []
120
+ if args.gradient_checkpointing:
121
+ unet.enable_gradient_checkpointing()
122
+ training_models.append(unet)
123
+
124
+ if args.train_text_encoder:
125
+ print("enable text encoder training")
126
+ if args.gradient_checkpointing:
127
+ text_encoder.gradient_checkpointing_enable()
128
+ training_models.append(text_encoder)
129
+ else:
130
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
131
+ text_encoder.requires_grad_(False) # text encoderは学習しない
132
+ if args.gradient_checkpointing:
133
+ text_encoder.gradient_checkpointing_enable()
134
+ text_encoder.train() # required for gradient_checkpointing
135
+ else:
136
+ text_encoder.eval()
137
+
138
+ if not cache_latents:
139
+ vae.requires_grad_(False)
140
+ vae.eval()
141
+ vae.to(accelerator.device, dtype=weight_dtype)
142
+
143
+ for m in training_models:
144
+ m.requires_grad_(True)
145
+ params = []
146
+ for m in training_models:
147
+ params.extend(m.parameters())
148
+ params_to_optimize = params
149
+
150
+ # 学習に必要なクラスを準備する
151
+ print("prepare optimizer, data loader etc.")
152
+
153
+ # 8-bit Adamを使う
154
+ if args.use_8bit_adam:
155
+ try:
156
+ import bitsandbytes as bnb
157
+ except ImportError:
158
+ raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
159
+ print("use 8-bit Adam optimizer")
160
+ optimizer_class = bnb.optim.AdamW8bit
161
+ elif args.use_lion_optimizer:
162
+ try:
163
+ import lion_pytorch
164
+ except ImportError:
165
+ raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
166
+ print("use Lion optimizer")
167
+ optimizer_class = lion_pytorch.Lion
168
+ else:
169
+ optimizer_class = torch.optim.AdamW
170
+
171
+ # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
172
+ optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
173
+
174
+ # dataloaderを準備する
175
+ # DataLoaderのプロセス数:0はメインプロセスになる
176
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
177
+ train_dataloader = torch.utils.data.DataLoader(
178
+ train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
179
+
180
+ # 学習ステップ数を計算する
181
+ if args.max_train_epochs is not None:
182
+ args.max_train_steps = args.max_train_epochs * len(train_dataloader)
183
+ print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
184
+
185
+ # lr schedulerを用意する
186
+ lr_scheduler = diffusers.optimization.get_scheduler(
187
+ args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
188
+
189
+ # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
190
+ if args.full_fp16:
191
+ assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
192
+ print("enable full fp16 training.")
193
+ unet.to(weight_dtype)
194
+ text_encoder.to(weight_dtype)
195
+
196
+ # acceleratorがなんかよろしくやってくれるらしい
197
+ if args.train_text_encoder:
198
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
199
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
200
+ else:
201
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
202
+
203
+ # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
204
+ if args.full_fp16:
205
+ train_util.patch_accelerator_for_fp16_training(accelerator)
206
+
207
+ # resumeする
208
+ if args.resume is not None:
209
+ print(f"resume training from state: {args.resume}")
210
+ accelerator.load_state(args.resume)
211
+
212
+ # epoch数を計算する
213
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
214
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
215
+ if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
216
+ args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
217
+
218
+ # 学習する
219
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
220
+ print("running training / 学習開始")
221
+ print(f" num examples / サンプル数: {train_dataset.num_train_images}")
222
+ print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
223
+ print(f" num epochs / epoch数: {num_train_epochs}")
224
+ print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
225
+ print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
226
+ print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
227
+ print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
228
+
229
+ progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
230
+ global_step = 0
231
+
232
+ noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
233
+ num_train_timesteps=1000, clip_sample=False)
234
+
235
+ if accelerator.is_main_process:
236
+ accelerator.init_trackers("finetuning")
237
+
238
+ for epoch in range(num_train_epochs):
239
+ print(f"epoch {epoch+1}/{num_train_epochs}")
240
+ train_dataset.set_current_epoch(epoch + 1)
241
+
242
+ for m in training_models:
243
+ m.train()
244
+
245
+ loss_total = 0
246
+ for step, batch in enumerate(train_dataloader):
247
+ with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
248
+ with torch.no_grad():
249
+ if "latents" in batch and batch["latents"] is not None:
250
+ latents = batch["latents"].to(accelerator.device)
251
+ else:
252
+ # latentに変換
253
+ latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
254
+ latents = latents * 0.18215
255
+ b_size = latents.shape[0]
256
+
257
+ with torch.set_grad_enabled(args.train_text_encoder):
258
+ # Get the text embedding for conditioning
259
+ input_ids = batch["input_ids"].to(accelerator.device)
260
+ encoder_hidden_states = train_util.get_hidden_states(
261
+ args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype)
262
+
263
+ # Sample noise that we'll add to the latents
264
+ noise = torch.randn_like(latents, device=latents.device)
265
+ if args.noise_offset:
266
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
267
+ noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
268
+
269
+ # Sample a random timestep for each image
270
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
271
+ timesteps = timesteps.long()
272
+
273
+ # Add noise to the latents according to the noise magnitude at each timestep
274
+ # (this is the forward diffusion process)
275
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
276
+
277
+ # Predict the noise residual
278
+ noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
279
+
280
+ if args.v_parameterization:
281
+ # v-parameterization training
282
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
283
+ else:
284
+ target = noise
285
+
286
+ loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
287
+
288
+ accelerator.backward(loss)
289
+ if accelerator.sync_gradients:
290
+ params_to_clip = []
291
+ for m in training_models:
292
+ params_to_clip.extend(m.parameters())
293
+ accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
294
+
295
+ optimizer.step()
296
+ lr_scheduler.step()
297
+ optimizer.zero_grad(set_to_none=True)
298
+
299
+ # Checks if the accelerator has performed an optimization step behind the scenes
300
+ if accelerator.sync_gradients:
301
+ progress_bar.update(1)
302
+ global_step += 1
303
+
304
+ current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
305
+ if args.logging_dir is not None:
306
+ logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
307
+ accelerator.log(logs, step=global_step)
308
+
309
+ loss_total += current_loss
310
+ avr_loss = loss_total / (step+1)
311
+ logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
312
+ progress_bar.set_postfix(**logs)
313
+
314
+ if global_step >= args.max_train_steps:
315
+ break
316
+
317
+ if args.logging_dir is not None:
318
+ logs = {"epoch_loss": loss_total / len(train_dataloader)}
319
+ accelerator.log(logs, step=epoch+1)
320
+
321
+ accelerator.wait_for_everyone()
322
+
323
+ if args.save_every_n_epochs is not None:
324
+ src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
325
+ train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
326
+ save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
327
+
328
+ is_main_process = accelerator.is_main_process
329
+ if is_main_process:
330
+ unet = unwrap_model(unet)
331
+ text_encoder = unwrap_model(text_encoder)
332
+
333
+ accelerator.end_training()
334
+
335
+ if args.save_state:
336
+ train_util.save_state_on_train_end(args, accelerator)
337
+
338
+ del accelerator # この後メモリを使うのでこれは消す
339
+
340
+ if is_main_process:
341
+ src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
342
+ train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors,
343
+ save_dtype, epoch, global_step, text_encoder, unet, vae)
344
+ print("model saved.")
345
+
346
+
347
+ if __name__ == '__main__':
348
+ parser = argparse.ArgumentParser()
349
+
350
+ train_util.add_sd_models_arguments(parser)
351
+ train_util.add_dataset_arguments(parser, False, True, True)
352
+ train_util.add_training_arguments(parser, False)
353
+ train_util.add_sd_saving_arguments(parser)
354
+
355
+ parser.add_argument("--diffusers_xformers", action='store_true',
356
+ help='use xformers by diffusers / Diffusersでxformersを使用する')
357
+ parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
358
+
359
+ args = parser.parse_args()
360
+ train(args)
gen_img_diffusers.py ADDED
The diff for this file is too large to render. See raw diff
 
library.egg-info/PKG-INFO ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: library
3
+ Version: 0.0.0
4
+ License-File: LICENSE.md
library.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE.md
2
+ README.md
3
+ setup.py
4
+ library/__init__.py
5
+ library/model_util.py
6
+ library/train_util.py
7
+ library.egg-info/PKG-INFO
8
+ library.egg-info/SOURCES.txt
9
+ library.egg-info/dependency_links.txt
10
+ library.egg-info/top_level.txt
library.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
library.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ library
library/__init__.py ADDED
File without changes
library/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (131 Bytes). View file
 
library/__pycache__/model_util.cpython-310.pyc ADDED
Binary file (29.2 kB). View file
 
library/__pycache__/train_util.cpython-310.pyc ADDED
Binary file (57.4 kB). View file
 
library/model_util.py ADDED
@@ -0,0 +1,1180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # v1: split from train_db_fixed.py.
2
+ # v2: support safetensors
3
+
4
+ import math
5
+ import os
6
+ import torch
7
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
8
+ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
9
+ from safetensors.torch import load_file, save_file
10
+
11
+ # DiffUsers版StableDiffusionのモデルパラメータ
12
+ NUM_TRAIN_TIMESTEPS = 1000
13
+ BETA_START = 0.00085
14
+ BETA_END = 0.0120
15
+
16
+ UNET_PARAMS_MODEL_CHANNELS = 320
17
+ UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
18
+ UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
19
+ UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
20
+ UNET_PARAMS_IN_CHANNELS = 4
21
+ UNET_PARAMS_OUT_CHANNELS = 4
22
+ UNET_PARAMS_NUM_RES_BLOCKS = 2
23
+ UNET_PARAMS_CONTEXT_DIM = 768
24
+ UNET_PARAMS_NUM_HEADS = 8
25
+
26
+ VAE_PARAMS_Z_CHANNELS = 4
27
+ VAE_PARAMS_RESOLUTION = 256
28
+ VAE_PARAMS_IN_CHANNELS = 3
29
+ VAE_PARAMS_OUT_CH = 3
30
+ VAE_PARAMS_CH = 128
31
+ VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
32
+ VAE_PARAMS_NUM_RES_BLOCKS = 2
33
+
34
+ # V2
35
+ V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
36
+ V2_UNET_PARAMS_CONTEXT_DIM = 1024
37
+
38
+ # Diffusersの設定を読み込むための参照モデル
39
+ DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
40
+ DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
41
+
42
+
43
+ # region StableDiffusion->Diffusersの変換コード
44
+ # convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
45
+
46
+
47
+ def shave_segments(path, n_shave_prefix_segments=1):
48
+ """
49
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
50
+ """
51
+ if n_shave_prefix_segments >= 0:
52
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
53
+ else:
54
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
55
+
56
+
57
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
58
+ """
59
+ Updates paths inside resnets to the new naming scheme (local renaming)
60
+ """
61
+ mapping = []
62
+ for old_item in old_list:
63
+ new_item = old_item.replace("in_layers.0", "norm1")
64
+ new_item = new_item.replace("in_layers.2", "conv1")
65
+
66
+ new_item = new_item.replace("out_layers.0", "norm2")
67
+ new_item = new_item.replace("out_layers.3", "conv2")
68
+
69
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
70
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
71
+
72
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
73
+
74
+ mapping.append({"old": old_item, "new": new_item})
75
+
76
+ return mapping
77
+
78
+
79
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
80
+ """
81
+ Updates paths inside resnets to the new naming scheme (local renaming)
82
+ """
83
+ mapping = []
84
+ for old_item in old_list:
85
+ new_item = old_item
86
+
87
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
88
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
89
+
90
+ mapping.append({"old": old_item, "new": new_item})
91
+
92
+ return mapping
93
+
94
+
95
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
96
+ """
97
+ Updates paths inside attentions to the new naming scheme (local renaming)
98
+ """
99
+ mapping = []
100
+ for old_item in old_list:
101
+ new_item = old_item
102
+
103
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
104
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
105
+
106
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
107
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
108
+
109
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
110
+
111
+ mapping.append({"old": old_item, "new": new_item})
112
+
113
+ return mapping
114
+
115
+
116
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
117
+ """
118
+ Updates paths inside attentions to the new naming scheme (local renaming)
119
+ """
120
+ mapping = []
121
+ for old_item in old_list:
122
+ new_item = old_item
123
+
124
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
125
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
126
+
127
+ new_item = new_item.replace("q.weight", "query.weight")
128
+ new_item = new_item.replace("q.bias", "query.bias")
129
+
130
+ new_item = new_item.replace("k.weight", "key.weight")
131
+ new_item = new_item.replace("k.bias", "key.bias")
132
+
133
+ new_item = new_item.replace("v.weight", "value.weight")
134
+ new_item = new_item.replace("v.bias", "value.bias")
135
+
136
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
137
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
138
+
139
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
140
+
141
+ mapping.append({"old": old_item, "new": new_item})
142
+
143
+ return mapping
144
+
145
+
146
+ def assign_to_checkpoint(
147
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
148
+ ):
149
+ """
150
+ This does the final conversion step: take locally converted weights and apply a global renaming
151
+ to them. It splits attention layers, and takes into account additional replacements
152
+ that may arise.
153
+
154
+ Assigns the weights to the new checkpoint.
155
+ """
156
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
157
+
158
+ # Splits the attention layers into three variables.
159
+ if attention_paths_to_split is not None:
160
+ for path, path_map in attention_paths_to_split.items():
161
+ old_tensor = old_checkpoint[path]
162
+ channels = old_tensor.shape[0] // 3
163
+
164
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
165
+
166
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
167
+
168
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
169
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
170
+
171
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
172
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
173
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
174
+
175
+ for path in paths:
176
+ new_path = path["new"]
177
+
178
+ # These have already been assigned
179
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
180
+ continue
181
+
182
+ # Global renaming happens here
183
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
184
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
185
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
186
+
187
+ if additional_replacements is not None:
188
+ for replacement in additional_replacements:
189
+ new_path = new_path.replace(replacement["old"], replacement["new"])
190
+
191
+ # proj_attn.weight has to be converted from conv 1D to linear
192
+ if "proj_attn.weight" in new_path:
193
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
194
+ else:
195
+ checkpoint[new_path] = old_checkpoint[path["old"]]
196
+
197
+
198
+ def conv_attn_to_linear(checkpoint):
199
+ keys = list(checkpoint.keys())
200
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
201
+ for key in keys:
202
+ if ".".join(key.split(".")[-2:]) in attn_keys:
203
+ if checkpoint[key].ndim > 2:
204
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
205
+ elif "proj_attn.weight" in key:
206
+ if checkpoint[key].ndim > 2:
207
+ checkpoint[key] = checkpoint[key][:, :, 0]
208
+
209
+
210
+ def linear_transformer_to_conv(checkpoint):
211
+ keys = list(checkpoint.keys())
212
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
213
+ for key in keys:
214
+ if ".".join(key.split(".")[-2:]) in tf_keys:
215
+ if checkpoint[key].ndim == 2:
216
+ checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
217
+
218
+
219
+ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
220
+ """
221
+ Takes a state dict and a config, and returns a converted checkpoint.
222
+ """
223
+
224
+ # extract state_dict for UNet
225
+ unet_state_dict = {}
226
+ unet_key = "model.diffusion_model."
227
+ keys = list(checkpoint.keys())
228
+ for key in keys:
229
+ if key.startswith(unet_key):
230
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
231
+
232
+ new_checkpoint = {}
233
+
234
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
235
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
236
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
237
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
238
+
239
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
240
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
241
+
242
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
243
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
244
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
245
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
246
+
247
+ # Retrieves the keys for the input blocks only
248
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
249
+ input_blocks = {
250
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
251
+ for layer_id in range(num_input_blocks)
252
+ }
253
+
254
+ # Retrieves the keys for the middle blocks only
255
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
256
+ middle_blocks = {
257
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
258
+ for layer_id in range(num_middle_blocks)
259
+ }
260
+
261
+ # Retrieves the keys for the output blocks only
262
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
263
+ output_blocks = {
264
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
265
+ for layer_id in range(num_output_blocks)
266
+ }
267
+
268
+ for i in range(1, num_input_blocks):
269
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
270
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
271
+
272
+ resnets = [
273
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
274
+ ]
275
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
276
+
277
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
278
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
279
+ f"input_blocks.{i}.0.op.weight"
280
+ )
281
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
282
+ f"input_blocks.{i}.0.op.bias"
283
+ )
284
+
285
+ paths = renew_resnet_paths(resnets)
286
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
287
+ assign_to_checkpoint(
288
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
289
+ )
290
+
291
+ if len(attentions):
292
+ paths = renew_attention_paths(attentions)
293
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
294
+ assign_to_checkpoint(
295
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
296
+ )
297
+
298
+ resnet_0 = middle_blocks[0]
299
+ attentions = middle_blocks[1]
300
+ resnet_1 = middle_blocks[2]
301
+
302
+ resnet_0_paths = renew_resnet_paths(resnet_0)
303
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
304
+
305
+ resnet_1_paths = renew_resnet_paths(resnet_1)
306
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
307
+
308
+ attentions_paths = renew_attention_paths(attentions)
309
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
310
+ assign_to_checkpoint(
311
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
312
+ )
313
+
314
+ for i in range(num_output_blocks):
315
+ block_id = i // (config["layers_per_block"] + 1)
316
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
317
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
318
+ output_block_list = {}
319
+
320
+ for layer in output_block_layers:
321
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
322
+ if layer_id in output_block_list:
323
+ output_block_list[layer_id].append(layer_name)
324
+ else:
325
+ output_block_list[layer_id] = [layer_name]
326
+
327
+ if len(output_block_list) > 1:
328
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
329
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
330
+
331
+ resnet_0_paths = renew_resnet_paths(resnets)
332
+ paths = renew_resnet_paths(resnets)
333
+
334
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
335
+ assign_to_checkpoint(
336
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
337
+ )
338
+
339
+ # オリジナル:
340
+ # if ["conv.weight", "conv.bias"] in output_block_list.values():
341
+ # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
342
+
343
+ # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
344
+ for l in output_block_list.values():
345
+ l.sort()
346
+
347
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
348
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
349
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
350
+ f"output_blocks.{i}.{index}.conv.bias"
351
+ ]
352
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
353
+ f"output_blocks.{i}.{index}.conv.weight"
354
+ ]
355
+
356
+ # Clear attentions as they have been attributed above.
357
+ if len(attentions) == 2:
358
+ attentions = []
359
+
360
+ if len(attentions):
361
+ paths = renew_attention_paths(attentions)
362
+ meta_path = {
363
+ "old": f"output_blocks.{i}.1",
364
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
365
+ }
366
+ assign_to_checkpoint(
367
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
368
+ )
369
+ else:
370
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
371
+ for path in resnet_0_paths:
372
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
373
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
374
+
375
+ new_checkpoint[new_path] = unet_state_dict[old_path]
376
+
377
+ # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
378
+ if v2:
379
+ linear_transformer_to_conv(new_checkpoint)
380
+
381
+ return new_checkpoint
382
+
383
+
384
+ def convert_ldm_vae_checkpoint(checkpoint, config):
385
+ # extract state dict for VAE
386
+ vae_state_dict = {}
387
+ vae_key = "first_stage_model."
388
+ keys = list(checkpoint.keys())
389
+ for key in keys:
390
+ if key.startswith(vae_key):
391
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
392
+ # if len(vae_state_dict) == 0:
393
+ # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
394
+ # vae_state_dict = checkpoint
395
+
396
+ new_checkpoint = {}
397
+
398
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
399
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
400
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
401
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
402
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
403
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
404
+
405
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
406
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
407
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
408
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
409
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
410
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
411
+
412
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
413
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
414
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
415
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
416
+
417
+ # Retrieves the keys for the encoder down blocks only
418
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
419
+ down_blocks = {
420
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
421
+ }
422
+
423
+ # Retrieves the keys for the decoder up blocks only
424
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
425
+ up_blocks = {
426
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
427
+ }
428
+
429
+ for i in range(num_down_blocks):
430
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
431
+
432
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
433
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
434
+ f"encoder.down.{i}.downsample.conv.weight"
435
+ )
436
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
437
+ f"encoder.down.{i}.downsample.conv.bias"
438
+ )
439
+
440
+ paths = renew_vae_resnet_paths(resnets)
441
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
442
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
443
+
444
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
445
+ num_mid_res_blocks = 2
446
+ for i in range(1, num_mid_res_blocks + 1):
447
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
448
+
449
+ paths = renew_vae_resnet_paths(resnets)
450
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
451
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
452
+
453
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
454
+ paths = renew_vae_attention_paths(mid_attentions)
455
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
456
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
457
+ conv_attn_to_linear(new_checkpoint)
458
+
459
+ for i in range(num_up_blocks):
460
+ block_id = num_up_blocks - 1 - i
461
+ resnets = [
462
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
463
+ ]
464
+
465
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
466
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
467
+ f"decoder.up.{block_id}.upsample.conv.weight"
468
+ ]
469
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
470
+ f"decoder.up.{block_id}.upsample.conv.bias"
471
+ ]
472
+
473
+ paths = renew_vae_resnet_paths(resnets)
474
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
475
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
476
+
477
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
478
+ num_mid_res_blocks = 2
479
+ for i in range(1, num_mid_res_blocks + 1):
480
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
481
+
482
+ paths = renew_vae_resnet_paths(resnets)
483
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
484
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
485
+
486
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
487
+ paths = renew_vae_attention_paths(mid_attentions)
488
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
489
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
490
+ conv_attn_to_linear(new_checkpoint)
491
+ return new_checkpoint
492
+
493
+
494
+ def create_unet_diffusers_config(v2):
495
+ """
496
+ Creates a config for the diffusers based on the config of the LDM model.
497
+ """
498
+ # unet_params = original_config.model.params.unet_config.params
499
+
500
+ block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
501
+
502
+ down_block_types = []
503
+ resolution = 1
504
+ for i in range(len(block_out_channels)):
505
+ block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
506
+ down_block_types.append(block_type)
507
+ if i != len(block_out_channels) - 1:
508
+ resolution *= 2
509
+
510
+ up_block_types = []
511
+ for i in range(len(block_out_channels)):
512
+ block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
513
+ up_block_types.append(block_type)
514
+ resolution //= 2
515
+
516
+ config = dict(
517
+ sample_size=UNET_PARAMS_IMAGE_SIZE,
518
+ in_channels=UNET_PARAMS_IN_CHANNELS,
519
+ out_channels=UNET_PARAMS_OUT_CHANNELS,
520
+ down_block_types=tuple(down_block_types),
521
+ up_block_types=tuple(up_block_types),
522
+ block_out_channels=tuple(block_out_channels),
523
+ layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
524
+ cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
525
+ attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
526
+ )
527
+
528
+ return config
529
+
530
+
531
+ def create_vae_diffusers_config():
532
+ """
533
+ Creates a config for the diffusers based on the config of the LDM model.
534
+ """
535
+ # vae_params = original_config.model.params.first_stage_config.params.ddconfig
536
+ # _ = original_config.model.params.first_stage_config.params.embed_dim
537
+ block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
538
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
539
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
540
+
541
+ config = dict(
542
+ sample_size=VAE_PARAMS_RESOLUTION,
543
+ in_channels=VAE_PARAMS_IN_CHANNELS,
544
+ out_channels=VAE_PARAMS_OUT_CH,
545
+ down_block_types=tuple(down_block_types),
546
+ up_block_types=tuple(up_block_types),
547
+ block_out_channels=tuple(block_out_channels),
548
+ latent_channels=VAE_PARAMS_Z_CHANNELS,
549
+ layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
550
+ )
551
+ return config
552
+
553
+
554
+ def convert_ldm_clip_checkpoint_v1(checkpoint):
555
+ keys = list(checkpoint.keys())
556
+ text_model_dict = {}
557
+ for key in keys:
558
+ if key.startswith("cond_stage_model.transformer"):
559
+ text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
560
+ return text_model_dict
561
+
562
+
563
+ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
564
+ # 嫌になるくらい違うぞ!
565
+ def convert_key(key):
566
+ if not key.startswith("cond_stage_model"):
567
+ return None
568
+
569
+ # common conversion
570
+ key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
571
+ key = key.replace("cond_stage_model.model.", "text_model.")
572
+
573
+ if "resblocks" in key:
574
+ # resblocks conversion
575
+ key = key.replace(".resblocks.", ".layers.")
576
+ if ".ln_" in key:
577
+ key = key.replace(".ln_", ".layer_norm")
578
+ elif ".mlp." in key:
579
+ key = key.replace(".c_fc.", ".fc1.")
580
+ key = key.replace(".c_proj.", ".fc2.")
581
+ elif '.attn.out_proj' in key:
582
+ key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
583
+ elif '.attn.in_proj' in key:
584
+ key = None # 特殊なので後で処理する
585
+ else:
586
+ raise ValueError(f"unexpected key in SD: {key}")
587
+ elif '.positional_embedding' in key:
588
+ key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
589
+ elif '.text_projection' in key:
590
+ key = None # 使われない???
591
+ elif '.logit_scale' in key:
592
+ key = None # 使われない???
593
+ elif '.token_embedding' in key:
594
+ key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
595
+ elif '.ln_final' in key:
596
+ key = key.replace(".ln_final", ".final_layer_norm")
597
+ return key
598
+
599
+ keys = list(checkpoint.keys())
600
+ new_sd = {}
601
+ for key in keys:
602
+ # remove resblocks 23
603
+ if '.resblocks.23.' in key:
604
+ continue
605
+ new_key = convert_key(key)
606
+ if new_key is None:
607
+ continue
608
+ new_sd[new_key] = checkpoint[key]
609
+
610
+ # attnの変換
611
+ for key in keys:
612
+ if '.resblocks.23.' in key:
613
+ continue
614
+ if '.resblocks' in key and '.attn.in_proj_' in key:
615
+ # 三つに分割
616
+ values = torch.chunk(checkpoint[key], 3)
617
+
618
+ key_suffix = ".weight" if "weight" in key else ".bias"
619
+ key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
620
+ key_pfx = key_pfx.replace("_weight", "")
621
+ key_pfx = key_pfx.replace("_bias", "")
622
+ key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
623
+ new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
624
+ new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
625
+ new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
626
+
627
+ # rename or add position_ids
628
+ ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
629
+ if ANOTHER_POSITION_IDS_KEY in new_sd:
630
+ # waifu diffusion v1.4
631
+ position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
632
+ del new_sd[ANOTHER_POSITION_IDS_KEY]
633
+ else:
634
+ position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
635
+
636
+ new_sd["text_model.embeddings.position_ids"] = position_ids
637
+ return new_sd
638
+
639
+ # endregion
640
+
641
+
642
+ # region Diffusers->StableDiffusion の変換コード
643
+ # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
644
+
645
+ def conv_transformer_to_linear(checkpoint):
646
+ keys = list(checkpoint.keys())
647
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
648
+ for key in keys:
649
+ if ".".join(key.split(".")[-2:]) in tf_keys:
650
+ if checkpoint[key].ndim > 2:
651
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
652
+
653
+
654
+ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
655
+ unet_conversion_map = [
656
+ # (stable-diffusion, HF Diffusers)
657
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
658
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
659
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
660
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
661
+ ("input_blocks.0.0.weight", "conv_in.weight"),
662
+ ("input_blocks.0.0.bias", "conv_in.bias"),
663
+ ("out.0.weight", "conv_norm_out.weight"),
664
+ ("out.0.bias", "conv_norm_out.bias"),
665
+ ("out.2.weight", "conv_out.weight"),
666
+ ("out.2.bias", "conv_out.bias"),
667
+ ]
668
+
669
+ unet_conversion_map_resnet = [
670
+ # (stable-diffusion, HF Diffusers)
671
+ ("in_layers.0", "norm1"),
672
+ ("in_layers.2", "conv1"),
673
+ ("out_layers.0", "norm2"),
674
+ ("out_layers.3", "conv2"),
675
+ ("emb_layers.1", "time_emb_proj"),
676
+ ("skip_connection", "conv_shortcut"),
677
+ ]
678
+
679
+ unet_conversion_map_layer = []
680
+ for i in range(4):
681
+ # loop over downblocks/upblocks
682
+
683
+ for j in range(2):
684
+ # loop over resnets/attentions for downblocks
685
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
686
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
687
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
688
+
689
+ if i < 3:
690
+ # no attention layers in down_blocks.3
691
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
692
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
693
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
694
+
695
+ for j in range(3):
696
+ # loop over resnets/attentions for upblocks
697
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
698
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
699
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
700
+
701
+ if i > 0:
702
+ # no attention layers in up_blocks.0
703
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
704
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
705
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
706
+
707
+ if i < 3:
708
+ # no downsample in down_blocks.3
709
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
710
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
711
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
712
+
713
+ # no upsample in up_blocks.3
714
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
715
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
716
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
717
+
718
+ hf_mid_atn_prefix = "mid_block.attentions.0."
719
+ sd_mid_atn_prefix = "middle_block.1."
720
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
721
+
722
+ for j in range(2):
723
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
724
+ sd_mid_res_prefix = f"middle_block.{2*j}."
725
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
726
+
727
+ # buyer beware: this is a *brittle* function,
728
+ # and correct output requires that all of these pieces interact in
729
+ # the exact order in which I have arranged them.
730
+ mapping = {k: k for k in unet_state_dict.keys()}
731
+ for sd_name, hf_name in unet_conversion_map:
732
+ mapping[hf_name] = sd_name
733
+ for k, v in mapping.items():
734
+ if "resnets" in k:
735
+ for sd_part, hf_part in unet_conversion_map_resnet:
736
+ v = v.replace(hf_part, sd_part)
737
+ mapping[k] = v
738
+ for k, v in mapping.items():
739
+ for sd_part, hf_part in unet_conversion_map_layer:
740
+ v = v.replace(hf_part, sd_part)
741
+ mapping[k] = v
742
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
743
+
744
+ if v2:
745
+ conv_transformer_to_linear(new_state_dict)
746
+
747
+ return new_state_dict
748
+
749
+
750
+ # ================#
751
+ # VAE Conversion #
752
+ # ================#
753
+
754
+ def reshape_weight_for_sd(w):
755
+ # convert HF linear weights to SD conv2d weights
756
+ return w.reshape(*w.shape, 1, 1)
757
+
758
+
759
+ def convert_vae_state_dict(vae_state_dict):
760
+ vae_conversion_map = [
761
+ # (stable-diffusion, HF Diffusers)
762
+ ("nin_shortcut", "conv_shortcut"),
763
+ ("norm_out", "conv_norm_out"),
764
+ ("mid.attn_1.", "mid_block.attentions.0."),
765
+ ]
766
+
767
+ for i in range(4):
768
+ # down_blocks have two resnets
769
+ for j in range(2):
770
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
771
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
772
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
773
+
774
+ if i < 3:
775
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
776
+ sd_downsample_prefix = f"down.{i}.downsample."
777
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
778
+
779
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
780
+ sd_upsample_prefix = f"up.{3-i}.upsample."
781
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
782
+
783
+ # up_blocks have three resnets
784
+ # also, up blocks in hf are numbered in reverse from sd
785
+ for j in range(3):
786
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
787
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
788
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
789
+
790
+ # this part accounts for mid blocks in both the encoder and the decoder
791
+ for i in range(2):
792
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
793
+ sd_mid_res_prefix = f"mid.block_{i+1}."
794
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
795
+
796
+ vae_conversion_map_attn = [
797
+ # (stable-diffusion, HF Diffusers)
798
+ ("norm.", "group_norm."),
799
+ ("q.", "query."),
800
+ ("k.", "key."),
801
+ ("v.", "value."),
802
+ ("proj_out.", "proj_attn."),
803
+ ]
804
+
805
+ mapping = {k: k for k in vae_state_dict.keys()}
806
+ for k, v in mapping.items():
807
+ for sd_part, hf_part in vae_conversion_map:
808
+ v = v.replace(hf_part, sd_part)
809
+ mapping[k] = v
810
+ for k, v in mapping.items():
811
+ if "attentions" in k:
812
+ for sd_part, hf_part in vae_conversion_map_attn:
813
+ v = v.replace(hf_part, sd_part)
814
+ mapping[k] = v
815
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
816
+ weights_to_convert = ["q", "k", "v", "proj_out"]
817
+ for k, v in new_state_dict.items():
818
+ for weight_name in weights_to_convert:
819
+ if f"mid.attn_1.{weight_name}.weight" in k:
820
+ # print(f"Reshaping {k} for SD format")
821
+ new_state_dict[k] = reshape_weight_for_sd(v)
822
+
823
+ return new_state_dict
824
+
825
+
826
+ # endregion
827
+
828
+ # region 自作のモデル読み書きなど
829
+
830
+ def is_safetensors(path):
831
+ return os.path.splitext(path)[1].lower() == '.safetensors'
832
+
833
+
834
+ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
835
+ # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
836
+ TEXT_ENCODER_KEY_REPLACEMENTS = [
837
+ ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
838
+ ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
839
+ ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
840
+ ]
841
+
842
+ if is_safetensors(ckpt_path):
843
+ checkpoint = None
844
+ state_dict = load_file(ckpt_path, "cpu")
845
+ else:
846
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
847
+ if "state_dict" in checkpoint:
848
+ state_dict = checkpoint["state_dict"]
849
+ else:
850
+ state_dict = checkpoint
851
+ checkpoint = None
852
+
853
+ key_reps = []
854
+ for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
855
+ for key in state_dict.keys():
856
+ if key.startswith(rep_from):
857
+ new_key = rep_to + key[len(rep_from):]
858
+ key_reps.append((key, new_key))
859
+
860
+ for key, new_key in key_reps:
861
+ state_dict[new_key] = state_dict[key]
862
+ del state_dict[key]
863
+
864
+ return checkpoint, state_dict
865
+
866
+
867
+ # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
868
+ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
869
+ _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
870
+ if dtype is not None:
871
+ for k, v in state_dict.items():
872
+ if type(v) is torch.Tensor:
873
+ state_dict[k] = v.to(dtype)
874
+
875
+ # Convert the UNet2DConditionModel model.
876
+ unet_config = create_unet_diffusers_config(v2)
877
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
878
+
879
+ unet = UNet2DConditionModel(**unet_config)
880
+ info = unet.load_state_dict(converted_unet_checkpoint)
881
+ print("loading u-net:", info)
882
+
883
+ # Convert the VAE model.
884
+ vae_config = create_vae_diffusers_config()
885
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
886
+
887
+ vae = AutoencoderKL(**vae_config)
888
+ info = vae.load_state_dict(converted_vae_checkpoint)
889
+ print("loading vae:", info)
890
+
891
+ # convert text_model
892
+ if v2:
893
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
894
+ cfg = CLIPTextConfig(
895
+ vocab_size=49408,
896
+ hidden_size=1024,
897
+ intermediate_size=4096,
898
+ num_hidden_layers=23,
899
+ num_attention_heads=16,
900
+ max_position_embeddings=77,
901
+ hidden_act="gelu",
902
+ layer_norm_eps=1e-05,
903
+ dropout=0.0,
904
+ attention_dropout=0.0,
905
+ initializer_range=0.02,
906
+ initializer_factor=1.0,
907
+ pad_token_id=1,
908
+ bos_token_id=0,
909
+ eos_token_id=2,
910
+ model_type="clip_text_model",
911
+ projection_dim=512,
912
+ torch_dtype="float32",
913
+ transformers_version="4.25.0.dev0",
914
+ )
915
+ text_model = CLIPTextModel._from_config(cfg)
916
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
917
+ else:
918
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
919
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
920
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
921
+ print("loading text encoder:", info)
922
+
923
+ return text_model, vae, unet
924
+
925
+
926
+ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
927
+ def convert_key(key):
928
+ # position_idsの除去
929
+ if ".position_ids" in key:
930
+ return None
931
+
932
+ # common
933
+ key = key.replace("text_model.encoder.", "transformer.")
934
+ key = key.replace("text_model.", "")
935
+ if "layers" in key:
936
+ # resblocks conversion
937
+ key = key.replace(".layers.", ".resblocks.")
938
+ if ".layer_norm" in key:
939
+ key = key.replace(".layer_norm", ".ln_")
940
+ elif ".mlp." in key:
941
+ key = key.replace(".fc1.", ".c_fc.")
942
+ key = key.replace(".fc2.", ".c_proj.")
943
+ elif '.self_attn.out_proj' in key:
944
+ key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
945
+ elif '.self_attn.' in key:
946
+ key = None # 特殊なので後で処理する
947
+ else:
948
+ raise ValueError(f"unexpected key in DiffUsers model: {key}")
949
+ elif '.position_embedding' in key:
950
+ key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
951
+ elif '.token_embedding' in key:
952
+ key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
953
+ elif 'final_layer_norm' in key:
954
+ key = key.replace("final_layer_norm", "ln_final")
955
+ return key
956
+
957
+ keys = list(checkpoint.keys())
958
+ new_sd = {}
959
+ for key in keys:
960
+ new_key = convert_key(key)
961
+ if new_key is None:
962
+ continue
963
+ new_sd[new_key] = checkpoint[key]
964
+
965
+ # attnの変換
966
+ for key in keys:
967
+ if 'layers' in key and 'q_proj' in key:
968
+ # 三つを結合
969
+ key_q = key
970
+ key_k = key.replace("q_proj", "k_proj")
971
+ key_v = key.replace("q_proj", "v_proj")
972
+
973
+ value_q = checkpoint[key_q]
974
+ value_k = checkpoint[key_k]
975
+ value_v = checkpoint[key_v]
976
+ value = torch.cat([value_q, value_k, value_v])
977
+
978
+ new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
979
+ new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
980
+ new_sd[new_key] = value
981
+
982
+ # 最後の層などを捏造するか
983
+ if make_dummy_weights:
984
+ print("make dummy weights for resblock.23, text_projection and logit scale.")
985
+ keys = list(new_sd.keys())
986
+ for key in keys:
987
+ if key.startswith("transformer.resblocks.22."):
988
+ new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
989
+
990
+ # Diffusersに含まれない重みを作っておく
991
+ new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
992
+ new_sd['logit_scale'] = torch.tensor(1)
993
+
994
+ return new_sd
995
+
996
+
997
+ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
998
+ if ckpt_path is not None:
999
+ # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
1000
+ checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
1001
+ if checkpoint is None: # safetensors または state_dictのckpt
1002
+ checkpoint = {}
1003
+ strict = False
1004
+ else:
1005
+ strict = True
1006
+ if "state_dict" in state_dict:
1007
+ del state_dict["state_dict"]
1008
+ else:
1009
+ # 新しく作る
1010
+ assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
1011
+ checkpoint = {}
1012
+ state_dict = {}
1013
+ strict = False
1014
+
1015
+ def update_sd(prefix, sd):
1016
+ for k, v in sd.items():
1017
+ key = prefix + k
1018
+ assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
1019
+ if save_dtype is not None:
1020
+ v = v.detach().clone().to("cpu").to(save_dtype)
1021
+ state_dict[key] = v
1022
+
1023
+ # Convert the UNet model
1024
+ unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
1025
+ update_sd("model.diffusion_model.", unet_state_dict)
1026
+
1027
+ # Convert the text encoder model
1028
+ if v2:
1029
+ make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製��て作るなどダミーの重みを入れる
1030
+ text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
1031
+ update_sd("cond_stage_model.model.", text_enc_dict)
1032
+ else:
1033
+ text_enc_dict = text_encoder.state_dict()
1034
+ update_sd("cond_stage_model.transformer.", text_enc_dict)
1035
+
1036
+ # Convert the VAE
1037
+ if vae is not None:
1038
+ vae_dict = convert_vae_state_dict(vae.state_dict())
1039
+ update_sd("first_stage_model.", vae_dict)
1040
+
1041
+ # Put together new checkpoint
1042
+ key_count = len(state_dict.keys())
1043
+ new_ckpt = {'state_dict': state_dict}
1044
+
1045
+ if 'epoch' in checkpoint:
1046
+ epochs += checkpoint['epoch']
1047
+ if 'global_step' in checkpoint:
1048
+ steps += checkpoint['global_step']
1049
+
1050
+ new_ckpt['epoch'] = epochs
1051
+ new_ckpt['global_step'] = steps
1052
+
1053
+ if is_safetensors(output_file):
1054
+ # TODO Tensor以外のdictの値を削除したほうがいいか
1055
+ save_file(state_dict, output_file)
1056
+ else:
1057
+ torch.save(new_ckpt, output_file)
1058
+
1059
+ return key_count
1060
+
1061
+
1062
+ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
1063
+ if pretrained_model_name_or_path is None:
1064
+ # load default settings for v1/v2
1065
+ if v2:
1066
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
1067
+ else:
1068
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
1069
+
1070
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
1071
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
1072
+ if vae is None:
1073
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
1074
+
1075
+ pipeline = StableDiffusionPipeline(
1076
+ unet=unet,
1077
+ text_encoder=text_encoder,
1078
+ vae=vae,
1079
+ scheduler=scheduler,
1080
+ tokenizer=tokenizer,
1081
+ safety_checker=None,
1082
+ feature_extractor=None,
1083
+ requires_safety_checker=None,
1084
+ )
1085
+ pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
1086
+
1087
+
1088
+ VAE_PREFIX = "first_stage_model."
1089
+
1090
+
1091
+ def load_vae(vae_id, dtype):
1092
+ print(f"load VAE: {vae_id}")
1093
+ if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
1094
+ # Diffusers local/remote
1095
+ try:
1096
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
1097
+ except EnvironmentError as e:
1098
+ print(f"exception occurs in loading vae: {e}")
1099
+ print("retry with subfolder='vae'")
1100
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
1101
+ return vae
1102
+
1103
+ # local
1104
+ vae_config = create_vae_diffusers_config()
1105
+
1106
+ if vae_id.endswith(".bin"):
1107
+ # SD 1.5 VAE on Huggingface
1108
+ converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
1109
+ else:
1110
+ # StableDiffusion
1111
+ vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
1112
+ else torch.load(vae_id, map_location="cpu"))
1113
+ vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
1114
+
1115
+ # vae only or full model
1116
+ full_model = False
1117
+ for vae_key in vae_sd:
1118
+ if vae_key.startswith(VAE_PREFIX):
1119
+ full_model = True
1120
+ break
1121
+ if not full_model:
1122
+ sd = {}
1123
+ for key, value in vae_sd.items():
1124
+ sd[VAE_PREFIX + key] = value
1125
+ vae_sd = sd
1126
+ del sd
1127
+
1128
+ # Convert the VAE model.
1129
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
1130
+
1131
+ vae = AutoencoderKL(**vae_config)
1132
+ vae.load_state_dict(converted_vae_checkpoint)
1133
+ return vae
1134
+
1135
+ # endregion
1136
+
1137
+
1138
+ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
1139
+ max_width, max_height = max_reso
1140
+ max_area = (max_width // divisible) * (max_height // divisible)
1141
+
1142
+ resos = set()
1143
+
1144
+ size = int(math.sqrt(max_area)) * divisible
1145
+ resos.add((size, size))
1146
+
1147
+ size = min_size
1148
+ while size <= max_size:
1149
+ width = size
1150
+ height = min(max_size, (max_area // (width // divisible)) * divisible)
1151
+ resos.add((width, height))
1152
+ resos.add((height, width))
1153
+
1154
+ # # make additional resos
1155
+ # if width >= height and width - divisible >= min_size:
1156
+ # resos.add((width - divisible, height))
1157
+ # resos.add((height, width - divisible))
1158
+ # if height >= width and height - divisible >= min_size:
1159
+ # resos.add((width, height - divisible))
1160
+ # resos.add((height - divisible, width))
1161
+
1162
+ size += divisible
1163
+
1164
+ resos = list(resos)
1165
+ resos.sort()
1166
+ return resos
1167
+
1168
+
1169
+ if __name__ == '__main__':
1170
+ resos = make_bucket_resolutions((512, 768))
1171
+ print(len(resos))
1172
+ print(resos)
1173
+ aspect_ratios = [w / h for w, h in resos]
1174
+ print(aspect_ratios)
1175
+
1176
+ ars = set()
1177
+ for ar in aspect_ratios:
1178
+ if ar in ars:
1179
+ print("error! duplicate ar:", ar)
1180
+ ars.add(ar)
library/train_util.py ADDED
@@ -0,0 +1,1796 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # common functions for training
2
+
3
+ import argparse
4
+ import json
5
+ import shutil
6
+ import time
7
+ from typing import Dict, List, NamedTuple, Tuple
8
+ from accelerate import Accelerator
9
+ from torch.autograd.function import Function
10
+ import glob
11
+ import math
12
+ import os
13
+ import random
14
+ import hashlib
15
+ import subprocess
16
+ from io import BytesIO
17
+
18
+ from tqdm import tqdm
19
+ import torch
20
+ from torchvision import transforms
21
+ from transformers import CLIPTokenizer
22
+ import diffusers
23
+ from diffusers import DDPMScheduler, StableDiffusionPipeline
24
+ import albumentations as albu
25
+ import numpy as np
26
+ from PIL import Image
27
+ import cv2
28
+ from einops import rearrange
29
+ from torch import einsum
30
+ import safetensors.torch
31
+
32
+ import library.model_util as model_util
33
+
34
+ # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
35
+ TOKENIZER_PATH = "openai/clip-vit-large-patch14"
36
+ V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
37
+
38
+ # checkpointファイル名
39
+ EPOCH_STATE_NAME = "{}-{:06d}-state"
40
+ EPOCH_FILE_NAME = "{}-{:06d}"
41
+ EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}"
42
+ LAST_STATE_NAME = "{}-state"
43
+ DEFAULT_EPOCH_NAME = "epoch"
44
+ DEFAULT_LAST_OUTPUT_NAME = "last"
45
+
46
+ # region dataset
47
+
48
+ IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
49
+ # , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux?
50
+
51
+
52
+ class ImageInfo():
53
+ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
54
+ self.image_key: str = image_key
55
+ self.num_repeats: int = num_repeats
56
+ self.caption: str = caption
57
+ self.is_reg: bool = is_reg
58
+ self.absolute_path: str = absolute_path
59
+ self.image_size: Tuple[int, int] = None
60
+ self.resized_size: Tuple[int, int] = None
61
+ self.bucket_reso: Tuple[int, int] = None
62
+ self.latents: torch.Tensor = None
63
+ self.latents_flipped: torch.Tensor = None
64
+ self.latents_npz: str = None
65
+ self.latents_npz_flipped: str = None
66
+
67
+
68
+ class BucketManager():
69
+ def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
70
+ self.no_upscale = no_upscale
71
+ if max_reso is None:
72
+ self.max_reso = None
73
+ self.max_area = None
74
+ else:
75
+ self.max_reso = max_reso
76
+ self.max_area = max_reso[0] * max_reso[1]
77
+ self.min_size = min_size
78
+ self.max_size = max_size
79
+ self.reso_steps = reso_steps
80
+
81
+ self.resos = []
82
+ self.reso_to_id = {}
83
+ self.buckets = [] # 前処理時は (image_key, image)、学習時は image_key
84
+
85
+ def add_image(self, reso, image):
86
+ bucket_id = self.reso_to_id[reso]
87
+ self.buckets[bucket_id].append(image)
88
+
89
+ def shuffle(self):
90
+ for bucket in self.buckets:
91
+ random.shuffle(bucket)
92
+
93
+ def sort(self):
94
+ # 解像度順にソートする(表示時、メタデータ格納時の見栄えをよくするためだけ)。bucketsも入れ替えてreso_to_idも振り直す
95
+ sorted_resos = self.resos.copy()
96
+ sorted_resos.sort()
97
+
98
+ sorted_buckets = []
99
+ sorted_reso_to_id = {}
100
+ for i, reso in enumerate(sorted_resos):
101
+ bucket_id = self.reso_to_id[reso]
102
+ sorted_buckets.append(self.buckets[bucket_id])
103
+ sorted_reso_to_id[reso] = i
104
+
105
+ self.resos = sorted_resos
106
+ self.buckets = sorted_buckets
107
+ self.reso_to_id = sorted_reso_to_id
108
+
109
+ def make_buckets(self):
110
+ resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps)
111
+ self.set_predefined_resos(resos)
112
+
113
+ def set_predefined_resos(self, resos):
114
+ # 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
115
+ self.predefined_resos = resos.copy()
116
+ self.predefined_resos_set = set(resos)
117
+ self.predefined_aspect_ratios = np.array([w / h for w, h in resos])
118
+
119
+ def add_if_new_reso(self, reso):
120
+ if reso not in self.reso_to_id:
121
+ bucket_id = len(self.resos)
122
+ self.reso_to_id[reso] = bucket_id
123
+ self.resos.append(reso)
124
+ self.buckets.append([])
125
+ # print(reso, bucket_id, len(self.buckets))
126
+
127
+ def round_to_steps(self, x):
128
+ x = int(x + .5)
129
+ return x - x % self.reso_steps
130
+
131
+ def select_bucket(self, image_width, image_height):
132
+ aspect_ratio = image_width / image_height
133
+ if not self.no_upscale:
134
+ # 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する
135
+ reso = (image_width, image_height)
136
+ if reso in self.predefined_resos_set:
137
+ pass
138
+ else:
139
+ ar_errors = self.predefined_aspect_ratios - aspect_ratio
140
+ predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
141
+ reso = self.predefined_resos[predefined_bucket_id]
142
+
143
+ ar_reso = reso[0] / reso[1]
144
+ if aspect_ratio > ar_reso: # 横が長い→縦を合わせる
145
+ scale = reso[1] / image_height
146
+ else:
147
+ scale = reso[0] / image_width
148
+
149
+ resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
150
+ # print("use predef", image_width, image_height, reso, resized_size)
151
+ else:
152
+ if image_width * image_height > self.max_area:
153
+ # 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める
154
+ resized_width = math.sqrt(self.max_area * aspect_ratio)
155
+ resized_height = self.max_area / resized_width
156
+ assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal"
157
+
158
+ # リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ
159
+ # 元のbucketingと同じロジック
160
+ b_width_rounded = self.round_to_steps(resized_width)
161
+ b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio)
162
+ ar_width_rounded = b_width_rounded / b_height_in_wr
163
+
164
+ b_height_rounded = self.round_to_steps(resized_height)
165
+ b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio)
166
+ ar_height_rounded = b_width_in_hr / b_height_rounded
167
+
168
+ # print(b_width_rounded, b_height_in_wr, ar_width_rounded)
169
+ # print(b_width_in_hr, b_height_rounded, ar_height_rounded)
170
+
171
+ if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio):
172
+ resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + .5))
173
+ else:
174
+ resized_size = (int(b_height_rounded * aspect_ratio + .5), b_height_rounded)
175
+ # print(resized_size)
176
+ else:
177
+ resized_size = (image_width, image_height) # リサイズは不要
178
+
179
+ # 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする)
180
+ bucket_width = resized_size[0] - resized_size[0] % self.reso_steps
181
+ bucket_height = resized_size[1] - resized_size[1] % self.reso_steps
182
+ # print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height)
183
+
184
+ reso = (bucket_width, bucket_height)
185
+
186
+ self.add_if_new_reso(reso)
187
+
188
+ ar_error = (reso[0] / reso[1]) - aspect_ratio
189
+ return reso, resized_size, ar_error
190
+
191
+
192
+ class BucketBatchIndex(NamedTuple):
193
+ bucket_index: int
194
+ bucket_batch_size: int
195
+ batch_index: int
196
+
197
+
198
+ class BaseDataset(torch.utils.data.Dataset):
199
+ def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
200
+ super().__init__()
201
+ self.tokenizer: CLIPTokenizer = tokenizer
202
+ self.max_token_length = max_token_length
203
+ self.shuffle_caption = shuffle_caption
204
+ self.shuffle_keep_tokens = shuffle_keep_tokens
205
+ # width/height is used when enable_bucket==False
206
+ self.width, self.height = (None, None) if resolution is None else resolution
207
+ self.face_crop_aug_range = face_crop_aug_range
208
+ self.flip_aug = flip_aug
209
+ self.color_aug = color_aug
210
+ self.debug_dataset = debug_dataset
211
+ self.random_crop = random_crop
212
+ self.token_padding_disabled = False
213
+ self.dataset_dirs_info = {}
214
+ self.reg_dataset_dirs_info = {}
215
+ self.tag_frequency = {}
216
+
217
+ self.enable_bucket = False
218
+ self.bucket_manager: BucketManager = None # not initialized
219
+ self.min_bucket_reso = None
220
+ self.max_bucket_reso = None
221
+ self.bucket_reso_steps = None
222
+ self.bucket_no_upscale = None
223
+ self.bucket_info = None # for metadata
224
+
225
+ self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
226
+
227
+ self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
228
+ self.dropout_rate: float = 0
229
+ self.dropout_every_n_epochs: int = None
230
+ self.tag_dropout_rate: float = 0
231
+
232
+ # augmentation
233
+ flip_p = 0.5 if flip_aug else 0.0
234
+ if color_aug:
235
+ # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
236
+ self.aug = albu.Compose([
237
+ albu.OneOf([
238
+ albu.HueSaturationValue(8, 0, 0, p=.5),
239
+ albu.RandomGamma((95, 105), p=.5),
240
+ ], p=.33),
241
+ albu.HorizontalFlip(p=flip_p)
242
+ ], p=1.)
243
+ elif flip_aug:
244
+ self.aug = albu.Compose([
245
+ albu.HorizontalFlip(p=flip_p)
246
+ ], p=1.)
247
+ else:
248
+ self.aug = None
249
+
250
+ self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
251
+
252
+ self.image_data: Dict[str, ImageInfo] = {}
253
+
254
+ self.replacements = {}
255
+
256
+ def set_current_epoch(self, epoch):
257
+ self.current_epoch = epoch
258
+
259
+ def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
260
+ # コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
261
+ self.dropout_rate = dropout_rate
262
+ self.dropout_every_n_epochs = dropout_every_n_epochs
263
+ self.tag_dropout_rate = tag_dropout_rate
264
+
265
+ def set_tag_frequency(self, dir_name, captions):
266
+ frequency_for_dir = self.tag_frequency.get(dir_name, {})
267
+ self.tag_frequency[dir_name] = frequency_for_dir
268
+ for caption in captions:
269
+ for tag in caption.split(","):
270
+ if tag and not tag.isspace():
271
+ tag = tag.lower()
272
+ frequency = frequency_for_dir.get(tag, 0)
273
+ frequency_for_dir[tag] = frequency + 1
274
+
275
+ def disable_token_padding(self):
276
+ self.token_padding_disabled = True
277
+
278
+ def add_replacement(self, str_from, str_to):
279
+ self.replacements[str_from] = str_to
280
+
281
+ def process_caption(self, caption):
282
+ # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
283
+ is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
284
+ is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
285
+
286
+ if is_drop_out:
287
+ caption = ""
288
+ else:
289
+ if self.shuffle_caption or self.tag_dropout_rate > 0:
290
+ def dropout_tags(tokens):
291
+ if self.tag_dropout_rate <= 0:
292
+ return tokens
293
+ l = []
294
+ for token in tokens:
295
+ if random.random() >= self.tag_dropout_rate:
296
+ l.append(token)
297
+ return l
298
+
299
+ tokens = [t.strip() for t in caption.strip().split(",")]
300
+ if self.shuffle_keep_tokens is None:
301
+ if self.shuffle_caption:
302
+ random.shuffle(tokens)
303
+
304
+ tokens = dropout_tags(tokens)
305
+ else:
306
+ if len(tokens) > self.shuffle_keep_tokens:
307
+ keep_tokens = tokens[:self.shuffle_keep_tokens]
308
+ tokens = tokens[self.shuffle_keep_tokens:]
309
+
310
+ if self.shuffle_caption:
311
+ random.shuffle(tokens)
312
+
313
+ tokens = dropout_tags(tokens)
314
+
315
+ tokens = keep_tokens + tokens
316
+ caption = ", ".join(tokens)
317
+
318
+ # textual inversion対応
319
+ for str_from, str_to in self.replacements.items():
320
+ if str_from == "":
321
+ # replace all
322
+ if type(str_to) == list:
323
+ caption = random.choice(str_to)
324
+ else:
325
+ caption = str_to
326
+ else:
327
+ caption = caption.replace(str_from, str_to)
328
+
329
+ return caption
330
+
331
+ def get_input_ids(self, caption):
332
+ input_ids = self.tokenizer(caption, padding="max_length", truncation=True,
333
+ max_length=self.tokenizer_max_length, return_tensors="pt").input_ids
334
+
335
+ if self.tokenizer_max_length > self.tokenizer.model_max_length:
336
+ input_ids = input_ids.squeeze(0)
337
+ iids_list = []
338
+ if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
339
+ # v1
340
+ # 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
341
+ # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
342
+ for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75)
343
+ ids_chunk = (input_ids[0].unsqueeze(0),
344
+ input_ids[i:i + self.tokenizer.model_max_length - 2],
345
+ input_ids[-1].unsqueeze(0))
346
+ ids_chunk = torch.cat(ids_chunk)
347
+ iids_list.append(ids_chunk)
348
+ else:
349
+ # v2
350
+ # 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
351
+ for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2):
352
+ ids_chunk = (input_ids[0].unsqueeze(0), # BOS
353
+ input_ids[i:i + self.tokenizer.model_max_length - 2],
354
+ input_ids[-1].unsqueeze(0)) # PAD or EOS
355
+ ids_chunk = torch.cat(ids_chunk)
356
+
357
+ # 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
358
+ # 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
359
+ if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id:
360
+ ids_chunk[-1] = self.tokenizer.eos_token_id
361
+ # 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
362
+ if ids_chunk[1] == self.tokenizer.pad_token_id:
363
+ ids_chunk[1] = self.tokenizer.eos_token_id
364
+
365
+ iids_list.append(ids_chunk)
366
+
367
+ input_ids = torch.stack(iids_list) # 3,77
368
+ return input_ids
369
+
370
+ def register_image(self, info: ImageInfo):
371
+ self.image_data[info.image_key] = info
372
+
373
+ def make_buckets(self):
374
+ '''
375
+ bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
376
+ min_size and max_size are ignored when enable_bucket is False
377
+ '''
378
+ print("loading image sizes.")
379
+ for info in tqdm(self.image_data.values()):
380
+ if info.image_size is None:
381
+ info.image_size = self.get_image_size(info.absolute_path)
382
+
383
+ if self.enable_bucket:
384
+ print("make buckets")
385
+ else:
386
+ print("prepare dataset")
387
+
388
+ # bucketを作成し、画像をbucketに振り分ける
389
+ if self.enable_bucket:
390
+ if self.bucket_manager is None: # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み
391
+ self.bucket_manager = BucketManager(self.bucket_no_upscale, (self.width, self.height),
392
+ self.min_bucket_reso, self.max_bucket_reso, self.bucket_reso_steps)
393
+ if not self.bucket_no_upscale:
394
+ self.bucket_manager.make_buckets()
395
+ else:
396
+ print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
397
+
398
+ img_ar_errors = []
399
+ for image_info in self.image_data.values():
400
+ image_width, image_height = image_info.image_size
401
+ image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(image_width, image_height)
402
+
403
+ # print(image_info.image_key, image_info.bucket_reso)
404
+ img_ar_errors.append(abs(ar_error))
405
+
406
+ self.bucket_manager.sort()
407
+ else:
408
+ self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None)
409
+ self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ
410
+ for image_info in self.image_data.values():
411
+ image_width, image_height = image_info.image_size
412
+ image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height)
413
+
414
+ for image_info in self.image_data.values():
415
+ for _ in range(image_info.num_repeats):
416
+ self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key)
417
+
418
+ # bucket情報を表示、格納する
419
+ if self.enable_bucket:
420
+ self.bucket_info = {"buckets": {}}
421
+ print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
422
+ for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)):
423
+ count = len(bucket)
424
+ if count > 0:
425
+ self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
426
+ print(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
427
+
428
+ img_ar_errors = np.array(img_ar_errors)
429
+ mean_img_ar_error = np.mean(np.abs(img_ar_errors))
430
+ self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
431
+ print(f"mean ar error (without repeats): {mean_img_ar_error}")
432
+
433
+ # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
434
+ self.buckets_indices: List(BucketBatchIndex) = []
435
+ for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
436
+ batch_count = int(math.ceil(len(bucket) / self.batch_size))
437
+ for batch_index in range(batch_count):
438
+ self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index))
439
+
440
+ # ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す
441
+ #  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる
442
+ #
443
+ # # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
444
+ # # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
445
+ # # そのためバッチサイズを画像種類までに制限する
446
+ # # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
447
+ # # TO DO 正則化画像をepochまたがりで利用する仕組み
448
+ # num_of_image_types = len(set(bucket))
449
+ # bucket_batch_size = min(self.batch_size, num_of_image_types)
450
+ # batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
451
+ # # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
452
+ # for batch_index in range(batch_count):
453
+ # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
454
+ # ↑ここまで
455
+
456
+ self.shuffle_buckets()
457
+ self._length = len(self.buckets_indices)
458
+
459
+ def shuffle_buckets(self):
460
+ random.shuffle(self.buckets_indices)
461
+ self.bucket_manager.shuffle()
462
+
463
+ def load_image(self, image_path):
464
+ image = Image.open(image_path)
465
+ if not image.mode == "RGB":
466
+ image = image.convert("RGB")
467
+ img = np.array(image, np.uint8)
468
+ return img
469
+
470
+ def trim_and_resize_if_required(self, image, reso, resized_size):
471
+ image_height, image_width = image.shape[0:2]
472
+
473
+ if image_width != resized_size[0] or image_height != resized_size[1]:
474
+ # リサイズする
475
+ image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
476
+
477
+ image_height, image_width = image.shape[0:2]
478
+ if image_width > reso[0]:
479
+ trim_size = image_width - reso[0]
480
+ p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
481
+ # print("w", trim_size, p)
482
+ image = image[:, p:p + reso[0]]
483
+ if image_height > reso[1]:
484
+ trim_size = image_height - reso[1]
485
+ p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
486
+ # print("h", trim_size, p)
487
+ image = image[p:p + reso[1]]
488
+
489
+ assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
490
+ return image
491
+
492
+ def cache_latents(self, vae):
493
+ # TODO ここを高速化したい
494
+ print("caching latents.")
495
+ for info in tqdm(self.image_data.values()):
496
+ if info.latents_npz is not None:
497
+ info.latents = self.load_latents_from_npz(info, False)
498
+ info.latents = torch.FloatTensor(info.latents)
499
+ info.latents_flipped = self.load_latents_from_npz(info, True) # might be None
500
+ if info.latents_flipped is not None:
501
+ info.latents_flipped = torch.FloatTensor(info.latents_flipped)
502
+ continue
503
+
504
+ image = self.load_image(info.absolute_path)
505
+ image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
506
+
507
+ img_tensor = self.image_transforms(image)
508
+ img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
509
+ info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
510
+
511
+ if self.flip_aug:
512
+ image = image[:, ::-1].copy() # cannot convert to Tensor without copy
513
+ img_tensor = self.image_transforms(image)
514
+ img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
515
+ info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
516
+
517
+ def get_image_size(self, image_path):
518
+ image = Image.open(image_path)
519
+ return image.size
520
+
521
+ def load_image_with_face_info(self, image_path: str):
522
+ img = self.load_image(image_path)
523
+
524
+ face_cx = face_cy = face_w = face_h = 0
525
+ if self.face_crop_aug_range is not None:
526
+ tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
527
+ if len(tokens) >= 5:
528
+ face_cx = int(tokens[-4])
529
+ face_cy = int(tokens[-3])
530
+ face_w = int(tokens[-2])
531
+ face_h = int(tokens[-1])
532
+
533
+ return img, face_cx, face_cy, face_w, face_h
534
+
535
+ # いい感じに切り出す
536
+ def crop_target(self, image, face_cx, face_cy, face_w, face_h):
537
+ height, width = image.shape[0:2]
538
+ if height == self.height and width == self.width:
539
+ return image
540
+
541
+ # 画像サイズはsizeより大きいのでリサイズする
542
+ face_size = max(face_w, face_h)
543
+ min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
544
+ min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
545
+ max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
546
+ if min_scale >= max_scale: # range指定がmin==max
547
+ scale = min_scale
548
+ else:
549
+ scale = random.uniform(min_scale, max_scale)
550
+
551
+ nh = int(height * scale + .5)
552
+ nw = int(width * scale + .5)
553
+ assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
554
+ image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
555
+ face_cx = int(face_cx * scale + .5)
556
+ face_cy = int(face_cy * scale + .5)
557
+ height, width = nh, nw
558
+
559
+ # 顔を中心として448*640とかへ切り出す
560
+ for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
561
+ p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
562
+
563
+ if self.random_crop:
564
+ # 背景も含めるために顔を中心に置く確率を高めつつずらす
565
+ range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
566
+ p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
567
+ else:
568
+ # range指定があるときのみ、すこしだけランダムに(わりと適当)
569
+ if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
570
+ if face_size > self.size // 10 and face_size >= 40:
571
+ p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
572
+
573
+ p1 = max(0, min(p1, length - target_size))
574
+
575
+ if axis == 0:
576
+ image = image[p1:p1 + target_size, :]
577
+ else:
578
+ image = image[:, p1:p1 + target_size]
579
+
580
+ return image
581
+
582
+ def load_latents_from_npz(self, image_info: ImageInfo, flipped):
583
+ npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
584
+ if npz_file is None:
585
+ return None
586
+ return np.load(npz_file)['arr_0']
587
+
588
+ def __len__(self):
589
+ return self._length
590
+
591
+ def __getitem__(self, index):
592
+ if index == 0:
593
+ self.shuffle_buckets()
594
+
595
+ bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
596
+ bucket_batch_size = self.buckets_indices[index].bucket_batch_size
597
+ image_index = self.buckets_indices[index].batch_index * bucket_batch_size
598
+
599
+ loss_weights = []
600
+ captions = []
601
+ input_ids_list = []
602
+ latents_list = []
603
+ images = []
604
+
605
+ for image_key in bucket[image_index:image_index + bucket_batch_size]:
606
+ image_info = self.image_data[image_key]
607
+ loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
608
+
609
+ # image/latentsを処理する
610
+ if image_info.latents is not None:
611
+ latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped
612
+ image = None
613
+ elif image_info.latents_npz is not None:
614
+ latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5)
615
+ latents = torch.FloatTensor(latents)
616
+ image = None
617
+ else:
618
+ # 画像を読み込み、必要ならcropする
619
+ img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
620
+ im_h, im_w = img.shape[0:2]
621
+
622
+ if self.enable_bucket:
623
+ img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
624
+ else:
625
+ if face_cx > 0: # 顔位置情報あり
626
+ img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
627
+ elif im_h > self.height or im_w > self.width:
628
+ assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
629
+ if im_h > self.height:
630
+ p = random.randint(0, im_h - self.height)
631
+ img = img[p:p + self.height]
632
+ if im_w > self.width:
633
+ p = random.randint(0, im_w - self.width)
634
+ img = img[:, p:p + self.width]
635
+
636
+ im_h, im_w = img.shape[0:2]
637
+ assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
638
+
639
+ # augmentation
640
+ if self.aug is not None:
641
+ img = self.aug(image=img)['image']
642
+
643
+ latents = None
644
+ image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
645
+
646
+ images.append(image)
647
+ latents_list.append(latents)
648
+
649
+ caption = self.process_caption(image_info.caption)
650
+ captions.append(caption)
651
+ if not self.token_padding_disabled: # this option might be omitted in future
652
+ input_ids_list.append(self.get_input_ids(caption))
653
+
654
+ example = {}
655
+ example['loss_weights'] = torch.FloatTensor(loss_weights)
656
+
657
+ if self.token_padding_disabled:
658
+ # padding=True means pad in the batch
659
+ example['input_ids'] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
660
+ else:
661
+ # batch processing seems to be good
662
+ example['input_ids'] = torch.stack(input_ids_list)
663
+
664
+ if images[0] is not None:
665
+ images = torch.stack(images)
666
+ images = images.to(memory_format=torch.contiguous_format).float()
667
+ else:
668
+ images = None
669
+ example['images'] = images
670
+
671
+ example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None
672
+
673
+ if self.debug_dataset:
674
+ example['image_keys'] = bucket[image_index:image_index + self.batch_size]
675
+ example['captions'] = captions
676
+ return example
677
+
678
+
679
+ class DreamBoothDataset(BaseDataset):
680
+ def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
681
+ super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
682
+ resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
683
+
684
+ assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
685
+
686
+ self.batch_size = batch_size
687
+ self.size = min(self.width, self.height) # 短いほう
688
+ self.prior_loss_weight = prior_loss_weight
689
+ self.latents_cache = None
690
+
691
+ self.enable_bucket = enable_bucket
692
+ if self.enable_bucket:
693
+ assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
694
+ assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
695
+ self.min_bucket_reso = min_bucket_reso
696
+ self.max_bucket_reso = max_bucket_reso
697
+ self.bucket_reso_steps = bucket_reso_steps
698
+ self.bucket_no_upscale = bucket_no_upscale
699
+ else:
700
+ self.min_bucket_reso = None
701
+ self.max_bucket_reso = None
702
+ self.bucket_reso_steps = None # この情報は使われない
703
+ self.bucket_no_upscale = False
704
+
705
+ def read_caption(img_path):
706
+ # captionの候補ファイル名を作る
707
+ base_name = os.path.splitext(img_path)[0]
708
+ base_name_face_det = base_name
709
+ tokens = base_name.split("_")
710
+ if len(tokens) >= 5:
711
+ base_name_face_det = "_".join(tokens[:-4])
712
+ cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
713
+
714
+ caption = None
715
+ for cap_path in cap_paths:
716
+ if os.path.isfile(cap_path):
717
+ with open(cap_path, "rt", encoding='utf-8') as f:
718
+ try:
719
+ lines = f.readlines()
720
+ except UnicodeDecodeError as e:
721
+ print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
722
+ raise e
723
+ assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
724
+ caption = lines[0].strip()
725
+ break
726
+ return caption
727
+
728
+ def load_dreambooth_dir(dir):
729
+ if not os.path.isdir(dir):
730
+ # print(f"ignore file: {dir}")
731
+ return 0, [], []
732
+
733
+ tokens = os.path.basename(dir).split('_')
734
+ try:
735
+ n_repeats = int(tokens[0])
736
+ except ValueError as e:
737
+ print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
738
+ return 0, [], []
739
+
740
+ caption_by_folder = '_'.join(tokens[1:])
741
+ img_paths = glob_images(dir, "*")
742
+ print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
743
+
744
+ # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
745
+ captions = []
746
+ for img_path in img_paths:
747
+ cap_for_img = read_caption(img_path)
748
+ captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
749
+
750
+ self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
751
+
752
+ return n_repeats, img_paths, captions
753
+
754
+ print("prepare train images.")
755
+ train_dirs = os.listdir(train_data_dir)
756
+ num_train_images = 0
757
+ for dir in train_dirs:
758
+ n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
759
+ num_train_images += n_repeats * len(img_paths)
760
+
761
+ for img_path, caption in zip(img_paths, captions):
762
+ info = ImageInfo(img_path, n_repeats, caption, False, img_path)
763
+ self.register_image(info)
764
+
765
+ self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
766
+
767
+ print(f"{num_train_images} train images with repeating.")
768
+ self.num_train_images = num_train_images
769
+
770
+ # reg imageは数を数えて学習画像と同じ枚数にする
771
+ num_reg_images = 0
772
+ if reg_data_dir:
773
+ print("prepare reg images.")
774
+ reg_infos: List[ImageInfo] = []
775
+
776
+ reg_dirs = os.listdir(reg_data_dir)
777
+ for dir in reg_dirs:
778
+ n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
779
+ num_reg_images += n_repeats * len(img_paths)
780
+
781
+ for img_path, caption in zip(img_paths, captions):
782
+ info = ImageInfo(img_path, n_repeats, caption, True, img_path)
783
+ reg_infos.append(info)
784
+
785
+ self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
786
+
787
+ print(f"{num_reg_images} reg images.")
788
+ if num_train_images < num_reg_images:
789
+ print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
790
+
791
+ if num_reg_images == 0:
792
+ print("no regularization images / 正則化画像が見つかりませんでした")
793
+ else:
794
+ # num_repeatsを計算する:どうせ大した数ではないのでループで処理する
795
+ n = 0
796
+ first_loop = True
797
+ while n < num_train_images:
798
+ for info in reg_infos:
799
+ if first_loop:
800
+ self.register_image(info)
801
+ n += info.num_repeats
802
+ else:
803
+ info.num_repeats += 1
804
+ n += 1
805
+ if n >= num_train_images:
806
+ break
807
+ first_loop = False
808
+
809
+ self.num_reg_images = num_reg_images
810
+
811
+
812
+ class FineTuningDataset(BaseDataset):
813
+ def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
814
+ super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
815
+ resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
816
+
817
+ # メタデータを読み込む
818
+ if os.path.exists(json_file_name):
819
+ print(f"loading existing metadata: {json_file_name}")
820
+ with open(json_file_name, "rt", encoding='utf-8') as f:
821
+ metadata = json.load(f)
822
+ else:
823
+ raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}")
824
+
825
+ self.metadata = metadata
826
+ self.train_data_dir = train_data_dir
827
+ self.batch_size = batch_size
828
+
829
+ tags_list = []
830
+ for image_key, img_md in metadata.items():
831
+ # path情報を作る
832
+ if os.path.exists(image_key):
833
+ abs_path = image_key
834
+ else:
835
+ # わりといい加減だがいい方法が思いつかん
836
+ abs_path = glob_images(train_data_dir, image_key)
837
+ assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
838
+ abs_path = abs_path[0]
839
+
840
+ caption = img_md.get('caption')
841
+ tags = img_md.get('tags')
842
+ if caption is None:
843
+ caption = tags
844
+ elif tags is not None and len(tags) > 0:
845
+ caption = caption + ', ' + tags
846
+ tags_list.append(tags)
847
+ assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
848
+
849
+ image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
850
+ image_info.image_size = img_md.get('train_resolution')
851
+
852
+ if not self.color_aug and not self.random_crop:
853
+ # if npz exists, use them
854
+ image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
855
+
856
+ self.register_image(image_info)
857
+ self.num_train_images = len(metadata) * dataset_repeats
858
+ self.num_reg_images = 0
859
+
860
+ # TODO do not record tag freq when no tag
861
+ self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
862
+ self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
863
+
864
+ # check existence of all npz files
865
+ use_npz_latents = not (self.color_aug or self.random_crop)
866
+ if use_npz_latents:
867
+ npz_any = False
868
+ npz_all = True
869
+ for image_info in self.image_data.values():
870
+ has_npz = image_info.latents_npz is not None
871
+ npz_any = npz_any or has_npz
872
+
873
+ if self.flip_aug:
874
+ has_npz = has_npz and image_info.latents_npz_flipped is not None
875
+ npz_all = npz_all and has_npz
876
+
877
+ if npz_any and not npz_all:
878
+ break
879
+
880
+ if not npz_any:
881
+ use_npz_latents = False
882
+ print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します")
883
+ elif not npz_all:
884
+ use_npz_latents = False
885
+ print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
886
+ if self.flip_aug:
887
+ print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
888
+ # else:
889
+ # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
890
+
891
+ # check min/max bucket size
892
+ sizes = set()
893
+ resos = set()
894
+ for image_info in self.image_data.values():
895
+ if image_info.image_size is None:
896
+ sizes = None # not calculated
897
+ break
898
+ sizes.add(image_info.image_size[0])
899
+ sizes.add(image_info.image_size[1])
900
+ resos.add(tuple(image_info.image_size))
901
+
902
+ if sizes is None:
903
+ if use_npz_latents:
904
+ use_npz_latents = False
905
+ print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します")
906
+
907
+ assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
908
+
909
+ self.enable_bucket = enable_bucket
910
+ if self.enable_bucket:
911
+ self.min_bucket_reso = min_bucket_reso
912
+ self.max_bucket_reso = max_bucket_reso
913
+ self.bucket_reso_steps = bucket_reso_steps
914
+ self.bucket_no_upscale = bucket_no_upscale
915
+ else:
916
+ if not enable_bucket:
917
+ print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします")
918
+ print("using bucket info in metadata / メタデータ内のbucket情報を使います")
919
+ self.enable_bucket = True
920
+
921
+ assert not bucket_no_upscale, "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません"
922
+
923
+ # bucket情報を初期化しておく、make_bucketsで再作成しない
924
+ self.bucket_manager = BucketManager(False, None, None, None, None)
925
+ self.bucket_manager.set_predefined_resos(resos)
926
+
927
+ # npz情報をきれいにしておく
928
+ if not use_npz_latents:
929
+ for image_info in self.image_data.values():
930
+ image_info.latents_npz = image_info.latents_npz_flipped = None
931
+
932
+ def image_key_to_npz_file(self, image_key):
933
+ base_name = os.path.splitext(image_key)[0]
934
+ npz_file_norm = base_name + '.npz'
935
+
936
+ if os.path.exists(npz_file_norm):
937
+ # image_key is full path
938
+ npz_file_flip = base_name + '_flip.npz'
939
+ if not os.path.exists(npz_file_flip):
940
+ npz_file_flip = None
941
+ return npz_file_norm, npz_file_flip
942
+
943
+ # image_key is relative path
944
+ npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz')
945
+ npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz')
946
+
947
+ if not os.path.exists(npz_file_norm):
948
+ npz_file_norm = None
949
+ npz_file_flip = None
950
+ elif not os.path.exists(npz_file_flip):
951
+ npz_file_flip = None
952
+
953
+ return npz_file_norm, npz_file_flip
954
+
955
+
956
+ def debug_dataset(train_dataset, show_input_ids=False):
957
+ print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
958
+ print("Escape for exit. / Escキーで中断、終了します")
959
+
960
+ train_dataset.set_current_epoch(1)
961
+ k = 0
962
+ for i, example in enumerate(train_dataset):
963
+ if example['latents'] is not None:
964
+ print(f"sample has latents from npz file: {example['latents'].size()}")
965
+ for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
966
+ print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
967
+ if show_input_ids:
968
+ print(f"input ids: {iid}")
969
+ if example['images'] is not None:
970
+ im = example['images'][j]
971
+ print(f"image size: {im.size()}")
972
+ im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
973
+ im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
974
+ im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
975
+ if os.name == 'nt': # only windows
976
+ cv2.imshow("img", im)
977
+ k = cv2.waitKey()
978
+ cv2.destroyAllWindows()
979
+ if k == 27:
980
+ break
981
+ if k == 27 or (example['images'] is None and i >= 8):
982
+ break
983
+
984
+
985
+ def glob_images(directory, base="*"):
986
+ img_paths = []
987
+ for ext in IMAGE_EXTENSIONS:
988
+ if base == '*':
989
+ img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
990
+ else:
991
+ img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
992
+ # img_paths = list(set(img_paths)) # 重複を排除
993
+ # img_paths.sort()
994
+ return img_paths
995
+
996
+
997
+ def glob_images_pathlib(dir_path, recursive):
998
+ image_paths = []
999
+ if recursive:
1000
+ for ext in IMAGE_EXTENSIONS:
1001
+ image_paths += list(dir_path.rglob('*' + ext))
1002
+ else:
1003
+ for ext in IMAGE_EXTENSIONS:
1004
+ image_paths += list(dir_path.glob('*' + ext))
1005
+ # image_paths = list(set(image_paths)) # 重複を排除
1006
+ # image_paths.sort()
1007
+ return image_paths
1008
+
1009
+ # endregion
1010
+
1011
+
1012
+ # region モジュール入れ替え部
1013
+ """
1014
+ 高速化のためのモジュール入れ替え
1015
+ """
1016
+
1017
+ # FlashAttentionを使うCrossAttention
1018
+ # based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
1019
+ # LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
1020
+
1021
+ # constants
1022
+
1023
+ EPSILON = 1e-6
1024
+
1025
+ # helper functions
1026
+
1027
+
1028
+ def exists(val):
1029
+ return val is not None
1030
+
1031
+
1032
+ def default(val, d):
1033
+ return val if exists(val) else d
1034
+
1035
+
1036
+ def model_hash(filename):
1037
+ """Old model hash used by stable-diffusion-webui"""
1038
+ try:
1039
+ with open(filename, "rb") as file:
1040
+ m = hashlib.sha256()
1041
+
1042
+ file.seek(0x100000)
1043
+ m.update(file.read(0x10000))
1044
+ return m.hexdigest()[0:8]
1045
+ except FileNotFoundError:
1046
+ return 'NOFILE'
1047
+
1048
+
1049
+ def calculate_sha256(filename):
1050
+ """New model hash used by stable-diffusion-webui"""
1051
+ hash_sha256 = hashlib.sha256()
1052
+ blksize = 1024 * 1024
1053
+
1054
+ with open(filename, "rb") as f:
1055
+ for chunk in iter(lambda: f.read(blksize), b""):
1056
+ hash_sha256.update(chunk)
1057
+
1058
+ return hash_sha256.hexdigest()
1059
+
1060
+
1061
+ def precalculate_safetensors_hashes(tensors, metadata):
1062
+ """Precalculate the model hashes needed by sd-webui-additional-networks to
1063
+ save time on indexing the model later."""
1064
+
1065
+ # Because writing user metadata to the file can change the result of
1066
+ # sd_models.model_hash(), only retain the training metadata for purposes of
1067
+ # calculating the hash, as they are meant to be immutable
1068
+ metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
1069
+
1070
+ bytes = safetensors.torch.save(tensors, metadata)
1071
+ b = BytesIO(bytes)
1072
+
1073
+ model_hash = addnet_hash_safetensors(b)
1074
+ legacy_hash = addnet_hash_legacy(b)
1075
+ return model_hash, legacy_hash
1076
+
1077
+
1078
+ def addnet_hash_legacy(b):
1079
+ """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
1080
+ m = hashlib.sha256()
1081
+
1082
+ b.seek(0x100000)
1083
+ m.update(b.read(0x10000))
1084
+ return m.hexdigest()[0:8]
1085
+
1086
+
1087
+ def addnet_hash_safetensors(b):
1088
+ """New model hash used by sd-webui-additional-networks for .safetensors format files"""
1089
+ hash_sha256 = hashlib.sha256()
1090
+ blksize = 1024 * 1024
1091
+
1092
+ b.seek(0)
1093
+ header = b.read(8)
1094
+ n = int.from_bytes(header, "little")
1095
+
1096
+ offset = n + 8
1097
+ b.seek(offset)
1098
+ for chunk in iter(lambda: b.read(blksize), b""):
1099
+ hash_sha256.update(chunk)
1100
+
1101
+ return hash_sha256.hexdigest()
1102
+
1103
+
1104
+ def get_git_revision_hash() -> str:
1105
+ try:
1106
+ return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=os.path.dirname(__file__)).decode('ascii').strip()
1107
+ except:
1108
+ return "(unknown)"
1109
+
1110
+
1111
+ # flash attention forwards and backwards
1112
+
1113
+ # https://arxiv.org/abs/2205.14135
1114
+
1115
+
1116
+ class FlashAttentionFunction(torch.autograd.function.Function):
1117
+ @ staticmethod
1118
+ @ torch.no_grad()
1119
+ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
1120
+ """ Algorithm 2 in the paper """
1121
+
1122
+ device = q.device
1123
+ dtype = q.dtype
1124
+ max_neg_value = -torch.finfo(q.dtype).max
1125
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
1126
+
1127
+ o = torch.zeros_like(q)
1128
+ all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
1129
+ all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
1130
+
1131
+ scale = (q.shape[-1] ** -0.5)
1132
+
1133
+ if not exists(mask):
1134
+ mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
1135
+ else:
1136
+ mask = rearrange(mask, 'b n -> b 1 1 n')
1137
+ mask = mask.split(q_bucket_size, dim=-1)
1138
+
1139
+ row_splits = zip(
1140
+ q.split(q_bucket_size, dim=-2),
1141
+ o.split(q_bucket_size, dim=-2),
1142
+ mask,
1143
+ all_row_sums.split(q_bucket_size, dim=-2),
1144
+ all_row_maxes.split(q_bucket_size, dim=-2),
1145
+ )
1146
+
1147
+ for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
1148
+ q_start_index = ind * q_bucket_size - qk_len_diff
1149
+
1150
+ col_splits = zip(
1151
+ k.split(k_bucket_size, dim=-2),
1152
+ v.split(k_bucket_size, dim=-2),
1153
+ )
1154
+
1155
+ for k_ind, (kc, vc) in enumerate(col_splits):
1156
+ k_start_index = k_ind * k_bucket_size
1157
+
1158
+ attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
1159
+
1160
+ if exists(row_mask):
1161
+ attn_weights.masked_fill_(~row_mask, max_neg_value)
1162
+
1163
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
1164
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
1165
+ device=device).triu(q_start_index - k_start_index + 1)
1166
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
1167
+
1168
+ block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
1169
+ attn_weights -= block_row_maxes
1170
+ exp_weights = torch.exp(attn_weights)
1171
+
1172
+ if exists(row_mask):
1173
+ exp_weights.masked_fill_(~row_mask, 0.)
1174
+
1175
+ block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
1176
+
1177
+ new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
1178
+
1179
+ exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
1180
+
1181
+ exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
1182
+ exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
1183
+
1184
+ new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
1185
+
1186
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
1187
+
1188
+ row_maxes.copy_(new_row_maxes)
1189
+ row_sums.copy_(new_row_sums)
1190
+
1191
+ ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
1192
+ ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
1193
+
1194
+ return o
1195
+
1196
+ @ staticmethod
1197
+ @ torch.no_grad()
1198
+ def backward(ctx, do):
1199
+ """ Algorithm 4 in the paper """
1200
+
1201
+ causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
1202
+ q, k, v, o, l, m = ctx.saved_tensors
1203
+
1204
+ device = q.device
1205
+
1206
+ max_neg_value = -torch.finfo(q.dtype).max
1207
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
1208
+
1209
+ dq = torch.zeros_like(q)
1210
+ dk = torch.zeros_like(k)
1211
+ dv = torch.zeros_like(v)
1212
+
1213
+ row_splits = zip(
1214
+ q.split(q_bucket_size, dim=-2),
1215
+ o.split(q_bucket_size, dim=-2),
1216
+ do.split(q_bucket_size, dim=-2),
1217
+ mask,
1218
+ l.split(q_bucket_size, dim=-2),
1219
+ m.split(q_bucket_size, dim=-2),
1220
+ dq.split(q_bucket_size, dim=-2)
1221
+ )
1222
+
1223
+ for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
1224
+ q_start_index = ind * q_bucket_size - qk_len_diff
1225
+
1226
+ col_splits = zip(
1227
+ k.split(k_bucket_size, dim=-2),
1228
+ v.split(k_bucket_size, dim=-2),
1229
+ dk.split(k_bucket_size, dim=-2),
1230
+ dv.split(k_bucket_size, dim=-2),
1231
+ )
1232
+
1233
+ for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
1234
+ k_start_index = k_ind * k_bucket_size
1235
+
1236
+ attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
1237
+
1238
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
1239
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
1240
+ device=device).triu(q_start_index - k_start_index + 1)
1241
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
1242
+
1243
+ exp_attn_weights = torch.exp(attn_weights - mc)
1244
+
1245
+ if exists(row_mask):
1246
+ exp_attn_weights.masked_fill_(~row_mask, 0.)
1247
+
1248
+ p = exp_attn_weights / lc
1249
+
1250
+ dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
1251
+ dp = einsum('... i d, ... j d -> ... i j', doc, vc)
1252
+
1253
+ D = (doc * oc).sum(dim=-1, keepdims=True)
1254
+ ds = p * scale * (dp - D)
1255
+
1256
+ dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
1257
+ dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
1258
+
1259
+ dqc.add_(dq_chunk)
1260
+ dkc.add_(dk_chunk)
1261
+ dvc.add_(dv_chunk)
1262
+
1263
+ return dq, dk, dv, None, None, None, None
1264
+
1265
+
1266
+ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
1267
+ if mem_eff_attn:
1268
+ replace_unet_cross_attn_to_memory_efficient()
1269
+ elif xformers:
1270
+ replace_unet_cross_attn_to_xformers()
1271
+
1272
+
1273
+ def replace_unet_cross_attn_to_memory_efficient():
1274
+ print("Replace CrossAttention.forward to use FlashAttention (not xformers)")
1275
+ flash_func = FlashAttentionFunction
1276
+
1277
+ def forward_flash_attn(self, x, context=None, mask=None):
1278
+ q_bucket_size = 512
1279
+ k_bucket_size = 1024
1280
+
1281
+ h = self.heads
1282
+ q = self.to_q(x)
1283
+
1284
+ context = context if context is not None else x
1285
+ context = context.to(x.dtype)
1286
+
1287
+ if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
1288
+ context_k, context_v = self.hypernetwork.forward(x, context)
1289
+ context_k = context_k.to(x.dtype)
1290
+ context_v = context_v.to(x.dtype)
1291
+ else:
1292
+ context_k = context
1293
+ context_v = context
1294
+
1295
+ k = self.to_k(context_k)
1296
+ v = self.to_v(context_v)
1297
+ del context, x
1298
+
1299
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
1300
+
1301
+ out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
1302
+
1303
+ out = rearrange(out, 'b h n d -> b n (h d)')
1304
+
1305
+ # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`)
1306
+ out = self.to_out[0](out)
1307
+ out = self.to_out[1](out)
1308
+ return out
1309
+
1310
+ diffusers.models.attention.CrossAttention.forward = forward_flash_attn
1311
+
1312
+
1313
+ def replace_unet_cross_attn_to_xformers():
1314
+ print("Replace CrossAttention.forward to use xformers")
1315
+ try:
1316
+ import xformers.ops
1317
+ except ImportError:
1318
+ raise ImportError("No xformers / xformersがインストールされていないようです")
1319
+
1320
+ def forward_xformers(self, x, context=None, mask=None):
1321
+ h = self.heads
1322
+ q_in = self.to_q(x)
1323
+
1324
+ context = default(context, x)
1325
+ context = context.to(x.dtype)
1326
+
1327
+ if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
1328
+ context_k, context_v = self.hypernetwork.forward(x, context)
1329
+ context_k = context_k.to(x.dtype)
1330
+ context_v = context_v.to(x.dtype)
1331
+ else:
1332
+ context_k = context
1333
+ context_v = context
1334
+
1335
+ k_in = self.to_k(context_k)
1336
+ v_in = self.to_v(context_v)
1337
+
1338
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
1339
+ del q_in, k_in, v_in
1340
+
1341
+ q = q.contiguous()
1342
+ k = k.contiguous()
1343
+ v = v.contiguous()
1344
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
1345
+
1346
+ out = rearrange(out, 'b n h d -> b n (h d)', h=h)
1347
+
1348
+ # diffusers 0.7.0~
1349
+ out = self.to_out[0](out)
1350
+ out = self.to_out[1](out)
1351
+ return out
1352
+
1353
+ diffusers.models.attention.CrossAttention.forward = forward_xformers
1354
+ # endregion
1355
+
1356
+
1357
+ # region arguments
1358
+
1359
+ def add_sd_models_arguments(parser: argparse.ArgumentParser):
1360
+ # for pretrained models
1361
+ parser.add_argument("--v2", action='store_true',
1362
+ help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
1363
+ parser.add_argument("--v_parameterization", action='store_true',
1364
+ help='enable v-parameterization training / v-parameterization学習を有効にする')
1365
+ parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
1366
+ help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
1367
+
1368
+
1369
+ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
1370
+ parser.add_argument("--output_dir", type=str, default=None,
1371
+ help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
1372
+ parser.add_argument("--output_name", type=str, default=None,
1373
+ help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
1374
+ parser.add_argument("--save_precision", type=str, default=None,
1375
+ choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
1376
+ parser.add_argument("--save_every_n_epochs", type=int, default=None,
1377
+ help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
1378
+ parser.add_argument("--save_n_epoch_ratio", type=int, default=None,
1379
+ help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)")
1380
+ parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
1381
+ parser.add_argument("--save_last_n_epochs_state", type=int, default=None,
1382
+ help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
1383
+ parser.add_argument("--save_state", action="store_true",
1384
+ help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
1385
+ parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
1386
+
1387
+ parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
1388
+ parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
1389
+ help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
1390
+ parser.add_argument("--use_8bit_adam", action="store_true",
1391
+ help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
1392
+ parser.add_argument("--use_lion_optimizer", action="store_true",
1393
+ help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
1394
+ parser.add_argument("--mem_eff_attn", action="store_true",
1395
+ help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
1396
+ parser.add_argument("--xformers", action="store_true",
1397
+ help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
1398
+ parser.add_argument("--vae", type=str, default=None,
1399
+ help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
1400
+
1401
+ parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
1402
+ parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
1403
+ parser.add_argument("--max_train_epochs", type=int, default=None,
1404
+ help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
1405
+ parser.add_argument("--max_data_loader_n_workers", type=int, default=8,
1406
+ help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
1407
+ parser.add_argument("--persistent_data_loader_workers", action="store_true",
1408
+ help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)")
1409
+ parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
1410
+ parser.add_argument("--gradient_checkpointing", action="store_true",
1411
+ help="enable gradient checkpointing / grandient checkpointingを有効にする")
1412
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
1413
+ help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数")
1414
+ parser.add_argument("--mixed_precision", type=str, default="no",
1415
+ choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
1416
+ parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
1417
+ parser.add_argument("--clip_skip", type=int, default=None,
1418
+ help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
1419
+ parser.add_argument("--logging_dir", type=str, default=None,
1420
+ help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
1421
+ parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
1422
+ parser.add_argument("--lr_scheduler", type=str, default="constant",
1423
+ help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
1424
+ parser.add_argument("--lr_warmup_steps", type=int, default=0,
1425
+ help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
1426
+ parser.add_argument("--noise_offset", type=float, default=None,
1427
+ help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
1428
+ parser.add_argument("--lowram", action="store_true",
1429
+ help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
1430
+
1431
+ if support_dreambooth:
1432
+ # DreamBooth training
1433
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0,
1434
+ help="loss weight for regularization images / 正則化画像のlossの重み")
1435
+
1436
+
1437
+ def verify_training_args(args: argparse.Namespace):
1438
+ if args.v_parameterization and not args.v2:
1439
+ print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
1440
+ if args.v2 and args.clip_skip is not None:
1441
+ print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
1442
+
1443
+
1444
+ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool):
1445
+ # dataset common
1446
+ parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
1447
+ parser.add_argument("--shuffle_caption", action="store_true",
1448
+ help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする")
1449
+ parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
1450
+ parser.add_argument("--caption_extention", type=str, default=None,
1451
+ help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
1452
+ parser.add_argument("--keep_tokens", type=int, default=None,
1453
+ help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
1454
+ parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
1455
+ parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
1456
+ parser.add_argument("--face_crop_aug_range", type=str, default=None,
1457
+ help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)")
1458
+ parser.add_argument("--random_crop", action="store_true",
1459
+ help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)")
1460
+ parser.add_argument("--debug_dataset", action="store_true",
1461
+ help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
1462
+ parser.add_argument("--resolution", type=str, default=None,
1463
+ help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)")
1464
+ parser.add_argument("--cache_latents", action="store_true",
1465
+ help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
1466
+ parser.add_argument("--enable_bucket", action="store_true",
1467
+ help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
1468
+ parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
1469
+ parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
1470
+ parser.add_argument("--bucket_reso_steps", type=int, default=64,
1471
+ help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
1472
+ parser.add_argument("--bucket_no_upscale", action="store_true",
1473
+ help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
1474
+
1475
+ if support_caption_dropout:
1476
+ # Textual Inversion はcaptionのdropoutをsupportしない
1477
+ # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
1478
+ parser.add_argument("--caption_dropout_rate", type=float, default=0,
1479
+ help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
1480
+ parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
1481
+ help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
1482
+ parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
1483
+ help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
1484
+
1485
+ if support_dreambooth:
1486
+ # DreamBooth dataset
1487
+ parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
1488
+
1489
+ if support_caption:
1490
+ # caption dataset
1491
+ parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル")
1492
+ parser.add_argument("--dataset_repeats", type=int, default=1,
1493
+ help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数")
1494
+
1495
+
1496
+ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
1497
+ parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"],
1498
+ help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)")
1499
+ parser.add_argument("--use_safetensors", action='store_true',
1500
+ help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)")
1501
+
1502
+ # endregion
1503
+
1504
+ # region utils
1505
+
1506
+
1507
+ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
1508
+ # backward compatibility
1509
+ if args.caption_extention is not None:
1510
+ args.caption_extension = args.caption_extention
1511
+ args.caption_extention = None
1512
+
1513
+ if args.cache_latents:
1514
+ assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
1515
+ assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
1516
+
1517
+ # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
1518
+ if args.resolution is not None:
1519
+ args.resolution = tuple([int(r) for r in args.resolution.split(',')])
1520
+ if len(args.resolution) == 1:
1521
+ args.resolution = (args.resolution[0], args.resolution[0])
1522
+ assert len(args.resolution) == 2, \
1523
+ f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
1524
+
1525
+ if args.face_crop_aug_range is not None:
1526
+ args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
1527
+ assert len(args.face_crop_aug_range) == 2 and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1], \
1528
+ f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}"
1529
+ else:
1530
+ args.face_crop_aug_range = None
1531
+
1532
+ if support_metadata:
1533
+ if args.in_json is not None and (args.color_aug or args.random_crop):
1534
+ print(f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます")
1535
+
1536
+
1537
+ def load_tokenizer(args: argparse.Namespace):
1538
+ print("prepare tokenizer")
1539
+ if args.v2:
1540
+ tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
1541
+ else:
1542
+ tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
1543
+ if args.max_token_length is not None:
1544
+ print(f"update token length: {args.max_token_length}")
1545
+ return tokenizer
1546
+
1547
+
1548
+ def prepare_accelerator(args: argparse.Namespace):
1549
+ if args.logging_dir is None:
1550
+ log_with = None
1551
+ logging_dir = None
1552
+ else:
1553
+ log_with = "tensorboard"
1554
+ log_prefix = "" if args.log_prefix is None else args.log_prefix
1555
+ logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime())
1556
+
1557
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision,
1558
+ log_with=log_with, logging_dir=logging_dir)
1559
+
1560
+ # accelerateの互換性問題を解決する
1561
+ accelerator_0_15 = True
1562
+ try:
1563
+ accelerator.unwrap_model("dummy", True)
1564
+ print("Using accelerator 0.15.0 or above.")
1565
+ except TypeError:
1566
+ accelerator_0_15 = False
1567
+
1568
+ def unwrap_model(model):
1569
+ if accelerator_0_15:
1570
+ return accelerator.unwrap_model(model, True)
1571
+ return accelerator.unwrap_model(model)
1572
+
1573
+ return accelerator, unwrap_model
1574
+
1575
+
1576
+ def prepare_dtype(args: argparse.Namespace):
1577
+ weight_dtype = torch.float32
1578
+ if args.mixed_precision == "fp16":
1579
+ weight_dtype = torch.float16
1580
+ elif args.mixed_precision == "bf16":
1581
+ weight_dtype = torch.bfloat16
1582
+
1583
+ save_dtype = None
1584
+ if args.save_precision == "fp16":
1585
+ save_dtype = torch.float16
1586
+ elif args.save_precision == "bf16":
1587
+ save_dtype = torch.bfloat16
1588
+ elif args.save_precision == "float":
1589
+ save_dtype = torch.float32
1590
+
1591
+ return weight_dtype, save_dtype
1592
+
1593
+
1594
+ def load_target_model(args: argparse.Namespace, weight_dtype):
1595
+ load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers
1596
+ if load_stable_diffusion_format:
1597
+ print("load StableDiffusion checkpoint")
1598
+ text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
1599
+ else:
1600
+ print("load Diffusers pretrained models")
1601
+ pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
1602
+ text_encoder = pipe.text_encoder
1603
+ vae = pipe.vae
1604
+ unet = pipe.unet
1605
+ del pipe
1606
+
1607
+ # VAEを読み込む
1608
+ if args.vae is not None:
1609
+ vae = model_util.load_vae(args.vae, weight_dtype)
1610
+ print("additional VAE loaded")
1611
+
1612
+ return text_encoder, vae, unet, load_stable_diffusion_format
1613
+
1614
+
1615
+ def patch_accelerator_for_fp16_training(accelerator):
1616
+ org_unscale_grads = accelerator.scaler._unscale_grads_
1617
+
1618
+ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
1619
+ return org_unscale_grads(optimizer, inv_scale, found_inf, True)
1620
+
1621
+ accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
1622
+
1623
+
1624
+ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None):
1625
+ # with no_token_padding, the length is not max length, return result immediately
1626
+ if input_ids.size()[-1] != tokenizer.model_max_length:
1627
+ return text_encoder(input_ids)[0]
1628
+
1629
+ b_size = input_ids.size()[0]
1630
+ input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
1631
+
1632
+ if args.clip_skip is None:
1633
+ encoder_hidden_states = text_encoder(input_ids)[0]
1634
+ else:
1635
+ enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
1636
+ encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
1637
+ encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
1638
+
1639
+ # bs*3, 77, 768 or 1024
1640
+ encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
1641
+
1642
+ if args.max_token_length is not None:
1643
+ if args.v2:
1644
+ # v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
1645
+ states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
1646
+ for i in range(1, args.max_token_length, tokenizer.model_max_length):
1647
+ chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # <BOS> の後から 最後の前まで
1648
+ if i > 0:
1649
+ for j in range(len(chunk)):
1650
+ if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
1651
+ chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
1652
+ states_list.append(chunk) # <BOS> の後から <EOS> の前まで
1653
+ states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
1654
+ encoder_hidden_states = torch.cat(states_list, dim=1)
1655
+ else:
1656
+ # v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
1657
+ states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
1658
+ for i in range(1, args.max_token_length, tokenizer.model_max_length):
1659
+ states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
1660
+ states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
1661
+ encoder_hidden_states = torch.cat(states_list, dim=1)
1662
+
1663
+ if weight_dtype is not None:
1664
+ # this is required for additional network training
1665
+ encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
1666
+
1667
+ return encoder_hidden_states
1668
+
1669
+
1670
+ def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch):
1671
+ model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
1672
+ ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt")
1673
+ return model_name, ckpt_name
1674
+
1675
+
1676
+ def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int):
1677
+ saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
1678
+ if saving:
1679
+ os.makedirs(args.output_dir, exist_ok=True)
1680
+ save_func()
1681
+
1682
+ if args.save_last_n_epochs is not None:
1683
+ remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
1684
+ remove_old_func(remove_epoch_no)
1685
+ return saving
1686
+
1687
+
1688
+ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae):
1689
+ epoch_no = epoch + 1
1690
+ model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no)
1691
+
1692
+ if save_stable_diffusion_format:
1693
+ def save_sd():
1694
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
1695
+ print(f"saving checkpoint: {ckpt_file}")
1696
+ model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
1697
+ src_path, epoch_no, global_step, save_dtype, vae)
1698
+
1699
+ def remove_sd(old_epoch_no):
1700
+ _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
1701
+ old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
1702
+ if os.path.exists(old_ckpt_file):
1703
+ print(f"removing old checkpoint: {old_ckpt_file}")
1704
+ os.remove(old_ckpt_file)
1705
+
1706
+ save_func = save_sd
1707
+ remove_old_func = remove_sd
1708
+ else:
1709
+ def save_du():
1710
+ out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no))
1711
+ print(f"saving model: {out_dir}")
1712
+ os.makedirs(out_dir, exist_ok=True)
1713
+ model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
1714
+ src_path, vae=vae, use_safetensors=use_safetensors)
1715
+
1716
+ def remove_du(old_epoch_no):
1717
+ out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
1718
+ if os.path.exists(out_dir_old):
1719
+ print(f"removing old model: {out_dir_old}")
1720
+ shutil.rmtree(out_dir_old)
1721
+
1722
+ save_func = save_du
1723
+ remove_old_func = remove_du
1724
+
1725
+ saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
1726
+ if saving and args.save_state:
1727
+ save_state_on_epoch_end(args, accelerator, model_name, epoch_no)
1728
+
1729
+
1730
+ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no):
1731
+ print("saving state.")
1732
+ accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
1733
+
1734
+ last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
1735
+ if last_n_epochs is not None:
1736
+ remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs
1737
+ state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
1738
+ if os.path.exists(state_dir_old):
1739
+ print(f"removing old state: {state_dir_old}")
1740
+ shutil.rmtree(state_dir_old)
1741
+
1742
+
1743
+ def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, global_step: int, text_encoder, unet, vae):
1744
+ model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
1745
+
1746
+ if save_stable_diffusion_format:
1747
+ os.makedirs(args.output_dir, exist_ok=True)
1748
+
1749
+ ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt")
1750
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
1751
+
1752
+ print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
1753
+ model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
1754
+ src_path, epoch, global_step, save_dtype, vae)
1755
+ else:
1756
+ out_dir = os.path.join(args.output_dir, model_name)
1757
+ os.makedirs(out_dir, exist_ok=True)
1758
+
1759
+ print(f"save trained model as Diffusers to {out_dir}")
1760
+ model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
1761
+ src_path, vae=vae, use_safetensors=use_safetensors)
1762
+
1763
+
1764
+ def save_state_on_train_end(args: argparse.Namespace, accelerator):
1765
+ print("saving last state.")
1766
+ os.makedirs(args.output_dir, exist_ok=True)
1767
+ model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
1768
+ accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
1769
+
1770
+ # endregion
1771
+
1772
+ # region 前処理用
1773
+
1774
+
1775
+ class ImageLoadingDataset(torch.utils.data.Dataset):
1776
+ def __init__(self, image_paths):
1777
+ self.images = image_paths
1778
+
1779
+ def __len__(self):
1780
+ return len(self.images)
1781
+
1782
+ def __getitem__(self, idx):
1783
+ img_path = self.images[idx]
1784
+
1785
+ try:
1786
+ image = Image.open(img_path).convert("RGB")
1787
+ # convert to tensor temporarily so dataloader will accept it
1788
+ tensor_pil = transforms.functional.pil_to_tensor(image)
1789
+ except Exception as e:
1790
+ print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
1791
+ return None
1792
+
1793
+ return (tensor_pil, img_path)
1794
+
1795
+
1796
+ # endregion
locon/__init__.py ADDED
File without changes
locon/kohya_model_utils.py ADDED
@@ -0,0 +1,1184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ https://github.com/kohya-ss/sd-scripts/blob/main/library/model_util.py
3
+ '''
4
+ # v1: split from train_db_fixed.py.
5
+ # v2: support safetensors
6
+
7
+ import math
8
+ import os
9
+ import torch
10
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
11
+ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
12
+ from safetensors.torch import load_file, save_file
13
+
14
+ # DiffUsers版StableDiffusionのモデルパラメータ
15
+ NUM_TRAIN_TIMESTEPS = 1000
16
+ BETA_START = 0.00085
17
+ BETA_END = 0.0120
18
+
19
+ UNET_PARAMS_MODEL_CHANNELS = 320
20
+ UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
21
+ UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
22
+ UNET_PARAMS_IMAGE_SIZE = 32 # unused
23
+ UNET_PARAMS_IN_CHANNELS = 4
24
+ UNET_PARAMS_OUT_CHANNELS = 4
25
+ UNET_PARAMS_NUM_RES_BLOCKS = 2
26
+ UNET_PARAMS_CONTEXT_DIM = 768
27
+ UNET_PARAMS_NUM_HEADS = 8
28
+
29
+ VAE_PARAMS_Z_CHANNELS = 4
30
+ VAE_PARAMS_RESOLUTION = 256
31
+ VAE_PARAMS_IN_CHANNELS = 3
32
+ VAE_PARAMS_OUT_CH = 3
33
+ VAE_PARAMS_CH = 128
34
+ VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
35
+ VAE_PARAMS_NUM_RES_BLOCKS = 2
36
+
37
+ # V2
38
+ V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
39
+ V2_UNET_PARAMS_CONTEXT_DIM = 1024
40
+
41
+ # Diffusersの設定を読み込むための参照モデル
42
+ DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
43
+ DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
44
+
45
+
46
+ # region StableDiffusion->Diffusersの変換コード
47
+ # convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
48
+
49
+
50
+ def shave_segments(path, n_shave_prefix_segments=1):
51
+ """
52
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
53
+ """
54
+ if n_shave_prefix_segments >= 0:
55
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
56
+ else:
57
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
58
+
59
+
60
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
61
+ """
62
+ Updates paths inside resnets to the new naming scheme (local renaming)
63
+ """
64
+ mapping = []
65
+ for old_item in old_list:
66
+ new_item = old_item.replace("in_layers.0", "norm1")
67
+ new_item = new_item.replace("in_layers.2", "conv1")
68
+
69
+ new_item = new_item.replace("out_layers.0", "norm2")
70
+ new_item = new_item.replace("out_layers.3", "conv2")
71
+
72
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
73
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
74
+
75
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
76
+
77
+ mapping.append({"old": old_item, "new": new_item})
78
+
79
+ return mapping
80
+
81
+
82
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
83
+ """
84
+ Updates paths inside resnets to the new naming scheme (local renaming)
85
+ """
86
+ mapping = []
87
+ for old_item in old_list:
88
+ new_item = old_item
89
+
90
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
91
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
92
+
93
+ mapping.append({"old": old_item, "new": new_item})
94
+
95
+ return mapping
96
+
97
+
98
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
99
+ """
100
+ Updates paths inside attentions to the new naming scheme (local renaming)
101
+ """
102
+ mapping = []
103
+ for old_item in old_list:
104
+ new_item = old_item
105
+
106
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
107
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
108
+
109
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
110
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
111
+
112
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
113
+
114
+ mapping.append({"old": old_item, "new": new_item})
115
+
116
+ return mapping
117
+
118
+
119
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
120
+ """
121
+ Updates paths inside attentions to the new naming scheme (local renaming)
122
+ """
123
+ mapping = []
124
+ for old_item in old_list:
125
+ new_item = old_item
126
+
127
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
128
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
129
+
130
+ new_item = new_item.replace("q.weight", "query.weight")
131
+ new_item = new_item.replace("q.bias", "query.bias")
132
+
133
+ new_item = new_item.replace("k.weight", "key.weight")
134
+ new_item = new_item.replace("k.bias", "key.bias")
135
+
136
+ new_item = new_item.replace("v.weight", "value.weight")
137
+ new_item = new_item.replace("v.bias", "value.bias")
138
+
139
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
140
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
141
+
142
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
143
+
144
+ mapping.append({"old": old_item, "new": new_item})
145
+
146
+ return mapping
147
+
148
+
149
+ def assign_to_checkpoint(
150
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
151
+ ):
152
+ """
153
+ This does the final conversion step: take locally converted weights and apply a global renaming
154
+ to them. It splits attention layers, and takes into account additional replacements
155
+ that may arise.
156
+
157
+ Assigns the weights to the new checkpoint.
158
+ """
159
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
160
+
161
+ # Splits the attention layers into three variables.
162
+ if attention_paths_to_split is not None:
163
+ for path, path_map in attention_paths_to_split.items():
164
+ old_tensor = old_checkpoint[path]
165
+ channels = old_tensor.shape[0] // 3
166
+
167
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
168
+
169
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
170
+
171
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
172
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
173
+
174
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
175
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
176
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
177
+
178
+ for path in paths:
179
+ new_path = path["new"]
180
+
181
+ # These have already been assigned
182
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
183
+ continue
184
+
185
+ # Global renaming happens here
186
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
187
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
188
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
189
+
190
+ if additional_replacements is not None:
191
+ for replacement in additional_replacements:
192
+ new_path = new_path.replace(replacement["old"], replacement["new"])
193
+
194
+ # proj_attn.weight has to be converted from conv 1D to linear
195
+ if "proj_attn.weight" in new_path:
196
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
197
+ else:
198
+ checkpoint[new_path] = old_checkpoint[path["old"]]
199
+
200
+
201
+ def conv_attn_to_linear(checkpoint):
202
+ keys = list(checkpoint.keys())
203
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
204
+ for key in keys:
205
+ if ".".join(key.split(".")[-2:]) in attn_keys:
206
+ if checkpoint[key].ndim > 2:
207
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
208
+ elif "proj_attn.weight" in key:
209
+ if checkpoint[key].ndim > 2:
210
+ checkpoint[key] = checkpoint[key][:, :, 0]
211
+
212
+
213
+ def linear_transformer_to_conv(checkpoint):
214
+ keys = list(checkpoint.keys())
215
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
216
+ for key in keys:
217
+ if ".".join(key.split(".")[-2:]) in tf_keys:
218
+ if checkpoint[key].ndim == 2:
219
+ checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
220
+
221
+
222
+ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
223
+ """
224
+ Takes a state dict and a config, and returns a converted checkpoint.
225
+ """
226
+
227
+ # extract state_dict for UNet
228
+ unet_state_dict = {}
229
+ unet_key = "model.diffusion_model."
230
+ keys = list(checkpoint.keys())
231
+ for key in keys:
232
+ if key.startswith(unet_key):
233
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
234
+
235
+ new_checkpoint = {}
236
+
237
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
238
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
239
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
240
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
241
+
242
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
243
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
244
+
245
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
246
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
247
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
248
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
249
+
250
+ # Retrieves the keys for the input blocks only
251
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
252
+ input_blocks = {
253
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
254
+ for layer_id in range(num_input_blocks)
255
+ }
256
+
257
+ # Retrieves the keys for the middle blocks only
258
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
259
+ middle_blocks = {
260
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
261
+ for layer_id in range(num_middle_blocks)
262
+ }
263
+
264
+ # Retrieves the keys for the output blocks only
265
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
266
+ output_blocks = {
267
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
268
+ for layer_id in range(num_output_blocks)
269
+ }
270
+
271
+ for i in range(1, num_input_blocks):
272
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
273
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
274
+
275
+ resnets = [
276
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
277
+ ]
278
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
279
+
280
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
281
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
282
+ f"input_blocks.{i}.0.op.weight"
283
+ )
284
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
285
+ f"input_blocks.{i}.0.op.bias"
286
+ )
287
+
288
+ paths = renew_resnet_paths(resnets)
289
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
290
+ assign_to_checkpoint(
291
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
292
+ )
293
+
294
+ if len(attentions):
295
+ paths = renew_attention_paths(attentions)
296
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
297
+ assign_to_checkpoint(
298
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
299
+ )
300
+
301
+ resnet_0 = middle_blocks[0]
302
+ attentions = middle_blocks[1]
303
+ resnet_1 = middle_blocks[2]
304
+
305
+ resnet_0_paths = renew_resnet_paths(resnet_0)
306
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
307
+
308
+ resnet_1_paths = renew_resnet_paths(resnet_1)
309
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
310
+
311
+ attentions_paths = renew_attention_paths(attentions)
312
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
313
+ assign_to_checkpoint(
314
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
315
+ )
316
+
317
+ for i in range(num_output_blocks):
318
+ block_id = i // (config["layers_per_block"] + 1)
319
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
320
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
321
+ output_block_list = {}
322
+
323
+ for layer in output_block_layers:
324
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
325
+ if layer_id in output_block_list:
326
+ output_block_list[layer_id].append(layer_name)
327
+ else:
328
+ output_block_list[layer_id] = [layer_name]
329
+
330
+ if len(output_block_list) > 1:
331
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
332
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
333
+
334
+ resnet_0_paths = renew_resnet_paths(resnets)
335
+ paths = renew_resnet_paths(resnets)
336
+
337
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
338
+ assign_to_checkpoint(
339
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
340
+ )
341
+
342
+ # オリジナル:
343
+ # if ["conv.weight", "conv.bias"] in output_block_list.values():
344
+ # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
345
+
346
+ # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
347
+ for l in output_block_list.values():
348
+ l.sort()
349
+
350
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
351
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
352
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
353
+ f"output_blocks.{i}.{index}.conv.bias"
354
+ ]
355
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
356
+ f"output_blocks.{i}.{index}.conv.weight"
357
+ ]
358
+
359
+ # Clear attentions as they have been attributed above.
360
+ if len(attentions) == 2:
361
+ attentions = []
362
+
363
+ if len(attentions):
364
+ paths = renew_attention_paths(attentions)
365
+ meta_path = {
366
+ "old": f"output_blocks.{i}.1",
367
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
368
+ }
369
+ assign_to_checkpoint(
370
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
371
+ )
372
+ else:
373
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
374
+ for path in resnet_0_paths:
375
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
376
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
377
+
378
+ new_checkpoint[new_path] = unet_state_dict[old_path]
379
+
380
+ # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
381
+ if v2:
382
+ linear_transformer_to_conv(new_checkpoint)
383
+
384
+ return new_checkpoint
385
+
386
+
387
+ def convert_ldm_vae_checkpoint(checkpoint, config):
388
+ # extract state dict for VAE
389
+ vae_state_dict = {}
390
+ vae_key = "first_stage_model."
391
+ keys = list(checkpoint.keys())
392
+ for key in keys:
393
+ if key.startswith(vae_key):
394
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
395
+ # if len(vae_state_dict) == 0:
396
+ # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
397
+ # vae_state_dict = checkpoint
398
+
399
+ new_checkpoint = {}
400
+
401
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
402
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
403
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
404
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
405
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
406
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
407
+
408
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
409
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
410
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
411
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
412
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
413
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
414
+
415
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
416
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
417
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
418
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
419
+
420
+ # Retrieves the keys for the encoder down blocks only
421
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
422
+ down_blocks = {
423
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
424
+ }
425
+
426
+ # Retrieves the keys for the decoder up blocks only
427
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
428
+ up_blocks = {
429
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
430
+ }
431
+
432
+ for i in range(num_down_blocks):
433
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
434
+
435
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
436
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
437
+ f"encoder.down.{i}.downsample.conv.weight"
438
+ )
439
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
440
+ f"encoder.down.{i}.downsample.conv.bias"
441
+ )
442
+
443
+ paths = renew_vae_resnet_paths(resnets)
444
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
445
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
446
+
447
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
448
+ num_mid_res_blocks = 2
449
+ for i in range(1, num_mid_res_blocks + 1):
450
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
451
+
452
+ paths = renew_vae_resnet_paths(resnets)
453
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
454
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
455
+
456
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
457
+ paths = renew_vae_attention_paths(mid_attentions)
458
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
459
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
460
+ conv_attn_to_linear(new_checkpoint)
461
+
462
+ for i in range(num_up_blocks):
463
+ block_id = num_up_blocks - 1 - i
464
+ resnets = [
465
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
466
+ ]
467
+
468
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
469
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
470
+ f"decoder.up.{block_id}.upsample.conv.weight"
471
+ ]
472
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
473
+ f"decoder.up.{block_id}.upsample.conv.bias"
474
+ ]
475
+
476
+ paths = renew_vae_resnet_paths(resnets)
477
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
478
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
479
+
480
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
481
+ num_mid_res_blocks = 2
482
+ for i in range(1, num_mid_res_blocks + 1):
483
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
484
+
485
+ paths = renew_vae_resnet_paths(resnets)
486
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
487
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
488
+
489
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
490
+ paths = renew_vae_attention_paths(mid_attentions)
491
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
492
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
493
+ conv_attn_to_linear(new_checkpoint)
494
+ return new_checkpoint
495
+
496
+
497
+ def create_unet_diffusers_config(v2):
498
+ """
499
+ Creates a config for the diffusers based on the config of the LDM model.
500
+ """
501
+ # unet_params = original_config.model.params.unet_config.params
502
+
503
+ block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
504
+
505
+ down_block_types = []
506
+ resolution = 1
507
+ for i in range(len(block_out_channels)):
508
+ block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
509
+ down_block_types.append(block_type)
510
+ if i != len(block_out_channels) - 1:
511
+ resolution *= 2
512
+
513
+ up_block_types = []
514
+ for i in range(len(block_out_channels)):
515
+ block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
516
+ up_block_types.append(block_type)
517
+ resolution //= 2
518
+
519
+ config = dict(
520
+ sample_size=UNET_PARAMS_IMAGE_SIZE,
521
+ in_channels=UNET_PARAMS_IN_CHANNELS,
522
+ out_channels=UNET_PARAMS_OUT_CHANNELS,
523
+ down_block_types=tuple(down_block_types),
524
+ up_block_types=tuple(up_block_types),
525
+ block_out_channels=tuple(block_out_channels),
526
+ layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
527
+ cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
528
+ attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
529
+ )
530
+
531
+ return config
532
+
533
+
534
+ def create_vae_diffusers_config():
535
+ """
536
+ Creates a config for the diffusers based on the config of the LDM model.
537
+ """
538
+ # vae_params = original_config.model.params.first_stage_config.params.ddconfig
539
+ # _ = original_config.model.params.first_stage_config.params.embed_dim
540
+ block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
541
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
542
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
543
+
544
+ config = dict(
545
+ sample_size=VAE_PARAMS_RESOLUTION,
546
+ in_channels=VAE_PARAMS_IN_CHANNELS,
547
+ out_channels=VAE_PARAMS_OUT_CH,
548
+ down_block_types=tuple(down_block_types),
549
+ up_block_types=tuple(up_block_types),
550
+ block_out_channels=tuple(block_out_channels),
551
+ latent_channels=VAE_PARAMS_Z_CHANNELS,
552
+ layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
553
+ )
554
+ return config
555
+
556
+
557
+ def convert_ldm_clip_checkpoint_v1(checkpoint):
558
+ keys = list(checkpoint.keys())
559
+ text_model_dict = {}
560
+ for key in keys:
561
+ if key.startswith("cond_stage_model.transformer"):
562
+ text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
563
+ return text_model_dict
564
+
565
+
566
+ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
567
+ # 嫌になるくらい違うぞ!
568
+ def convert_key(key):
569
+ if not key.startswith("cond_stage_model"):
570
+ return None
571
+
572
+ # common conversion
573
+ key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
574
+ key = key.replace("cond_stage_model.model.", "text_model.")
575
+
576
+ if "resblocks" in key:
577
+ # resblocks conversion
578
+ key = key.replace(".resblocks.", ".layers.")
579
+ if ".ln_" in key:
580
+ key = key.replace(".ln_", ".layer_norm")
581
+ elif ".mlp." in key:
582
+ key = key.replace(".c_fc.", ".fc1.")
583
+ key = key.replace(".c_proj.", ".fc2.")
584
+ elif '.attn.out_proj' in key:
585
+ key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
586
+ elif '.attn.in_proj' in key:
587
+ key = None # 特殊なので後で処理する
588
+ else:
589
+ raise ValueError(f"unexpected key in SD: {key}")
590
+ elif '.positional_embedding' in key:
591
+ key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
592
+ elif '.text_projection' in key:
593
+ key = None # 使われない???
594
+ elif '.logit_scale' in key:
595
+ key = None # 使われない???
596
+ elif '.token_embedding' in key:
597
+ key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
598
+ elif '.ln_final' in key:
599
+ key = key.replace(".ln_final", ".final_layer_norm")
600
+ return key
601
+
602
+ keys = list(checkpoint.keys())
603
+ new_sd = {}
604
+ for key in keys:
605
+ # remove resblocks 23
606
+ if '.resblocks.23.' in key:
607
+ continue
608
+ new_key = convert_key(key)
609
+ if new_key is None:
610
+ continue
611
+ new_sd[new_key] = checkpoint[key]
612
+
613
+ # attnの変換
614
+ for key in keys:
615
+ if '.resblocks.23.' in key:
616
+ continue
617
+ if '.resblocks' in key and '.attn.in_proj_' in key:
618
+ # 三つに分割
619
+ values = torch.chunk(checkpoint[key], 3)
620
+
621
+ key_suffix = ".weight" if "weight" in key else ".bias"
622
+ key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
623
+ key_pfx = key_pfx.replace("_weight", "")
624
+ key_pfx = key_pfx.replace("_bias", "")
625
+ key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
626
+ new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
627
+ new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
628
+ new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
629
+
630
+ # rename or add position_ids
631
+ ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
632
+ if ANOTHER_POSITION_IDS_KEY in new_sd:
633
+ # waifu diffusion v1.4
634
+ position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
635
+ del new_sd[ANOTHER_POSITION_IDS_KEY]
636
+ else:
637
+ position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
638
+
639
+ new_sd["text_model.embeddings.position_ids"] = position_ids
640
+ return new_sd
641
+
642
+ # endregion
643
+
644
+
645
+ # region Diffusers->StableDiffusion の変換コード
646
+ # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
647
+
648
+ def conv_transformer_to_linear(checkpoint):
649
+ keys = list(checkpoint.keys())
650
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
651
+ for key in keys:
652
+ if ".".join(key.split(".")[-2:]) in tf_keys:
653
+ if checkpoint[key].ndim > 2:
654
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
655
+
656
+
657
+ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
658
+ unet_conversion_map = [
659
+ # (stable-diffusion, HF Diffusers)
660
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
661
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
662
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
663
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
664
+ ("input_blocks.0.0.weight", "conv_in.weight"),
665
+ ("input_blocks.0.0.bias", "conv_in.bias"),
666
+ ("out.0.weight", "conv_norm_out.weight"),
667
+ ("out.0.bias", "conv_norm_out.bias"),
668
+ ("out.2.weight", "conv_out.weight"),
669
+ ("out.2.bias", "conv_out.bias"),
670
+ ]
671
+
672
+ unet_conversion_map_resnet = [
673
+ # (stable-diffusion, HF Diffusers)
674
+ ("in_layers.0", "norm1"),
675
+ ("in_layers.2", "conv1"),
676
+ ("out_layers.0", "norm2"),
677
+ ("out_layers.3", "conv2"),
678
+ ("emb_layers.1", "time_emb_proj"),
679
+ ("skip_connection", "conv_shortcut"),
680
+ ]
681
+
682
+ unet_conversion_map_layer = []
683
+ for i in range(4):
684
+ # loop over downblocks/upblocks
685
+
686
+ for j in range(2):
687
+ # loop over resnets/attentions for downblocks
688
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
689
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
690
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
691
+
692
+ if i < 3:
693
+ # no attention layers in down_blocks.3
694
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
695
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
696
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
697
+
698
+ for j in range(3):
699
+ # loop over resnets/attentions for upblocks
700
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
701
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
702
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
703
+
704
+ if i > 0:
705
+ # no attention layers in up_blocks.0
706
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
707
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
708
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
709
+
710
+ if i < 3:
711
+ # no downsample in down_blocks.3
712
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
713
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
714
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
715
+
716
+ # no upsample in up_blocks.3
717
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
718
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
719
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
720
+
721
+ hf_mid_atn_prefix = "mid_block.attentions.0."
722
+ sd_mid_atn_prefix = "middle_block.1."
723
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
724
+
725
+ for j in range(2):
726
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
727
+ sd_mid_res_prefix = f"middle_block.{2*j}."
728
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
729
+
730
+ # buyer beware: this is a *brittle* function,
731
+ # and correct output requires that all of these pieces interact in
732
+ # the exact order in which I have arranged them.
733
+ mapping = {k: k for k in unet_state_dict.keys()}
734
+ for sd_name, hf_name in unet_conversion_map:
735
+ mapping[hf_name] = sd_name
736
+ for k, v in mapping.items():
737
+ if "resnets" in k:
738
+ for sd_part, hf_part in unet_conversion_map_resnet:
739
+ v = v.replace(hf_part, sd_part)
740
+ mapping[k] = v
741
+ for k, v in mapping.items():
742
+ for sd_part, hf_part in unet_conversion_map_layer:
743
+ v = v.replace(hf_part, sd_part)
744
+ mapping[k] = v
745
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
746
+
747
+ if v2:
748
+ conv_transformer_to_linear(new_state_dict)
749
+
750
+ return new_state_dict
751
+
752
+
753
+ # ================#
754
+ # VAE Conversion #
755
+ # ================#
756
+
757
+ def reshape_weight_for_sd(w):
758
+ # convert HF linear weights to SD conv2d weights
759
+ return w.reshape(*w.shape, 1, 1)
760
+
761
+
762
+ def convert_vae_state_dict(vae_state_dict):
763
+ vae_conversion_map = [
764
+ # (stable-diffusion, HF Diffusers)
765
+ ("nin_shortcut", "conv_shortcut"),
766
+ ("norm_out", "conv_norm_out"),
767
+ ("mid.attn_1.", "mid_block.attentions.0."),
768
+ ]
769
+
770
+ for i in range(4):
771
+ # down_blocks have two resnets
772
+ for j in range(2):
773
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
774
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
775
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
776
+
777
+ if i < 3:
778
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
779
+ sd_downsample_prefix = f"down.{i}.downsample."
780
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
781
+
782
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
783
+ sd_upsample_prefix = f"up.{3-i}.upsample."
784
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
785
+
786
+ # up_blocks have three resnets
787
+ # also, up blocks in hf are numbered in reverse from sd
788
+ for j in range(3):
789
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
790
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
791
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
792
+
793
+ # this part accounts for mid blocks in both the encoder and the decoder
794
+ for i in range(2):
795
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
796
+ sd_mid_res_prefix = f"mid.block_{i+1}."
797
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
798
+
799
+ vae_conversion_map_attn = [
800
+ # (stable-diffusion, HF Diffusers)
801
+ ("norm.", "group_norm."),
802
+ ("q.", "query."),
803
+ ("k.", "key."),
804
+ ("v.", "value."),
805
+ ("proj_out.", "proj_attn."),
806
+ ]
807
+
808
+ mapping = {k: k for k in vae_state_dict.keys()}
809
+ for k, v in mapping.items():
810
+ for sd_part, hf_part in vae_conversion_map:
811
+ v = v.replace(hf_part, sd_part)
812
+ mapping[k] = v
813
+ for k, v in mapping.items():
814
+ if "attentions" in k:
815
+ for sd_part, hf_part in vae_conversion_map_attn:
816
+ v = v.replace(hf_part, sd_part)
817
+ mapping[k] = v
818
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
819
+ weights_to_convert = ["q", "k", "v", "proj_out"]
820
+ for k, v in new_state_dict.items():
821
+ for weight_name in weights_to_convert:
822
+ if f"mid.attn_1.{weight_name}.weight" in k:
823
+ # print(f"Reshaping {k} for SD format")
824
+ new_state_dict[k] = reshape_weight_for_sd(v)
825
+
826
+ return new_state_dict
827
+
828
+
829
+ # endregion
830
+
831
+ # region 自作のモデル読み書きなど
832
+
833
+ def is_safetensors(path):
834
+ return os.path.splitext(path)[1].lower() == '.safetensors'
835
+
836
+
837
+ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
838
+ # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
839
+ TEXT_ENCODER_KEY_REPLACEMENTS = [
840
+ ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
841
+ ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
842
+ ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
843
+ ]
844
+
845
+ if is_safetensors(ckpt_path):
846
+ checkpoint = None
847
+ state_dict = load_file(ckpt_path, "cpu")
848
+ else:
849
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
850
+ if "state_dict" in checkpoint:
851
+ state_dict = checkpoint["state_dict"]
852
+ else:
853
+ state_dict = checkpoint
854
+ checkpoint = None
855
+
856
+ key_reps = []
857
+ for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
858
+ for key in state_dict.keys():
859
+ if key.startswith(rep_from):
860
+ new_key = rep_to + key[len(rep_from):]
861
+ key_reps.append((key, new_key))
862
+
863
+ for key, new_key in key_reps:
864
+ state_dict[new_key] = state_dict[key]
865
+ del state_dict[key]
866
+
867
+ return checkpoint, state_dict
868
+
869
+
870
+ # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
871
+ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
872
+ _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
873
+ if dtype is not None:
874
+ for k, v in state_dict.items():
875
+ if type(v) is torch.Tensor:
876
+ state_dict[k] = v.to(dtype)
877
+
878
+ # Convert the UNet2DConditionModel model.
879
+ unet_config = create_unet_diffusers_config(v2)
880
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
881
+
882
+ unet = UNet2DConditionModel(**unet_config)
883
+ info = unet.load_state_dict(converted_unet_checkpoint)
884
+ print("loading u-net:", info)
885
+
886
+ # Convert the VAE model.
887
+ vae_config = create_vae_diffusers_config()
888
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
889
+
890
+ vae = AutoencoderKL(**vae_config)
891
+ info = vae.load_state_dict(converted_vae_checkpoint)
892
+ print("loading vae:", info)
893
+
894
+ # convert text_model
895
+ if v2:
896
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
897
+ cfg = CLIPTextConfig(
898
+ vocab_size=49408,
899
+ hidden_size=1024,
900
+ intermediate_size=4096,
901
+ num_hidden_layers=23,
902
+ num_attention_heads=16,
903
+ max_position_embeddings=77,
904
+ hidden_act="gelu",
905
+ layer_norm_eps=1e-05,
906
+ dropout=0.0,
907
+ attention_dropout=0.0,
908
+ initializer_range=0.02,
909
+ initializer_factor=1.0,
910
+ pad_token_id=1,
911
+ bos_token_id=0,
912
+ eos_token_id=2,
913
+ model_type="clip_text_model",
914
+ projection_dim=512,
915
+ torch_dtype="float32",
916
+ transformers_version="4.25.0.dev0",
917
+ )
918
+ text_model = CLIPTextModel._from_config(cfg)
919
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
920
+ else:
921
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
922
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
923
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
924
+ print("loading text encoder:", info)
925
+
926
+ return text_model, vae, unet
927
+
928
+
929
+ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
930
+ def convert_key(key):
931
+ # position_idsの除去
932
+ if ".position_ids" in key:
933
+ return None
934
+
935
+ # common
936
+ key = key.replace("text_model.encoder.", "transformer.")
937
+ key = key.replace("text_model.", "")
938
+ if "layers" in key:
939
+ # resblocks conversion
940
+ key = key.replace(".layers.", ".resblocks.")
941
+ if ".layer_norm" in key:
942
+ key = key.replace(".layer_norm", ".ln_")
943
+ elif ".mlp." in key:
944
+ key = key.replace(".fc1.", ".c_fc.")
945
+ key = key.replace(".fc2.", ".c_proj.")
946
+ elif '.self_attn.out_proj' in key:
947
+ key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
948
+ elif '.self_attn.' in key:
949
+ key = None # 特殊なので後で処理する
950
+ else:
951
+ raise ValueError(f"unexpected key in DiffUsers model: {key}")
952
+ elif '.position_embedding' in key:
953
+ key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
954
+ elif '.token_embedding' in key:
955
+ key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
956
+ elif 'final_layer_norm' in key:
957
+ key = key.replace("final_layer_norm", "ln_final")
958
+ return key
959
+
960
+ keys = list(checkpoint.keys())
961
+ new_sd = {}
962
+ for key in keys:
963
+ new_key = convert_key(key)
964
+ if new_key is None:
965
+ continue
966
+ new_sd[new_key] = checkpoint[key]
967
+
968
+ # attnの変換
969
+ for key in keys:
970
+ if 'layers' in key and 'q_proj' in key:
971
+ # 三つを結合
972
+ key_q = key
973
+ key_k = key.replace("q_proj", "k_proj")
974
+ key_v = key.replace("q_proj", "v_proj")
975
+
976
+ value_q = checkpoint[key_q]
977
+ value_k = checkpoint[key_k]
978
+ value_v = checkpoint[key_v]
979
+ value = torch.cat([value_q, value_k, value_v])
980
+
981
+ new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
982
+ new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
983
+ new_sd[new_key] = value
984
+
985
+ # 最後の層などを捏造するか
986
+ if make_dummy_weights:
987
+ print("make dummy weights for resblock.23, text_projection and logit scale.")
988
+ keys = list(new_sd.keys())
989
+ for key in keys:
990
+ if key.startswith("transformer.resblocks.22."):
991
+ new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
992
+
993
+ # Diffusersに含まれない重みを作っておく
994
+ new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
995
+ new_sd['logit_scale'] = torch.tensor(1)
996
+
997
+ return new_sd
998
+
999
+
1000
+ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
1001
+ if ckpt_path is not None:
1002
+ # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
1003
+ checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
1004
+ if checkpoint is None: # safetensors または state_dictのckpt
1005
+ checkpoint = {}
1006
+ strict = False
1007
+ else:
1008
+ strict = True
1009
+ if "state_dict" in state_dict:
1010
+ del state_dict["state_dict"]
1011
+ else:
1012
+ # 新しく作る
1013
+ assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
1014
+ checkpoint = {}
1015
+ state_dict = {}
1016
+ strict = False
1017
+
1018
+ def update_sd(prefix, sd):
1019
+ for k, v in sd.items():
1020
+ key = prefix + k
1021
+ assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
1022
+ if save_dtype is not None:
1023
+ v = v.detach().clone().to("cpu").to(save_dtype)
1024
+ state_dict[key] = v
1025
+
1026
+ # Convert the UNet model
1027
+ unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
1028
+ update_sd("model.diffusion_model.", unet_state_dict)
1029
+
1030
+ # Convert the text encoder model
1031
+ if v2:
1032
+ make_dummy = ckpt_path is None # 参照元のcheckpoint���ない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
1033
+ text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
1034
+ update_sd("cond_stage_model.model.", text_enc_dict)
1035
+ else:
1036
+ text_enc_dict = text_encoder.state_dict()
1037
+ update_sd("cond_stage_model.transformer.", text_enc_dict)
1038
+
1039
+ # Convert the VAE
1040
+ if vae is not None:
1041
+ vae_dict = convert_vae_state_dict(vae.state_dict())
1042
+ update_sd("first_stage_model.", vae_dict)
1043
+
1044
+ # Put together new checkpoint
1045
+ key_count = len(state_dict.keys())
1046
+ new_ckpt = {'state_dict': state_dict}
1047
+
1048
+ if 'epoch' in checkpoint:
1049
+ epochs += checkpoint['epoch']
1050
+ if 'global_step' in checkpoint:
1051
+ steps += checkpoint['global_step']
1052
+
1053
+ new_ckpt['epoch'] = epochs
1054
+ new_ckpt['global_step'] = steps
1055
+
1056
+ if is_safetensors(output_file):
1057
+ # TODO Tensor以外のdictの値を削除したほうがいいか
1058
+ save_file(state_dict, output_file)
1059
+ else:
1060
+ torch.save(new_ckpt, output_file)
1061
+
1062
+ return key_count
1063
+
1064
+
1065
+ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
1066
+ if pretrained_model_name_or_path is None:
1067
+ # load default settings for v1/v2
1068
+ if v2:
1069
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
1070
+ else:
1071
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
1072
+
1073
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
1074
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
1075
+ if vae is None:
1076
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
1077
+
1078
+ pipeline = StableDiffusionPipeline(
1079
+ unet=unet,
1080
+ text_encoder=text_encoder,
1081
+ vae=vae,
1082
+ scheduler=scheduler,
1083
+ tokenizer=tokenizer,
1084
+ safety_checker=None,
1085
+ feature_extractor=None,
1086
+ requires_safety_checker=None,
1087
+ )
1088
+ pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
1089
+
1090
+
1091
+ VAE_PREFIX = "first_stage_model."
1092
+
1093
+
1094
+ def load_vae(vae_id, dtype):
1095
+ print(f"load VAE: {vae_id}")
1096
+ if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
1097
+ # Diffusers local/remote
1098
+ try:
1099
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
1100
+ except EnvironmentError as e:
1101
+ print(f"exception occurs in loading vae: {e}")
1102
+ print("retry with subfolder='vae'")
1103
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
1104
+ return vae
1105
+
1106
+ # local
1107
+ vae_config = create_vae_diffusers_config()
1108
+
1109
+ if vae_id.endswith(".bin"):
1110
+ # SD 1.5 VAE on Huggingface
1111
+ converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
1112
+ else:
1113
+ # StableDiffusion
1114
+ vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
1115
+ else torch.load(vae_id, map_location="cpu"))
1116
+ vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
1117
+
1118
+ # vae only or full model
1119
+ full_model = False
1120
+ for vae_key in vae_sd:
1121
+ if vae_key.startswith(VAE_PREFIX):
1122
+ full_model = True
1123
+ break
1124
+ if not full_model:
1125
+ sd = {}
1126
+ for key, value in vae_sd.items():
1127
+ sd[VAE_PREFIX + key] = value
1128
+ vae_sd = sd
1129
+ del sd
1130
+
1131
+ # Convert the VAE model.
1132
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
1133
+
1134
+ vae = AutoencoderKL(**vae_config)
1135
+ vae.load_state_dict(converted_vae_checkpoint)
1136
+ return vae
1137
+
1138
+ # endregion
1139
+
1140
+
1141
+ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
1142
+ max_width, max_height = max_reso
1143
+ max_area = (max_width // divisible) * (max_height // divisible)
1144
+
1145
+ resos = set()
1146
+
1147
+ size = int(math.sqrt(max_area)) * divisible
1148
+ resos.add((size, size))
1149
+
1150
+ size = min_size
1151
+ while size <= max_size:
1152
+ width = size
1153
+ height = min(max_size, (max_area // (width // divisible)) * divisible)
1154
+ resos.add((width, height))
1155
+ resos.add((height, width))
1156
+
1157
+ # # make additional resos
1158
+ # if width >= height and width - divisible >= min_size:
1159
+ # resos.add((width - divisible, height))
1160
+ # resos.add((height, width - divisible))
1161
+ # if height >= width and height - divisible >= min_size:
1162
+ # resos.add((width, height - divisible))
1163
+ # resos.add((height - divisible, width))
1164
+
1165
+ size += divisible
1166
+
1167
+ resos = list(resos)
1168
+ resos.sort()
1169
+
1170
+ aspect_ratios = [w / h for w, h in resos]
1171
+ return resos, aspect_ratios
1172
+
1173
+
1174
+ if __name__ == '__main__':
1175
+ resos, aspect_ratios = make_bucket_resolutions((512, 768))
1176
+ print(len(resos))
1177
+ print(resos)
1178
+ print(aspect_ratios)
1179
+
1180
+ ars = set()
1181
+ for ar in aspect_ratios:
1182
+ if ar in ars:
1183
+ print("error! duplicate ar:", ar)
1184
+ ars.add(ar)
locon/kohya_utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # part of https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py
2
+
3
+ import hashlib
4
+ import safetensors
5
+ from io import BytesIO
6
+
7
+
8
+ def addnet_hash_legacy(b):
9
+ """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
10
+ m = hashlib.sha256()
11
+
12
+ b.seek(0x100000)
13
+ m.update(b.read(0x10000))
14
+ return m.hexdigest()[0:8]
15
+
16
+
17
+ def addnet_hash_safetensors(b):
18
+ """New model hash used by sd-webui-additional-networks for .safetensors format files"""
19
+ hash_sha256 = hashlib.sha256()
20
+ blksize = 1024 * 1024
21
+
22
+ b.seek(0)
23
+ header = b.read(8)
24
+ n = int.from_bytes(header, "little")
25
+
26
+ offset = n + 8
27
+ b.seek(offset)
28
+ for chunk in iter(lambda: b.read(blksize), b""):
29
+ hash_sha256.update(chunk)
30
+
31
+ return hash_sha256.hexdigest()
32
+
33
+
34
+ def precalculate_safetensors_hashes(tensors, metadata):
35
+ """Precalculate the model hashes needed by sd-webui-additional-networks to
36
+ save time on indexing the model later."""
37
+
38
+ # Because writing user metadata to the file can change the result of
39
+ # sd_models.model_hash(), only retain the training metadata for purposes of
40
+ # calculating the hash, as they are meant to be immutable
41
+ metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
42
+
43
+ bytes = safetensors.torch.save(tensors, metadata)
44
+ b = BytesIO(bytes)
45
+
46
+ model_hash = addnet_hash_safetensors(b)
47
+ legacy_hash = addnet_hash_legacy(b)
48
+ return model_hash, legacy_hash
locon/locon.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class LoConModule(nn.Module):
9
+ """
10
+ modifed from kohya-ss/sd-scripts/networks/lora:LoRAModule
11
+ """
12
+
13
+ def __init__(self, lora_name, org_module: nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
14
+ """ if alpha == 0 or None, alpha is rank (no scaling). """
15
+ super().__init__()
16
+ self.lora_name = lora_name
17
+ self.lora_dim = lora_dim
18
+
19
+ if org_module.__class__.__name__ == 'Conv2d':
20
+ # For general LoCon
21
+ in_dim = org_module.in_channels
22
+ k_size = org_module.kernel_size
23
+ stride = org_module.stride
24
+ padding = org_module.padding
25
+ out_dim = org_module.out_channels
26
+ self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False)
27
+ self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
28
+ else:
29
+ in_dim = org_module.in_features
30
+ out_dim = org_module.out_features
31
+ self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
32
+ self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
33
+
34
+ if type(alpha) == torch.Tensor:
35
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
36
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
37
+ self.scale = alpha / self.lora_dim
38
+ self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
39
+
40
+ # same as microsoft's
41
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
42
+ torch.nn.init.zeros_(self.lora_up.weight)
43
+
44
+ self.multiplier = multiplier
45
+ self.org_module = org_module # remove in applying
46
+
47
+ def apply_to(self):
48
+ self.org_forward = self.org_module.forward
49
+ self.org_module.forward = self.forward
50
+ del self.org_module
51
+
52
+ def forward(self, x):
53
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
locon/locon_kohya.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoCon network module
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+ # https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
6
+
7
+ import math
8
+ import os
9
+ from typing import List
10
+ import torch
11
+
12
+ from .kohya_utils import *
13
+ from .locon import LoConModule
14
+
15
+
16
+ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
17
+ if network_dim is None:
18
+ network_dim = 4 # default
19
+ conv_dim = kwargs.get('conv_dim', network_dim)
20
+ conv_alpha = kwargs.get('conv_alpha', network_alpha)
21
+ network = LoRANetwork(
22
+ text_encoder, unet,
23
+ multiplier=multiplier,
24
+ lora_dim=network_dim, conv_lora_dim=conv_dim,
25
+ alpha=network_alpha, conv_alpha=conv_alpha
26
+ )
27
+ return network
28
+
29
+
30
+ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
31
+ if os.path.splitext(file)[1] == '.safetensors':
32
+ from safetensors.torch import load_file, safe_open
33
+ weights_sd = load_file(file)
34
+ else:
35
+ weights_sd = torch.load(file, map_location='cpu')
36
+
37
+ # get dim (rank)
38
+ network_alpha = None
39
+ network_dim = None
40
+ for key, value in weights_sd.items():
41
+ if network_alpha is None and 'alpha' in key:
42
+ network_alpha = value
43
+ if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
44
+ network_dim = value.size()[0]
45
+
46
+ if network_alpha is None:
47
+ network_alpha = network_dim
48
+
49
+ network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
50
+ network.weights_sd = weights_sd
51
+ return network
52
+
53
+ torch.nn.Conv2d
54
+ class LoRANetwork(torch.nn.Module):
55
+ '''
56
+ LoRA + LoCon
57
+ '''
58
+ # Ignore proj_in or proj_out, their channels is only a few.
59
+ UNET_TARGET_REPLACE_MODULE = [
60
+ "Transformer2DModel",
61
+ "Attention",
62
+ "ResnetBlock2D",
63
+ "Downsample2D",
64
+ "Upsample2D"
65
+ ]
66
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
67
+ LORA_PREFIX_UNET = 'lora_unet'
68
+ LORA_PREFIX_TEXT_ENCODER = 'lora_te'
69
+
70
+ def __init__(
71
+ self,
72
+ text_encoder, unet,
73
+ multiplier=1.0,
74
+ lora_dim=4, conv_lora_dim=4,
75
+ alpha=1, conv_alpha=1
76
+ ) -> None:
77
+ super().__init__()
78
+ self.multiplier = multiplier
79
+ self.lora_dim = lora_dim
80
+ self.conv_lora_dim = int(conv_lora_dim)
81
+ if self.conv_lora_dim != self.lora_dim:
82
+ print('Apply different lora dim for conv layer')
83
+ print(f'LoCon Dim: {conv_lora_dim}, LoRA Dim: {lora_dim}')
84
+ self.alpha = alpha
85
+ self.conv_alpha = float(conv_alpha)
86
+ if self.alpha != self.conv_alpha:
87
+ print('Apply different alpha value for conv layer')
88
+ print(f'LoCon alpha: {conv_alpha}, LoRA alpha: {alpha}')
89
+
90
+ # create module instances
91
+ def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoConModule]:
92
+ print('Create LoCon Module')
93
+ loras = []
94
+ for name, module in root_module.named_modules():
95
+ if module.__class__.__name__ in target_replace_modules:
96
+ for child_name, child_module in module.named_modules():
97
+ lora_name = prefix + '.' + name + '.' + child_name
98
+ lora_name = lora_name.replace('.', '_')
99
+ if child_module.__class__.__name__ == 'Linear':
100
+ lora = LoConModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
101
+ elif child_module.__class__.__name__ == 'Conv2d':
102
+ k_size, *_ = child_module.kernel_size
103
+ if k_size==1:
104
+ lora = LoConModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
105
+ else:
106
+ lora = LoConModule(lora_name, child_module, self.multiplier, self.conv_lora_dim, self.conv_alpha)
107
+ else:
108
+ continue
109
+ loras.append(lora)
110
+ return loras
111
+
112
+ self.text_encoder_loras = create_modules(
113
+ LoRANetwork.LORA_PREFIX_TEXT_ENCODER,
114
+ text_encoder,
115
+ LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
116
+ )
117
+ print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
118
+
119
+ self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
120
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
121
+
122
+ self.weights_sd = None
123
+
124
+ # assertion
125
+ names = set()
126
+ for lora in self.text_encoder_loras + self.unet_loras:
127
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
128
+ names.add(lora.lora_name)
129
+
130
+ def set_multiplier(self, multiplier):
131
+ self.multiplier = multiplier
132
+ for lora in self.text_encoder_loras + self.unet_loras:
133
+ lora.multiplier = self.multiplier
134
+
135
+ def load_weights(self, file):
136
+ if os.path.splitext(file)[1] == '.safetensors':
137
+ from safetensors.torch import load_file, safe_open
138
+ self.weights_sd = load_file(file)
139
+ else:
140
+ self.weights_sd = torch.load(file, map_location='cpu')
141
+
142
+ def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
143
+ if self.weights_sd:
144
+ weights_has_text_encoder = weights_has_unet = False
145
+ for key in self.weights_sd.keys():
146
+ if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
147
+ weights_has_text_encoder = True
148
+ elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
149
+ weights_has_unet = True
150
+
151
+ if apply_text_encoder is None:
152
+ apply_text_encoder = weights_has_text_encoder
153
+ else:
154
+ assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
155
+
156
+ if apply_unet is None:
157
+ apply_unet = weights_has_unet
158
+ else:
159
+ assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
160
+ else:
161
+ assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
162
+
163
+ if apply_text_encoder:
164
+ print("enable LoRA for text encoder")
165
+ else:
166
+ self.text_encoder_loras = []
167
+
168
+ if apply_unet:
169
+ print("enable LoRA for U-Net")
170
+ else:
171
+ self.unet_loras = []
172
+
173
+ for lora in self.text_encoder_loras + self.unet_loras:
174
+ lora.apply_to()
175
+ self.add_module(lora.lora_name, lora)
176
+
177
+ if self.weights_sd:
178
+ # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
179
+ info = self.load_state_dict(self.weights_sd, False)
180
+ print(f"weights are loaded: {info}")
181
+
182
+ def enable_gradient_checkpointing(self):
183
+ # not supported
184
+ pass
185
+
186
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
187
+ def enumerate_params(loras):
188
+ params = []
189
+ for lora in loras:
190
+ params.extend(lora.parameters())
191
+ return params
192
+
193
+ self.requires_grad_(True)
194
+ all_params = []
195
+
196
+ if self.text_encoder_loras:
197
+ param_data = {'params': enumerate_params(self.text_encoder_loras)}
198
+ if text_encoder_lr is not None:
199
+ param_data['lr'] = text_encoder_lr
200
+ all_params.append(param_data)
201
+
202
+ if self.unet_loras:
203
+ param_data = {'params': enumerate_params(self.unet_loras)}
204
+ if unet_lr is not None:
205
+ param_data['lr'] = unet_lr
206
+ all_params.append(param_data)
207
+
208
+ return all_params
209
+
210
+ def prepare_grad_etc(self, text_encoder, unet):
211
+ self.requires_grad_(True)
212
+
213
+ def on_epoch_start(self, text_encoder, unet):
214
+ self.train()
215
+
216
+ def get_trainable_params(self):
217
+ return self.parameters()
218
+
219
+ def save_weights(self, file, dtype, metadata):
220
+ if metadata is not None and len(metadata) == 0:
221
+ metadata = None
222
+
223
+ state_dict = self.state_dict()
224
+
225
+ if dtype is not None:
226
+ for key in list(state_dict.keys()):
227
+ v = state_dict[key]
228
+ v = v.detach().clone().to("cpu").to(dtype)
229
+ state_dict[key] = v
230
+
231
+ if os.path.splitext(file)[1] == '.safetensors':
232
+ from safetensors.torch import save_file
233
+
234
+ # Precalculate model hashes to save time on indexing
235
+ if metadata is None:
236
+ metadata = {}
237
+ model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
238
+ metadata["sshs_model_hash"] = model_hash
239
+ metadata["sshs_legacy_hash"] = legacy_hash
240
+
241
+ save_file(state_dict, file, metadata)
242
+ else:
243
+ torch.save(state_dict, file)
locon/utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import torch.linalg as linalg
6
+
7
+ from tqdm import tqdm
8
+
9
+
10
+ def extract_conv(
11
+ weight: nn.Parameter|torch.Tensor,
12
+ lora_rank = 8
13
+ ) -> tuple[nn.Parameter, nn.Parameter]:
14
+ out_ch, in_ch, kernel_size, _ = weight.shape
15
+ lora_rank = min(out_ch, in_ch, lora_rank)
16
+
17
+ U, S, Vh = linalg.svd(weight.reshape(out_ch, -1))
18
+
19
+ U = U[:, :lora_rank]
20
+ S = S[:lora_rank]
21
+ U = U @ torch.diag(S)
22
+ Vh = Vh[:lora_rank, :]
23
+
24
+ extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).cpu()
25
+ extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).cpu()
26
+ del U, S, Vh, weight
27
+ return extract_weight_A, extract_weight_B
28
+
29
+
30
+ def merge_conv(
31
+ weight_a: nn.Parameter|torch.Tensor,
32
+ weight_b: nn.Parameter|torch.Tensor,
33
+ ):
34
+ rank, in_ch, kernel_size, k_ = weight_a.shape
35
+ out_ch, rank_, _, _ = weight_b.shape
36
+
37
+ assert rank == rank_ and kernel_size == k_
38
+
39
+ merged = weight_b.reshape(out_ch, -1) @ weight_a.reshape(rank, -1)
40
+ weight = merged.reshape(out_ch, in_ch, kernel_size, kernel_size)
41
+ return weight
42
+
43
+
44
+ def extract_linear(
45
+ weight: nn.Parameter|torch.Tensor,
46
+ lora_rank = 8
47
+ ) -> tuple[nn.Parameter, nn.Parameter]:
48
+ out_ch, in_ch = weight.shape
49
+ lora_rank = min(out_ch, in_ch, lora_rank)
50
+
51
+ U, S, Vh = linalg.svd(weight)
52
+
53
+ U = U[:, :lora_rank]
54
+ S = S[:lora_rank]
55
+ U = U @ torch.diag(S)
56
+ Vh = Vh[:lora_rank, :]
57
+
58
+ extract_weight_A = Vh.reshape(lora_rank, in_ch).cpu()
59
+ extract_weight_B = U.reshape(out_ch, lora_rank).cpu()
60
+ del U, S, Vh, weight
61
+ return extract_weight_A, extract_weight_B
62
+
63
+
64
+ def merge_linear(
65
+ weight_a: nn.Parameter|torch.Tensor,
66
+ weight_b: nn.Parameter|torch.Tensor,
67
+ ):
68
+ rank, in_ch = weight_a.shape
69
+ out_ch, rank_ = weight_b.shape
70
+
71
+ assert rank == rank_
72
+
73
+ weight = weight_b @ weight_a
74
+ return weight
75
+
76
+
77
+ def extract_diff(
78
+ base_model,
79
+ db_model,
80
+ lora_dim=4,
81
+ conv_lora_dim=4,
82
+ extract_device = 'cuda',
83
+ ):
84
+ UNET_TARGET_REPLACE_MODULE = [
85
+ "Transformer2DModel",
86
+ "Attention",
87
+ "ResnetBlock2D",
88
+ "Downsample2D",
89
+ "Upsample2D"
90
+ ]
91
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
92
+ LORA_PREFIX_UNET = 'lora_unet'
93
+ LORA_PREFIX_TEXT_ENCODER = 'lora_te'
94
+ def make_state_dict(
95
+ prefix,
96
+ root_module: torch.nn.Module,
97
+ target_module: torch.nn.Module,
98
+ target_replace_modules
99
+ ):
100
+ loras = {}
101
+ temp = {}
102
+
103
+ for name, module in root_module.named_modules():
104
+ if module.__class__.__name__ in target_replace_modules:
105
+ temp[name] = {}
106
+ for child_name, child_module in module.named_modules():
107
+ if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}:
108
+ continue
109
+ temp[name][child_name] = child_module.weight
110
+
111
+ for name, module in tqdm(list(target_module.named_modules())):
112
+ if name in temp:
113
+ weights = temp[name]
114
+ for child_name, child_module in module.named_modules():
115
+ lora_name = prefix + '.' + name + '.' + child_name
116
+ lora_name = lora_name.replace('.', '_')
117
+ if child_module.__class__.__name__ == 'Linear':
118
+ extract_a, extract_b = extract_linear(
119
+ (child_module.weight - weights[child_name]),
120
+ lora_dim
121
+ )
122
+ elif child_module.__class__.__name__ == 'Conv2d':
123
+ extract_a, extract_b = extract_conv(
124
+ (child_module.weight - weights[child_name]),
125
+ conv_lora_dim
126
+ )
127
+ else:
128
+ continue
129
+
130
+ loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().half()
131
+ loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().half()
132
+ loras[f'{lora_name}.alpha'] = torch.Tensor([int(extract_a.shape[0])]).detach().cpu().half()
133
+ del extract_a, extract_b
134
+ return loras
135
+
136
+ text_encoder_loras = make_state_dict(
137
+ LORA_PREFIX_TEXT_ENCODER,
138
+ base_model[0], db_model[0],
139
+ TEXT_ENCODER_TARGET_REPLACE_MODULE
140
+ )
141
+
142
+ unet_loras = make_state_dict(
143
+ LORA_PREFIX_UNET,
144
+ base_model[2], db_model[2],
145
+ UNET_TARGET_REPLACE_MODULE
146
+ )
147
+ print(len(text_encoder_loras), len(unet_loras))
148
+ return text_encoder_loras|unet_loras
lora_train_popup.py ADDED
@@ -0,0 +1,862 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import json
3
+ import time
4
+ from functools import partial
5
+ from typing import Union
6
+ import os
7
+ import tkinter as tk
8
+ from tkinter import filedialog as fd, ttk
9
+ from tkinter import simpledialog as sd
10
+ from tkinter import messagebox as mb
11
+
12
+ import torch.cuda
13
+ import train_network
14
+ import library.train_util as util
15
+ import argparse
16
+
17
+
18
+ class ArgStore:
19
+ # Represents the entirety of all possible inputs for sd-scripts. they are ordered from most important to least
20
+ def __init__(self):
21
+ # Important, these are the most likely things you will modify
22
+ self.base_model: str = r"" # example path, r"E:\sd\stable-diffusion-webui\models\Stable-diffusion\nai.ckpt"
23
+ self.img_folder: str = r"" # is the folder path to your img folder, make sure to follow the guide here for folder setup: https://rentry.org/2chAI_LoRA_Dreambooth_guide_english#for-kohyas-script
24
+ self.output_folder: str = r"" # just the folder all epochs/safetensors are output
25
+ self.change_output_name: Union[str, None] = None # changes the output name of the epochs
26
+ self.save_json_folder: Union[str, None] = None # OPTIONAL, saves a json folder of your config to whatever location you set here.
27
+ self.load_json_path: Union[str, None] = None # OPTIONAL, loads a json file partially changes the config to match. things like folder paths do not get modified.
28
+ self.json_load_skip_list: Union[list[str], None] = ["save_json_folder", "reg_img_folder",
29
+ "lora_model_for_resume", "change_output_name",
30
+ "training_comment",
31
+ "json_load_skip_list"] # OPTIONAL, allows the user to define what they skip when loading a json, by default it loads everything, including all paths, set it up like this ["base_model", "img_folder", "output_folder"]
32
+ self.caption_dropout_rate: Union[float, None] = None # The rate at which captions for files get dropped.
33
+ self.caption_dropout_every_n_epochs: Union[int, None] = None # Defines how often an epoch will completely ignore
34
+ # captions, EX. 3 means it will ignore captions at epochs 3, 6, and 9
35
+ self.caption_tag_dropout_rate: Union[float, None] = None # Defines the rate at which a tag would be dropped, rather than the entire caption file
36
+ self.noise_offset: Union[float, None] = None # OPTIONAL, seems to help allow SD to gen better blacks and whites
37
+ # Kohya recommends, if you have it set, to use 0.1, not sure how
38
+ # high the value can be, I'm going to assume maximum of 1
39
+
40
+ self.net_dim: int = 128 # network dimension, 128 is the most common, however you might be able to get lesser to work
41
+ self.alpha: float = 128 # represents the scalar for training. the lower the alpha, the less gets learned per step. if you want the older way of training, set this to dim
42
+ # list of schedulers: linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup
43
+ self.scheduler: str = "cosine_with_restarts" # the scheduler for learning rate. Each does something specific
44
+ self.cosine_restarts: Union[int, None] = 1 # OPTIONAL, represents the number of times it restarts. Only matters if you are using cosine_with_restarts
45
+ self.scheduler_power: Union[float, None] = 1 # OPTIONAL, represents the power of the polynomial. Only matters if you are using polynomial
46
+ self.warmup_lr_ratio: Union[float, None] = None # OPTIONAL, Calculates the number of warmup steps based on the ratio given. Make sure to set this if you are using constant_with_warmup, None to ignore
47
+ self.learning_rate: Union[float, None] = 1e-4 # OPTIONAL, when not set, lr gets set to 1e-3 as per adamW. Personally, I suggest actually setting this as lower lr seems to be a small bit better.
48
+ self.text_encoder_lr: Union[float, None] = None # OPTIONAL, Sets a specific lr for the text encoder, this overwrites the base lr I believe, None to ignore
49
+ self.unet_lr: Union[float, None] = None # OPTIONAL, Sets a specific lr for the unet, this overwrites the base lr I believe, None to ignore
50
+ self.num_workers: int = 1 # The number of threads that are being used to load images, lower speeds up the start of epochs, but slows down the loading of data. The assumption here is that it increases the training time as you reduce this value
51
+ self.persistent_workers: bool = True # makes workers persistent, further reduces/eliminates the lag in between epochs. however it may increase memory usage
52
+
53
+ self.batch_size: int = 1 # The number of images that get processed at one time, this is directly proportional to your vram and resolution. with 12gb of vram, at 512 reso, you can get a maximum of 6 batch size
54
+ self.num_epochs: int = 1 # The number of epochs, if you set max steps this value is ignored as it doesn't calculate steps.
55
+ self.save_every_n_epochs: Union[int, None] = None # OPTIONAL, how often to save epochs, None to ignore
56
+ self.shuffle_captions: bool = False # OPTIONAL, False to ignore
57
+ self.keep_tokens: Union[int, None] = None # OPTIONAL, None to ignore
58
+ self.max_steps: Union[int, None] = None # OPTIONAL, if you have specific steps you want to hit, this allows you to set it directly. None to ignore
59
+ self.tag_occurrence_txt_file: bool = False # OPTIONAL, creates a txt file that has the entire occurrence of all tags in your dataset
60
+ # the metadata will also have this so long as you have metadata on, so no reason to have this on by default
61
+ # will automatically output to the same folder as your output checkpoints
62
+ self.sort_tag_occurrence_alphabetically: bool = False # OPTIONAL, only applies if tag_occurrence_txt_file is also true
63
+ # Will change the output to be alphabetically vs being occurrence based
64
+
65
+ # These are the second most likely things you will modify
66
+ self.train_resolution: int = 512
67
+ self.min_bucket_resolution: int = 320
68
+ self.max_bucket_resolution: int = 960
69
+ self.lora_model_for_resume: Union[str, None] = None # OPTIONAL, takes an input lora to continue training from, not exactly the way it *should* be, but it works, None to ignore
70
+ self.save_state: bool = False # OPTIONAL, is the intended way to save a training state to use for continuing training, False to ignore
71
+ self.load_previous_save_state: Union[str, None] = None # OPTIONAL, is the intended way to load a training state to use for continuing training, None to ignore
72
+ self.training_comment: Union[str, None] = None # OPTIONAL, great way to put in things like activation tokens right into the metadata. seems to not work at this point and time
73
+ self.unet_only: bool = False # OPTIONAL, set it to only train the unet
74
+ self.text_only: bool = False # OPTIONAL, set it to only train the text encoder
75
+
76
+ # These are the least likely things you will modify
77
+ self.reg_img_folder: Union[str, None] = None # OPTIONAL, None to ignore
78
+ self.clip_skip: int = 2 # If you are training on a model that is anime based, keep this at 2 as most models are designed for that
79
+ self.test_seed: int = 23 # this is the "reproducable seed", basically if you set the seed to this, you should be able to input a prompt from one of your training images and get a close representation of it
80
+ self.prior_loss_weight: float = 1 # is the loss weight much like Dreambooth, is required for LoRA training
81
+ self.gradient_checkpointing: bool = False # OPTIONAL, enables gradient checkpointing
82
+ self.gradient_acc_steps: Union[int, None] = None # OPTIONAL, not sure exactly what this means
83
+ self.mixed_precision: str = "fp16" # If you have the ability to use bf16, do it, it's better
84
+ self.save_precision: str = "fp16" # You can also save in bf16, but because it's not universally supported, I suggest you keep saving at fp16
85
+ self.save_as: str = "safetensors" # list is pt, ckpt, safetensors
86
+ self.caption_extension: str = ".txt" # the other option is .captions, but since wd1.4 tagger outputs as txt files, this is the default
87
+ self.max_clip_token_length = 150 # can be 75, 150, or 225 I believe, there is no reason to go higher than 150 though
88
+ self.buckets: bool = True
89
+ self.xformers: bool = True
90
+ self.use_8bit_adam: bool = True
91
+ self.cache_latents: bool = True
92
+ self.color_aug: bool = False # IMPORTANT: Clashes with cache_latents, only have one of the two on!
93
+ self.flip_aug: bool = False
94
+ self.vae: Union[str, None] = None # Seems to only make results worse when not using that specific vae, should probably not use
95
+ self.no_meta: bool = False # This removes the metadata that now gets saved into safetensors, (you should keep this on)
96
+ self.log_dir: Union[str, None] = None # output of logs, not useful to most people.
97
+ self.v2: bool = False # Sets up training for SD2.1
98
+ self.v_parameterization: bool = False # Only is used when v2 is also set and you are using the 768x version of v2
99
+
100
+ # Creates the dict that is used for the rest of the code, to facilitate easier json saving and loading
101
+ @staticmethod
102
+ def convert_args_to_dict():
103
+ return ArgStore().__dict__
104
+
105
+
106
+ def main():
107
+ parser = argparse.ArgumentParser()
108
+ setup_args(parser)
109
+ pre_args = parser.parse_args()
110
+ queues = 0
111
+ args_queue = []
112
+ cont = True
113
+ while cont:
114
+ arg_dict = ArgStore.convert_args_to_dict()
115
+ ret = mb.askyesno(message="Do you want to load a json config file?")
116
+ if ret:
117
+ load_json(ask_file("select json to load from", {"json"}), arg_dict)
118
+ arg_dict = ask_elements_trunc(arg_dict)
119
+ else:
120
+ arg_dict = ask_elements(arg_dict)
121
+ if pre_args.save_json_path or arg_dict["save_json_folder"]:
122
+ save_json(pre_args.save_json_path if pre_args.save_json_path else arg_dict['save_json_folder'], arg_dict)
123
+ args = create_arg_space(arg_dict)
124
+ args = parser.parse_args(args)
125
+ queues += 1
126
+ args_queue.append(args)
127
+ if arg_dict['tag_occurrence_txt_file']:
128
+ get_occurrence_of_tags(arg_dict)
129
+ ret = mb.askyesno(message="Do you want to queue another training?")
130
+ if not ret:
131
+ cont = False
132
+ for args in args_queue:
133
+ try:
134
+ train_network.train(args)
135
+ except Exception as e:
136
+ print(f"Failed to train this set of args.\nSkipping this training session.\nError is: {e}")
137
+ gc.collect()
138
+ torch.cuda.empty_cache()
139
+
140
+
141
+ def create_arg_space(args: dict) -> [str]:
142
+ # This is the list of args that are to be used regardless of setup
143
+ output = ["--network_module=networks.lora", f"--pretrained_model_name_or_path={args['base_model']}",
144
+ f"--train_data_dir={args['img_folder']}", f"--output_dir={args['output_folder']}",
145
+ f"--prior_loss_weight={args['prior_loss_weight']}", f"--caption_extension=" + args['caption_extension'],
146
+ f"--resolution={args['train_resolution']}", f"--train_batch_size={args['batch_size']}",
147
+ f"--mixed_precision={args['mixed_precision']}", f"--save_precision={args['save_precision']}",
148
+ f"--network_dim={args['net_dim']}", f"--save_model_as={args['save_as']}",
149
+ f"--clip_skip={args['clip_skip']}", f"--seed={args['test_seed']}",
150
+ f"--max_token_length={args['max_clip_token_length']}", f"--lr_scheduler={args['scheduler']}",
151
+ f"--network_alpha={args['alpha']}", f"--max_data_loader_n_workers={args['num_workers']}"]
152
+ if not args['max_steps']:
153
+ output.append(f"--max_train_epochs={args['num_epochs']}")
154
+ output += create_optional_args(args, find_max_steps(args))
155
+ else:
156
+ output.append(f"--max_train_steps={args['max_steps']}")
157
+ output += create_optional_args(args, args['max_steps'])
158
+ return output
159
+
160
+
161
+ def create_optional_args(args: dict, steps):
162
+ output = []
163
+ if args["reg_img_folder"]:
164
+ output.append(f"--reg_data_dir={args['reg_img_folder']}")
165
+
166
+ if args['lora_model_for_resume']:
167
+ output.append(f"--network_weights={args['lora_model_for_resume']}")
168
+
169
+ if args['save_every_n_epochs']:
170
+ output.append(f"--save_every_n_epochs={args['save_every_n_epochs']}")
171
+ else:
172
+ output.append("--save_every_n_epochs=999999")
173
+
174
+ if args['shuffle_captions']:
175
+ output.append("--shuffle_caption")
176
+
177
+ if args['keep_tokens'] and args['keep_tokens'] > 0:
178
+ output.append(f"--keep_tokens={args['keep_tokens']}")
179
+
180
+ if args['buckets']:
181
+ output.append("--enable_bucket")
182
+ output.append(f"--min_bucket_reso={args['min_bucket_resolution']}")
183
+ output.append(f"--max_bucket_reso={args['max_bucket_resolution']}")
184
+
185
+ if args['use_8bit_adam']:
186
+ output.append("--use_8bit_adam")
187
+
188
+ if args['xformers']:
189
+ output.append("--xformers")
190
+
191
+ if args['color_aug']:
192
+ if args['cache_latents']:
193
+ print("color_aug and cache_latents conflict with one another. Please select only one")
194
+ quit(1)
195
+ output.append("--color_aug")
196
+
197
+ if args['flip_aug']:
198
+ output.append("--flip_aug")
199
+
200
+ if args['cache_latents']:
201
+ output.append("--cache_latents")
202
+
203
+ if args['warmup_lr_ratio'] and args['warmup_lr_ratio'] > 0:
204
+ warmup_steps = int(steps * args['warmup_lr_ratio'])
205
+ output.append(f"--lr_warmup_steps={warmup_steps}")
206
+
207
+ if args['gradient_checkpointing']:
208
+ output.append("--gradient_checkpointing")
209
+
210
+ if args['gradient_acc_steps'] and args['gradient_acc_steps'] > 0 and args['gradient_checkpointing']:
211
+ output.append(f"--gradient_accumulation_steps={args['gradient_acc_steps']}")
212
+
213
+ if args['learning_rate'] and args['learning_rate'] > 0:
214
+ output.append(f"--learning_rate={args['learning_rate']}")
215
+
216
+ if args['text_encoder_lr'] and args['text_encoder_lr'] > 0:
217
+ output.append(f"--text_encoder_lr={args['text_encoder_lr']}")
218
+
219
+ if args['unet_lr'] and args['unet_lr'] > 0:
220
+ output.append(f"--unet_lr={args['unet_lr']}")
221
+
222
+ if args['vae']:
223
+ output.append(f"--vae={args['vae']}")
224
+
225
+ if args['no_meta']:
226
+ output.append("--no_metadata")
227
+
228
+ if args['save_state']:
229
+ output.append("--save_state")
230
+
231
+ if args['load_previous_save_state']:
232
+ output.append(f"--resume={args['load_previous_save_state']}")
233
+
234
+ if args['change_output_name']:
235
+ output.append(f"--output_name={args['change_output_name']}")
236
+
237
+ if args['training_comment']:
238
+ output.append(f"--training_comment={args['training_comment']}")
239
+
240
+ if args['cosine_restarts'] and args['scheduler'] == "cosine_with_restarts":
241
+ output.append(f"--lr_scheduler_num_cycles={args['cosine_restarts']}")
242
+
243
+ if args['scheduler_power'] and args['scheduler'] == "polynomial":
244
+ output.append(f"--lr_scheduler_power={args['scheduler_power']}")
245
+
246
+ if args['persistent_workers']:
247
+ output.append(f"--persistent_data_loader_workers")
248
+
249
+ if args['unet_only']:
250
+ output.append("--network_train_unet_only")
251
+
252
+ if args['text_only'] and not args['unet_only']:
253
+ output.append("--network_train_text_encoder_only")
254
+
255
+ if args["log_dir"]:
256
+ output.append(f"--logging_dir={args['log_dir']}")
257
+
258
+ if args['caption_dropout_rate']:
259
+ output.append(f"--caption_dropout_rate={args['caption_dropout_rate']}")
260
+
261
+ if args['caption_dropout_every_n_epochs']:
262
+ output.append(f"--caption_dropout_every_n_epochs={args['caption_dropout_every_n_epochs']}")
263
+
264
+ if args['caption_tag_dropout_rate']:
265
+ output.append(f"--caption_tag_dropout_rate={args['caption_tag_dropout_rate']}")
266
+
267
+ if args['v2']:
268
+ output.append("--v2")
269
+
270
+ if args['v2'] and args['v_parameterization']:
271
+ output.append("--v_parameterization")
272
+
273
+ if args['noise_offset']:
274
+ output.append(f"--noise_offset={args['noise_offset']}")
275
+ return output
276
+
277
+
278
+ def find_max_steps(args: dict) -> int:
279
+ total_steps = 0
280
+ folders = os.listdir(args["img_folder"])
281
+ for folder in folders:
282
+ if not os.path.isdir(os.path.join(args["img_folder"], folder)):
283
+ continue
284
+ num_repeats = folder.split("_")
285
+ if len(num_repeats) < 2:
286
+ print(f"folder {folder} is not in the correct format. Format is x_name. skipping")
287
+ continue
288
+ try:
289
+ num_repeats = int(num_repeats[0])
290
+ except ValueError:
291
+ print(f"folder {folder} is not in the correct format. Format is x_name. skipping")
292
+ continue
293
+ imgs = 0
294
+ for file in os.listdir(os.path.join(args["img_folder"], folder)):
295
+ if os.path.isdir(file):
296
+ continue
297
+ ext = file.split(".")
298
+ if ext[-1].lower() in {"png", "bmp", "gif", "jpeg", "jpg", "webp"}:
299
+ imgs += 1
300
+ total_steps += (num_repeats * imgs)
301
+ total_steps = int((total_steps / args["batch_size"]) * args["num_epochs"])
302
+ return total_steps
303
+
304
+
305
+ def add_misc_args(parser):
306
+ parser.add_argument("--save_json_path", type=str, default=None,
307
+ help="Path to save a configuration json file to")
308
+ parser.add_argument("--load_json_path", type=str, default=None,
309
+ help="Path to a json file to configure things from")
310
+ parser.add_argument("--no_metadata", action='store_true',
311
+ help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
312
+ parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
313
+ help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)")
314
+
315
+ parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
316
+ parser.add_argument("--text_encoder_lr", type=float, default=None,
317
+ help="learning rate for Text Encoder / Text Encoderの学習率")
318
+ parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
319
+ help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
320
+ parser.add_argument("--lr_scheduler_power", type=float, default=1,
321
+ help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
322
+
323
+ parser.add_argument("--network_weights", type=str, default=None,
324
+ help="pretrained weights for network / 学習するネットワークの初期重み")
325
+ parser.add_argument("--network_module", type=str, default=None,
326
+ help='network module to train / 学習対象のネットワークのモジュール')
327
+ parser.add_argument("--network_dim", type=int, default=None,
328
+ help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
329
+ parser.add_argument("--network_alpha", type=float, default=1,
330
+ help='alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)')
331
+ parser.add_argument("--network_args", type=str, default=None, nargs='*',
332
+ help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
333
+ parser.add_argument("--network_train_unet_only", action="store_true",
334
+ help="only training U-Net part / U-Net関連部分のみ学習する")
335
+ parser.add_argument("--network_train_text_encoder_only", action="store_true",
336
+ help="only training Text Encoder part / Text Encoder関連部分のみ学習する")
337
+ parser.add_argument("--training_comment", type=str, default=None,
338
+ help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
339
+
340
+
341
+ def setup_args(parser):
342
+ util.add_sd_models_arguments(parser)
343
+ util.add_dataset_arguments(parser, True, True, True)
344
+ util.add_training_arguments(parser, True)
345
+ add_misc_args(parser)
346
+
347
+
348
+ def get_occurrence_of_tags(args):
349
+ extension = args['caption_extension']
350
+ img_folder = args['img_folder']
351
+ output_folder = args['output_folder']
352
+ occurrence_dict = {}
353
+ print(img_folder)
354
+ for folder in os.listdir(img_folder):
355
+ print(folder)
356
+ if not os.path.isdir(os.path.join(img_folder, folder)):
357
+ continue
358
+ for file in os.listdir(os.path.join(img_folder, folder)):
359
+ if not os.path.isfile(os.path.join(img_folder, folder, file)):
360
+ continue
361
+ ext = os.path.splitext(file)[1]
362
+ if ext != extension:
363
+ continue
364
+ get_tags_from_file(os.path.join(img_folder, folder, file), occurrence_dict)
365
+ if not args['sort_tag_occurrence_alphabetically']:
366
+ output_list = {k: v for k, v in sorted(occurrence_dict.items(), key=lambda item: item[1], reverse=True)}
367
+ else:
368
+ output_list = {k: v for k, v in sorted(occurrence_dict.items(), key=lambda item: item[0])}
369
+ name = args['change_output_name'] if args['change_output_name'] else "last"
370
+ with open(os.path.join(output_folder, f"{name}.txt"), "w") as f:
371
+ f.write(f"Below is a list of keywords used during the training of {args['change_output_name']}:\n")
372
+ for k, v in output_list.items():
373
+ f.write(f"[{v}] {k}\n")
374
+ print(f"Created a txt file named {name}.txt in the output folder")
375
+
376
+
377
+ def get_tags_from_file(file, occurrence_dict):
378
+ f = open(file)
379
+ temp = f.read().replace(", ", ",").split(",")
380
+ f.close()
381
+ for tag in temp:
382
+ if tag in occurrence_dict:
383
+ occurrence_dict[tag] += 1
384
+ else:
385
+ occurrence_dict[tag] = 1
386
+
387
+
388
+ def ask_file(message, accepted_ext_list, file_path=None):
389
+ mb.showinfo(message=message)
390
+ res = ""
391
+ _initialdir = ""
392
+ _initialfile = ""
393
+ if file_path != None:
394
+ _initialdir = os.path.dirname(file_path) if os.path.exists(file_path) else ""
395
+ _initialfile = os.path.basename(file_path) if os.path.exists(file_path) else ""
396
+
397
+ while res == "":
398
+ res = fd.askopenfilename(title=message, initialdir=_initialdir, initialfile=_initialfile)
399
+ if res == "" or type(res) == tuple:
400
+ ret = mb.askretrycancel(message="Do you want to to cancel training?")
401
+ if not ret:
402
+ exit()
403
+ continue
404
+ elif not os.path.exists(res):
405
+ res = ""
406
+ continue
407
+ _, name = os.path.split(res)
408
+ split_name = name.split(".")
409
+ if split_name[-1] not in accepted_ext_list:
410
+ res = ""
411
+ return res
412
+
413
+
414
+ def ask_dir(message, dir_path=None):
415
+ mb.showinfo(message=message)
416
+ res = ""
417
+ _initialdir = ""
418
+ if dir_path != None:
419
+ _initialdir = dir_path if os.path.exists(dir_path) else ""
420
+ while res == "":
421
+ res = fd.askdirectory(title=message, initialdir=_initialdir)
422
+ if res == "" or type(res) == tuple:
423
+ ret = mb.askretrycancel(message="Do you want to to cancel training?")
424
+ if not ret:
425
+ exit()
426
+ continue
427
+ if not os.path.exists(res):
428
+ res = ""
429
+ return res
430
+
431
+
432
+ def ask_elements_trunc(args: dict):
433
+ args['base_model'] = ask_file("Select your base model", {"ckpt", "safetensors"}, args['base_model'])
434
+ args['img_folder'] = ask_dir("Select your image folder", args['img_folder'])
435
+ args['output_folder'] = ask_dir("Select your output folder", args['output_folder'])
436
+
437
+ ret = mb.askyesno(message="Do you want to save a json of your configuration?")
438
+ if ret:
439
+ args['save_json_folder'] = ask_dir("Select the folder to save json files to", args['save_json_folder'])
440
+ else:
441
+ args['save_json_folder'] = None
442
+
443
+ ret = mb.askyesno(message="Are you training on a SD2 based model?")
444
+ if ret:
445
+ args['v2'] = True
446
+
447
+ ret = mb.askyesno(message="Are you training on an realistic model?")
448
+ if ret:
449
+ args['clip_skip'] = 1
450
+
451
+ if args['v2']:
452
+ ret = mb.askyesno(message="Are you training on a model based on the 768x version of SD2?")
453
+ if ret:
454
+ args['v_parameterization'] = True
455
+
456
+ ret = mb.askyesno(message="Do you want to use regularization images?")
457
+ if ret:
458
+ args['reg_img_folder'] = ask_dir("Select your regularization folder", args['reg_img_folder'])
459
+ else:
460
+ args['reg_img_folder'] = None
461
+
462
+ ret = mb.askyesno(message="Do you want to continue from an earlier version?")
463
+ if ret:
464
+ args['lora_model_for_resume'] = ask_file("Select your lora model", {"ckpt", "pt", "safetensors"},
465
+ args['lora_model_for_resume'])
466
+ else:
467
+ args['lora_model_for_resume'] = None
468
+
469
+ ret = mb.askyesno(message="Do you want to flip all of your images? It is supposed to reduce biases\n"
470
+ "within your dataset but it can also ruin learning an asymmetrical element\n")
471
+ if ret:
472
+ args['flip_aug'] = True
473
+
474
+ ret = mb.askyesno(message="Do you want to change the name of output checkpoints?")
475
+ if ret:
476
+ ret = sd.askstring(title="output_name", prompt="What do you want your output name to be?\n"
477
+ "Cancel keeps outputs the original")
478
+ if ret:
479
+ args['change_output_name'] = ret
480
+ else:
481
+ args['change_output_name'] = None
482
+
483
+ ret = sd.askstring(title="comment",
484
+ prompt="Do you want to set a comment that gets put into the metadata?\nA good use of this would "
485
+ "be to include how to use, such as activation keywords.\nCancel will leave empty")
486
+ if ret is None:
487
+ args['training_comment'] = ret
488
+ else:
489
+ args['training_comment'] = None
490
+
491
+ ret = mb.askyesno(message="Do you want to train only one of unet and text encoder?")
492
+ if ret:
493
+ button = ButtonBox("Which do you want to train with?", ["unet_only", "text_only"])
494
+ button.window.mainloop()
495
+ if button.current_value != "":
496
+ args[button.current_value] = True
497
+
498
+ ret = mb.askyesno(message="Do you want to save a txt file that contains a list\n"
499
+ "of all tags that you have used in your training data?\n")
500
+ if ret:
501
+ args['tag_occurrence_txt_file'] = True
502
+ button = ButtonBox("How do you want tags to be ordered?", ["alphabetically", "occurrence-ly"])
503
+ button.window.mainloop()
504
+ if button.current_value == "alphabetically":
505
+ args['sort_tag_occurrence_alphabetically'] = True
506
+
507
+ ret = mb.askyesno(message="Do you want to use caption dropout?")
508
+ if ret:
509
+ ret = mb.askyesno(message="Do you want full caption files to dropout randomly?")
510
+ if ret:
511
+ ret = sd.askinteger(title="Caption_File_Dropout",
512
+ prompt="How often do you want caption files to drop out?\n"
513
+ "enter a number from 0 to 100 that is the percentage chance of dropout\n"
514
+ "Cancel sets to 0")
515
+ if ret and 0 <= ret <= 100:
516
+ args['caption_dropout_rate'] = ret / 100.0
517
+
518
+ ret = mb.askyesno(message="Do you want to have full epochs have no captions?")
519
+ if ret:
520
+ ret = sd.askinteger(title="Caption_epoch_dropout", prompt="The number set here is how often you will have an"
521
+ "epoch with no captions\nSo if you set 3, then every"
522
+ "three epochs will not have captions (3, 6, 9)\n"
523
+ "Cancel will set to None")
524
+ if ret:
525
+ args['caption_dropout_every_n_epochs'] = ret
526
+
527
+ ret = mb.askyesno(message="Do you want to have tags to randomly drop?")
528
+ if ret:
529
+ ret = sd.askinteger(title="Caption_tag_dropout", prompt="How often do you want tags to randomly drop out?\n"
530
+ "Enter a number between 0 and 100, that is the percentage"
531
+ "chance of dropout.\nCancel sets to 0")
532
+ if ret and 0 <= ret <= 100:
533
+ args['caption_tag_dropout_rate'] = ret / 100.0
534
+
535
+ ret = mb.askyesno(message="Do you want to use noise offset? Noise offset seems to allow for SD to better generate\n"
536
+ "darker or lighter images using this than normal.")
537
+ if ret:
538
+ ret = sd.askfloat(title="noise_offset", prompt="What value do you want to set? recommended value is 0.1,\n"
539
+ "but it can go higher. Cancel defaults to 0.1")
540
+ if ret:
541
+ args['noise_offset'] = ret
542
+ else:
543
+ args['noise_offset'] = 0.1
544
+ return args
545
+
546
+
547
+ def ask_elements(args: dict):
548
+ # start with file dialog
549
+ args['base_model'] = ask_file("Select your base model", {"ckpt", "safetensors"}, args['base_model'])
550
+ args['img_folder'] = ask_dir("Select your image folder", args['img_folder'])
551
+ args['output_folder'] = ask_dir("Select your output folder", args['output_folder'])
552
+
553
+ # optional file dialog
554
+ ret = mb.askyesno(message="Do you want to save a json of your configuration?")
555
+ if ret:
556
+ args['save_json_folder'] = ask_dir("Select the folder to save json files to", args['save_json_folder'])
557
+ else:
558
+ args['save_json_folder'] = None
559
+
560
+ ret = mb.askyesno(message="Are you training on a SD2 based model?")
561
+ if ret:
562
+ args['v2'] = True
563
+
564
+ ret = mb.askyesno(message="Are you training on an realistic model?")
565
+ if ret:
566
+ args['clip_skip'] = 1
567
+
568
+ if args['v2']:
569
+ ret = mb.askyesno(message="Are you training on a model based on the 768x version of SD2?")
570
+ if ret:
571
+ args['v_parameterization'] = True
572
+
573
+ ret = mb.askyesno(message="Do you want to use regularization images?")
574
+ if ret:
575
+ args['reg_img_folder'] = ask_dir("Select your regularization folder", args['reg_img_folder'])
576
+ else:
577
+ args['reg_img_folder'] = None
578
+
579
+ ret = mb.askyesno(message="Do you want to continue from an earlier version?")
580
+ if ret:
581
+ args['lora_model_for_resume'] = ask_file("Select your lora model", {"ckpt", "pt", "safetensors"},
582
+ args['lora_model_for_resume'])
583
+ else:
584
+ args['lora_model_for_resume'] = None
585
+
586
+ ret = mb.askyesno(message="Do you want to flip all of your images? It is supposed to reduce biases\n"
587
+ "within your dataset but it can also ruin learning an asymmetrical element\n")
588
+ if ret:
589
+ args['flip_aug'] = True
590
+
591
+ # text based required elements
592
+ ret = sd.askinteger(title="batch_size",
593
+ prompt="The number of images that get processed at one time, this is directly proportional to "
594
+ "your vram and resolution. with 12gb of vram, at 512 reso, you can get a maximum of 6 "
595
+ "batch size\nHow large is your batch size going to be?\nCancel will default to 1")
596
+ if ret is None:
597
+ args['batch_size'] = 1
598
+ else:
599
+ args['batch_size'] = ret
600
+
601
+ ret = sd.askinteger(title="num_epochs", prompt="How many epochs do you want?\nCancel will default to 1")
602
+ if ret is None:
603
+ args['num_epochs'] = 1
604
+ else:
605
+ args['num_epochs'] = ret
606
+
607
+ ret = sd.askinteger(title="network_dim", prompt="What is the dim size you want to use?\nCancel will default to 128")
608
+ if ret is None:
609
+ args['net_dim'] = 128
610
+ else:
611
+ args['net_dim'] = ret
612
+
613
+ ret = sd.askfloat(title="alpha", prompt="Alpha is the scalar of the training, generally a good starting point is "
614
+ "0.5x dim size\nWhat Alpha do you want?\nCancel will default to equal to "
615
+ "0.5 x network_dim")
616
+ if ret is None:
617
+ args['alpha'] = args['net_dim'] / 2
618
+ else:
619
+ args['alpha'] = ret
620
+
621
+ ret = sd.askinteger(title="resolution", prompt="How large of a resolution do you want to train at?\n"
622
+ "Cancel will default to 512")
623
+ if ret is None:
624
+ args['train_resolution'] = 512
625
+ else:
626
+ args['train_resolution'] = ret
627
+
628
+ ret = sd.askfloat(title="learning_rate", prompt="What learning rate do you want to use?\n"
629
+ "Cancel will default to 1e-4")
630
+ if ret is None:
631
+ args['learning_rate'] = 1e-4
632
+ else:
633
+ args['learning_rate'] = ret
634
+
635
+ ret = sd.askfloat(title="text_encoder_lr", prompt="Do you want to set the text_encoder_lr?\n"
636
+ "Cancel will default to None")
637
+ if ret is None:
638
+ args['text_encoder_lr'] = None
639
+ else:
640
+ args['text_encoder_lr'] = ret
641
+
642
+ ret = sd.askfloat(title="unet_lr", prompt="Do you want to set the unet_lr?\nCancel will default to None")
643
+ if ret is None:
644
+ args['unet_lr'] = None
645
+ else:
646
+ args['unet_lr'] = ret
647
+
648
+ button = ButtonBox("Which scheduler do you want?", ["cosine_with_restarts", "cosine", "polynomial",
649
+ "constant", "constant_with_warmup", "linear"])
650
+ button.window.mainloop()
651
+ args['scheduler'] = button.current_value if button.current_value != "" else "cosine_with_restarts"
652
+
653
+ if args['scheduler'] == "cosine_with_restarts":
654
+ ret = sd.askinteger(title="Cycle Count",
655
+ prompt="How many times do you want cosine to restart?\nThis is the entire amount of times "
656
+ "it will restart for the entire training\nCancel will default to 1")
657
+ if ret is None:
658
+ args['cosine_restarts'] = 1
659
+ else:
660
+ args['cosine_restarts'] = ret
661
+
662
+ if args['scheduler'] == "polynomial":
663
+ ret = sd.askfloat(title="Poly Strength",
664
+ prompt="What power do you want to set your polynomial to?\nhigher power means that the "
665
+ "model reduces the learning more more aggressively from initial training.\n1 = "
666
+ "linear\nCancel sets to 1")
667
+ if ret is None:
668
+ args['scheduler_power'] = 1
669
+ else:
670
+ args['scheduler_power'] = ret
671
+
672
+ ret = mb.askyesno(message="Do you want to save epochs as it trains?")
673
+ if ret:
674
+ ret = sd.askinteger(title="save_epoch",
675
+ prompt="How often do you want to save epochs?\nCancel will default to 1")
676
+ if ret is None:
677
+ args['save_every_n_epochs'] = 1
678
+ else:
679
+ args['save_every_n_epochs'] = ret
680
+
681
+ ret = mb.askyesno(message="Do you want to shuffle captions?")
682
+ if ret:
683
+ args['shuffle_captions'] = True
684
+ else:
685
+ args['shuffle_captions'] = False
686
+
687
+ ret = mb.askyesno(message="Do you want to keep some tokens at the front of your captions?")
688
+ if ret:
689
+ ret = sd.askinteger(title="keep_tokens", prompt="How many do you want to keep at the front?"
690
+ "\nCancel will default to 1")
691
+ if ret is None:
692
+ args['keep_tokens'] = 1
693
+ else:
694
+ args['keep_tokens'] = ret
695
+
696
+ ret = mb.askyesno(message="Do you want to have a warmup ratio?")
697
+ if ret:
698
+ ret = sd.askfloat(title="warmup_ratio", prompt="What is the ratio of steps to use as warmup "
699
+ "steps?\nCancel will default to None")
700
+ if ret is None:
701
+ args['warmup_lr_ratio'] = None
702
+ else:
703
+ args['warmup_lr_ratio'] = ret
704
+
705
+ ret = mb.askyesno(message="Do you want to change the name of output checkpoints?")
706
+ if ret:
707
+ ret = sd.askstring(title="output_name", prompt="What do you want your output name to be?\n"
708
+ "Cancel keeps outputs the original")
709
+ if ret:
710
+ args['change_output_name'] = ret
711
+ else:
712
+ args['change_output_name'] = None
713
+
714
+ ret = sd.askstring(title="comment",
715
+ prompt="Do you want to set a comment that gets put into the metadata?\nA good use of this would "
716
+ "be to include how to use, such as activation keywords.\nCancel will leave empty")
717
+ if ret is None:
718
+ args['training_comment'] = ret
719
+ else:
720
+ args['training_comment'] = None
721
+
722
+ ret = mb.askyesno(message="Do you want to train only one of unet and text encoder?")
723
+ if ret:
724
+ if ret:
725
+ button = ButtonBox("Which do you want to train with?", ["unet_only", "text_only"])
726
+ button.window.mainloop()
727
+ if button.current_value != "":
728
+ args[button.current_value] = True
729
+
730
+ ret = mb.askyesno(message="Do you want to save a txt file that contains a list\n"
731
+ "of all tags that you have used in your training data?\n")
732
+ if ret:
733
+ args['tag_occurrence_txt_file'] = True
734
+ button = ButtonBox("How do you want tags to be ordered?", ["alphabetically", "occurrence-ly"])
735
+ button.window.mainloop()
736
+ if button.current_value == "alphabetically":
737
+ args['sort_tag_occurrence_alphabetically'] = True
738
+
739
+ ret = mb.askyesno(message="Do you want to use caption dropout?")
740
+ if ret:
741
+ ret = mb.askyesno(message="Do you want full caption files to dropout randomly?")
742
+ if ret:
743
+ ret = sd.askinteger(title="Caption_File_Dropout",
744
+ prompt="How often do you want caption files to drop out?\n"
745
+ "enter a number from 0 to 100 that is the percentage chance of dropout\n"
746
+ "Cancel sets to 0")
747
+ if ret and 0 <= ret <= 100:
748
+ args['caption_dropout_rate'] = ret / 100.0
749
+
750
+ ret = mb.askyesno(message="Do you want to have full epochs have no captions?")
751
+ if ret:
752
+ ret = sd.askinteger(title="Caption_epoch_dropout", prompt="The number set here is how often you will have an"
753
+ "epoch with no captions\nSo if you set 3, then every"
754
+ "three epochs will not have captions (3, 6, 9)\n"
755
+ "Cancel will set to None")
756
+ if ret:
757
+ args['caption_dropout_every_n_epochs'] = ret
758
+
759
+ ret = mb.askyesno(message="Do you want to have tags to randomly drop?")
760
+ if ret:
761
+ ret = sd.askinteger(title="Caption_tag_dropout", prompt="How often do you want tags to randomly drop out?\n"
762
+ "Enter a number between 0 and 100, that is the percentage"
763
+ "chance of dropout.\nCancel sets to 0")
764
+ if ret and 0 <= ret <= 100:
765
+ args['caption_tag_dropout_rate'] = ret / 100.0
766
+
767
+ ret = mb.askyesno(message="Do you want to use noise offset? Noise offset seems to allow for SD to better generate\n"
768
+ "darker or lighter images using this than normal.")
769
+ if ret:
770
+ ret = sd.askfloat(title="noise_offset", prompt="What value do you want to set? recommended value is 0.1,\n"
771
+ "but it can go higher. Cancel defaults to 0.1")
772
+ if ret:
773
+ args['noise_offset'] = ret
774
+ else:
775
+ args['noise_offset'] = 0.1
776
+ return args
777
+
778
+
779
+ def save_json(path, obj: dict) -> None:
780
+ fp = open(os.path.join(path, f"config-{time.time()}.json"), "w")
781
+ json.dump(obj, fp=fp, indent=4)
782
+ fp.close()
783
+
784
+
785
+ def load_json(path, obj: dict) -> dict:
786
+ with open(path) as f:
787
+ json_obj = json.loads(f.read())
788
+ print("loaded json, setting variables...")
789
+ ui_name_scheme = {"pretrained_model_name_or_path": "base_model", "logging_dir": "log_dir",
790
+ "train_data_dir": "img_folder", "reg_data_dir": "reg_img_folder",
791
+ "output_dir": "output_folder", "max_resolution": "train_resolution",
792
+ "lr_scheduler": "scheduler", "lr_warmup": "warmup_lr_ratio",
793
+ "train_batch_size": "batch_size", "epoch": "num_epochs",
794
+ "save_at_n_epochs": "save_every_n_epochs", "num_cpu_threads_per_process": "num_workers",
795
+ "enable_bucket": "buckets", "save_model_as": "save_as", "shuffle_caption": "shuffle_captions",
796
+ "resume": "load_previous_save_state", "network_dim": "net_dim",
797
+ "gradient_accumulation_steps": "gradient_acc_steps", "output_name": "change_output_name",
798
+ "network_alpha": "alpha", "lr_scheduler_num_cycles": "cosine_restarts",
799
+ "lr_scheduler_power": "scheduler_power"}
800
+
801
+ for key in list(json_obj):
802
+ if key in ui_name_scheme:
803
+ json_obj[ui_name_scheme[key]] = json_obj[key]
804
+ if ui_name_scheme[key] in {"batch_size", "num_epochs"}:
805
+ try:
806
+ json_obj[ui_name_scheme[key]] = int(json_obj[ui_name_scheme[key]])
807
+ except ValueError:
808
+ print(f"attempting to load {key} from json failed as input isn't an integer")
809
+ quit(1)
810
+
811
+ for key in list(json_obj):
812
+ if obj["json_load_skip_list"] and key in obj["json_load_skip_list"]:
813
+ continue
814
+ if key in obj:
815
+ if key in {"keep_tokens", "warmup_lr_ratio"}:
816
+ json_obj[key] = int(json_obj[key]) if json_obj[key] is not None else None
817
+ if key in {"learning_rate", "unet_lr", "text_encoder_lr"}:
818
+ json_obj[key] = float(json_obj[key]) if json_obj[key] is not None else None
819
+ if obj[key] != json_obj[key]:
820
+ print_change(key, obj[key], json_obj[key])
821
+ obj[key] = json_obj[key]
822
+ print("completed changing variables.")
823
+ return obj
824
+
825
+
826
+ def print_change(value, old, new):
827
+ print(f"{value} changed from {old} to {new}")
828
+
829
+
830
+ class ButtonBox:
831
+ def __init__(self, label: str, button_name_list: list[str]) -> None:
832
+ self.window = tk.Tk()
833
+ self.button_list = []
834
+ self.current_value = ""
835
+
836
+ self.window.attributes("-topmost", True)
837
+ self.window.resizable(False, False)
838
+ self.window.eval('tk::PlaceWindow . center')
839
+
840
+ def del_window():
841
+ self.window.quit()
842
+ self.window.destroy()
843
+
844
+ self.window.protocol("WM_DELETE_WINDOW", del_window)
845
+ tk.Label(text=label, master=self.window).pack()
846
+ for button in button_name_list:
847
+ self.button_list.append(ttk.Button(text=button, master=self.window,
848
+ command=partial(self.set_current_value, button)))
849
+ self.button_list[-1].pack()
850
+
851
+ def set_current_value(self, value):
852
+ self.current_value = value
853
+ self.window.quit()
854
+ self.window.destroy()
855
+
856
+
857
+ root = tk.Tk()
858
+ root.attributes('-topmost', True)
859
+ root.withdraw()
860
+
861
+ if __name__ == "__main__":
862
+ main()
lycoris/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from lycoris import (
2
+ kohya,
3
+ kohya_model_utils,
4
+ kohya_utils,
5
+ locon,
6
+ loha,
7
+ utils,
8
+ )
lycoris/kohya.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # network module for kohya
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+ # https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
6
+
7
+ import math
8
+ import os
9
+ from typing import List
10
+ import torch
11
+
12
+ from .kohya_utils import *
13
+ from .locon import LoConModule
14
+ from .loha import LohaModule
15
+
16
+
17
+ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
18
+ if network_dim is None:
19
+ network_dim = 4 # default
20
+ conv_dim = int(kwargs.get('conv_dim', network_dim))
21
+ conv_alpha = float(kwargs.get('conv_alpha', network_alpha))
22
+ dropout = float(kwargs.get('dropout', 0.))
23
+ algo = kwargs.get('algo', 'lora')
24
+ network_module = {
25
+ 'lora': LoConModule,
26
+ 'loha': LohaModule,
27
+ }[algo]
28
+
29
+ print(f'Using rank adaptation algo: {algo}')
30
+ network = LoRANetwork(
31
+ text_encoder, unet,
32
+ multiplier=multiplier,
33
+ lora_dim=network_dim, conv_lora_dim=conv_dim,
34
+ alpha=network_alpha, conv_alpha=conv_alpha,
35
+ dropout=dropout,
36
+ network_module=network_module
37
+ )
38
+
39
+ return network
40
+
41
+
42
+ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
43
+ if os.path.splitext(file)[1] == '.safetensors':
44
+ from safetensors.torch import load_file, safe_open
45
+ weights_sd = load_file(file)
46
+ else:
47
+ weights_sd = torch.load(file, map_location='cpu')
48
+
49
+ # get dim (rank)
50
+ network_alpha = None
51
+ network_dim = None
52
+ for key, value in weights_sd.items():
53
+ if network_alpha is None and 'alpha' in key:
54
+ network_alpha = value
55
+ if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
56
+ network_dim = value.size()[0]
57
+
58
+ if network_alpha is None:
59
+ network_alpha = network_dim
60
+
61
+ network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
62
+ network.weights_sd = weights_sd
63
+ return network
64
+
65
+
66
+ class LoRANetwork(torch.nn.Module):
67
+ '''
68
+ LoRA + LoCon
69
+ '''
70
+ # Ignore proj_in or proj_out, their channels is only a few.
71
+ UNET_TARGET_REPLACE_MODULE = [
72
+ "Transformer2DModel",
73
+ "Attention",
74
+ "ResnetBlock2D",
75
+ "Downsample2D",
76
+ "Upsample2D"
77
+ ]
78
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
79
+ LORA_PREFIX_UNET = 'lora_unet'
80
+ LORA_PREFIX_TEXT_ENCODER = 'lora_te'
81
+
82
+ def __init__(
83
+ self,
84
+ text_encoder, unet,
85
+ multiplier=1.0,
86
+ lora_dim=4, conv_lora_dim=4,
87
+ alpha=1, conv_alpha=1,
88
+ dropout = 0, network_module = LoConModule,
89
+ ) -> None:
90
+ super().__init__()
91
+ self.multiplier = multiplier
92
+ self.lora_dim = lora_dim
93
+ self.conv_lora_dim = int(conv_lora_dim)
94
+ if self.conv_lora_dim != self.lora_dim:
95
+ print('Apply different lora dim for conv layer')
96
+ print(f'LoCon Dim: {conv_lora_dim}, LoRA Dim: {lora_dim}')
97
+
98
+ self.alpha = alpha
99
+ self.conv_alpha = float(conv_alpha)
100
+ if self.alpha != self.conv_alpha:
101
+ print('Apply different alpha value for conv layer')
102
+ print(f'LoCon alpha: {conv_alpha}, LoRA alpha: {alpha}')
103
+
104
+ if 1 >= dropout >= 0:
105
+ print(f'Use Dropout value: {dropout}')
106
+ self.dropout = dropout
107
+
108
+ # create module instances
109
+ def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[network_module]:
110
+ print('Create LoCon Module')
111
+ loras = []
112
+ for name, module in root_module.named_modules():
113
+ if module.__class__.__name__ in target_replace_modules:
114
+ for child_name, child_module in module.named_modules():
115
+ lora_name = prefix + '.' + name + '.' + child_name
116
+ lora_name = lora_name.replace('.', '_')
117
+ if child_module.__class__.__name__ == 'Linear' and lora_dim>0:
118
+ lora = network_module(
119
+ lora_name, child_module, self.multiplier,
120
+ self.lora_dim, self.alpha, self.dropout
121
+ )
122
+ elif child_module.__class__.__name__ == 'Conv2d':
123
+ k_size, *_ = child_module.kernel_size
124
+ if k_size==1 and lora_dim>0:
125
+ lora = network_module(
126
+ lora_name, child_module, self.multiplier,
127
+ self.lora_dim, self.alpha, self.dropout
128
+ )
129
+ elif conv_lora_dim>0:
130
+ lora = network_module(
131
+ lora_name, child_module, self.multiplier,
132
+ self.conv_lora_dim, self.conv_alpha, self.dropout
133
+ )
134
+ else:
135
+ continue
136
+ else:
137
+ continue
138
+ loras.append(lora)
139
+ return loras
140
+
141
+ self.text_encoder_loras = create_modules(
142
+ LoRANetwork.LORA_PREFIX_TEXT_ENCODER,
143
+ text_encoder,
144
+ LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
145
+ )
146
+ print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
147
+
148
+ self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
149
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
150
+
151
+ self.weights_sd = None
152
+
153
+ # assertion
154
+ names = set()
155
+ for lora in self.text_encoder_loras + self.unet_loras:
156
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
157
+ names.add(lora.lora_name)
158
+
159
+ def set_multiplier(self, multiplier):
160
+ self.multiplier = multiplier
161
+ for lora in self.text_encoder_loras + self.unet_loras:
162
+ lora.multiplier = self.multiplier
163
+
164
+ def load_weights(self, file):
165
+ if os.path.splitext(file)[1] == '.safetensors':
166
+ from safetensors.torch import load_file, safe_open
167
+ self.weights_sd = load_file(file)
168
+ else:
169
+ self.weights_sd = torch.load(file, map_location='cpu')
170
+
171
+ def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
172
+ if self.weights_sd:
173
+ weights_has_text_encoder = weights_has_unet = False
174
+ for key in self.weights_sd.keys():
175
+ if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
176
+ weights_has_text_encoder = True
177
+ elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
178
+ weights_has_unet = True
179
+
180
+ if apply_text_encoder is None:
181
+ apply_text_encoder = weights_has_text_encoder
182
+ else:
183
+ assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
184
+
185
+ if apply_unet is None:
186
+ apply_unet = weights_has_unet
187
+ else:
188
+ assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
189
+ else:
190
+ assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
191
+
192
+ if apply_text_encoder:
193
+ print("enable LoRA for text encoder")
194
+ else:
195
+ self.text_encoder_loras = []
196
+
197
+ if apply_unet:
198
+ print("enable LoRA for U-Net")
199
+ else:
200
+ self.unet_loras = []
201
+
202
+ for lora in self.text_encoder_loras + self.unet_loras:
203
+ lora.apply_to()
204
+ self.add_module(lora.lora_name, lora)
205
+
206
+ if self.weights_sd:
207
+ # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
208
+ info = self.load_state_dict(self.weights_sd, False)
209
+ print(f"weights are loaded: {info}")
210
+
211
+ def enable_gradient_checkpointing(self):
212
+ # not supported
213
+ def make_ckpt(module):
214
+ if isinstance(module, torch.nn.Module):
215
+ module.grad_ckpt = True
216
+ self.apply(make_ckpt)
217
+ pass
218
+
219
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
220
+ def enumerate_params(loras):
221
+ params = []
222
+ for lora in loras:
223
+ params.extend(lora.parameters())
224
+ return params
225
+
226
+ self.requires_grad_(True)
227
+ all_params = []
228
+
229
+ if self.text_encoder_loras:
230
+ param_data = {'params': enumerate_params(self.text_encoder_loras)}
231
+ if text_encoder_lr is not None:
232
+ param_data['lr'] = text_encoder_lr
233
+ all_params.append(param_data)
234
+
235
+ if self.unet_loras:
236
+ param_data = {'params': enumerate_params(self.unet_loras)}
237
+ if unet_lr is not None:
238
+ param_data['lr'] = unet_lr
239
+ all_params.append(param_data)
240
+
241
+ return all_params
242
+
243
+ def prepare_grad_etc(self, text_encoder, unet):
244
+ self.requires_grad_(True)
245
+
246
+ def on_epoch_start(self, text_encoder, unet):
247
+ self.train()
248
+
249
+ def get_trainable_params(self):
250
+ return self.parameters()
251
+
252
+ def save_weights(self, file, dtype, metadata):
253
+ if metadata is not None and len(metadata) == 0:
254
+ metadata = None
255
+
256
+ state_dict = self.state_dict()
257
+
258
+ if dtype is not None:
259
+ for key in list(state_dict.keys()):
260
+ v = state_dict[key]
261
+ v = v.detach().clone().to("cpu").to(dtype)
262
+ state_dict[key] = v
263
+
264
+ if os.path.splitext(file)[1] == '.safetensors':
265
+ from safetensors.torch import save_file
266
+
267
+ # Precalculate model hashes to save time on indexing
268
+ if metadata is None:
269
+ metadata = {}
270
+ model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
271
+ metadata["sshs_model_hash"] = model_hash
272
+ metadata["sshs_legacy_hash"] = legacy_hash
273
+
274
+ save_file(state_dict, file, metadata)
275
+ else:
276
+ torch.save(state_dict, file)
lycoris/kohya_model_utils.py ADDED
@@ -0,0 +1,1184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ https://github.com/kohya-ss/sd-scripts/blob/main/library/model_util.py
3
+ '''
4
+ # v1: split from train_db_fixed.py.
5
+ # v2: support safetensors
6
+
7
+ import math
8
+ import os
9
+ import torch
10
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
11
+ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
12
+ from safetensors.torch import load_file, save_file
13
+
14
+ # DiffUsers版StableDiffusionのモデルパラメータ
15
+ NUM_TRAIN_TIMESTEPS = 1000
16
+ BETA_START = 0.00085
17
+ BETA_END = 0.0120
18
+
19
+ UNET_PARAMS_MODEL_CHANNELS = 320
20
+ UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
21
+ UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
22
+ UNET_PARAMS_IMAGE_SIZE = 32 # unused
23
+ UNET_PARAMS_IN_CHANNELS = 4
24
+ UNET_PARAMS_OUT_CHANNELS = 4
25
+ UNET_PARAMS_NUM_RES_BLOCKS = 2
26
+ UNET_PARAMS_CONTEXT_DIM = 768
27
+ UNET_PARAMS_NUM_HEADS = 8
28
+
29
+ VAE_PARAMS_Z_CHANNELS = 4
30
+ VAE_PARAMS_RESOLUTION = 256
31
+ VAE_PARAMS_IN_CHANNELS = 3
32
+ VAE_PARAMS_OUT_CH = 3
33
+ VAE_PARAMS_CH = 128
34
+ VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
35
+ VAE_PARAMS_NUM_RES_BLOCKS = 2
36
+
37
+ # V2
38
+ V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
39
+ V2_UNET_PARAMS_CONTEXT_DIM = 1024
40
+
41
+ # Diffusersの設定を読み込むための参照モデル
42
+ DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
43
+ DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
44
+
45
+
46
+ # region StableDiffusion->Diffusersの変換コード
47
+ # convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
48
+
49
+
50
+ def shave_segments(path, n_shave_prefix_segments=1):
51
+ """
52
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
53
+ """
54
+ if n_shave_prefix_segments >= 0:
55
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
56
+ else:
57
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
58
+
59
+
60
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
61
+ """
62
+ Updates paths inside resnets to the new naming scheme (local renaming)
63
+ """
64
+ mapping = []
65
+ for old_item in old_list:
66
+ new_item = old_item.replace("in_layers.0", "norm1")
67
+ new_item = new_item.replace("in_layers.2", "conv1")
68
+
69
+ new_item = new_item.replace("out_layers.0", "norm2")
70
+ new_item = new_item.replace("out_layers.3", "conv2")
71
+
72
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
73
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
74
+
75
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
76
+
77
+ mapping.append({"old": old_item, "new": new_item})
78
+
79
+ return mapping
80
+
81
+
82
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
83
+ """
84
+ Updates paths inside resnets to the new naming scheme (local renaming)
85
+ """
86
+ mapping = []
87
+ for old_item in old_list:
88
+ new_item = old_item
89
+
90
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
91
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
92
+
93
+ mapping.append({"old": old_item, "new": new_item})
94
+
95
+ return mapping
96
+
97
+
98
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
99
+ """
100
+ Updates paths inside attentions to the new naming scheme (local renaming)
101
+ """
102
+ mapping = []
103
+ for old_item in old_list:
104
+ new_item = old_item
105
+
106
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
107
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
108
+
109
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
110
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
111
+
112
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
113
+
114
+ mapping.append({"old": old_item, "new": new_item})
115
+
116
+ return mapping
117
+
118
+
119
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
120
+ """
121
+ Updates paths inside attentions to the new naming scheme (local renaming)
122
+ """
123
+ mapping = []
124
+ for old_item in old_list:
125
+ new_item = old_item
126
+
127
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
128
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
129
+
130
+ new_item = new_item.replace("q.weight", "query.weight")
131
+ new_item = new_item.replace("q.bias", "query.bias")
132
+
133
+ new_item = new_item.replace("k.weight", "key.weight")
134
+ new_item = new_item.replace("k.bias", "key.bias")
135
+
136
+ new_item = new_item.replace("v.weight", "value.weight")
137
+ new_item = new_item.replace("v.bias", "value.bias")
138
+
139
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
140
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
141
+
142
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
143
+
144
+ mapping.append({"old": old_item, "new": new_item})
145
+
146
+ return mapping
147
+
148
+
149
+ def assign_to_checkpoint(
150
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
151
+ ):
152
+ """
153
+ This does the final conversion step: take locally converted weights and apply a global renaming
154
+ to them. It splits attention layers, and takes into account additional replacements
155
+ that may arise.
156
+
157
+ Assigns the weights to the new checkpoint.
158
+ """
159
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
160
+
161
+ # Splits the attention layers into three variables.
162
+ if attention_paths_to_split is not None:
163
+ for path, path_map in attention_paths_to_split.items():
164
+ old_tensor = old_checkpoint[path]
165
+ channels = old_tensor.shape[0] // 3
166
+
167
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
168
+
169
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
170
+
171
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
172
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
173
+
174
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
175
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
176
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
177
+
178
+ for path in paths:
179
+ new_path = path["new"]
180
+
181
+ # These have already been assigned
182
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
183
+ continue
184
+
185
+ # Global renaming happens here
186
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
187
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
188
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
189
+
190
+ if additional_replacements is not None:
191
+ for replacement in additional_replacements:
192
+ new_path = new_path.replace(replacement["old"], replacement["new"])
193
+
194
+ # proj_attn.weight has to be converted from conv 1D to linear
195
+ if "proj_attn.weight" in new_path:
196
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
197
+ else:
198
+ checkpoint[new_path] = old_checkpoint[path["old"]]
199
+
200
+
201
+ def conv_attn_to_linear(checkpoint):
202
+ keys = list(checkpoint.keys())
203
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
204
+ for key in keys:
205
+ if ".".join(key.split(".")[-2:]) in attn_keys:
206
+ if checkpoint[key].ndim > 2:
207
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
208
+ elif "proj_attn.weight" in key:
209
+ if checkpoint[key].ndim > 2:
210
+ checkpoint[key] = checkpoint[key][:, :, 0]
211
+
212
+
213
+ def linear_transformer_to_conv(checkpoint):
214
+ keys = list(checkpoint.keys())
215
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
216
+ for key in keys:
217
+ if ".".join(key.split(".")[-2:]) in tf_keys:
218
+ if checkpoint[key].ndim == 2:
219
+ checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
220
+
221
+
222
+ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
223
+ """
224
+ Takes a state dict and a config, and returns a converted checkpoint.
225
+ """
226
+
227
+ # extract state_dict for UNet
228
+ unet_state_dict = {}
229
+ unet_key = "model.diffusion_model."
230
+ keys = list(checkpoint.keys())
231
+ for key in keys:
232
+ if key.startswith(unet_key):
233
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
234
+
235
+ new_checkpoint = {}
236
+
237
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
238
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
239
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
240
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
241
+
242
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
243
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
244
+
245
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
246
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
247
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
248
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
249
+
250
+ # Retrieves the keys for the input blocks only
251
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
252
+ input_blocks = {
253
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
254
+ for layer_id in range(num_input_blocks)
255
+ }
256
+
257
+ # Retrieves the keys for the middle blocks only
258
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
259
+ middle_blocks = {
260
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
261
+ for layer_id in range(num_middle_blocks)
262
+ }
263
+
264
+ # Retrieves the keys for the output blocks only
265
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
266
+ output_blocks = {
267
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
268
+ for layer_id in range(num_output_blocks)
269
+ }
270
+
271
+ for i in range(1, num_input_blocks):
272
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
273
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
274
+
275
+ resnets = [
276
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
277
+ ]
278
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
279
+
280
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
281
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
282
+ f"input_blocks.{i}.0.op.weight"
283
+ )
284
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
285
+ f"input_blocks.{i}.0.op.bias"
286
+ )
287
+
288
+ paths = renew_resnet_paths(resnets)
289
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
290
+ assign_to_checkpoint(
291
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
292
+ )
293
+
294
+ if len(attentions):
295
+ paths = renew_attention_paths(attentions)
296
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
297
+ assign_to_checkpoint(
298
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
299
+ )
300
+
301
+ resnet_0 = middle_blocks[0]
302
+ attentions = middle_blocks[1]
303
+ resnet_1 = middle_blocks[2]
304
+
305
+ resnet_0_paths = renew_resnet_paths(resnet_0)
306
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
307
+
308
+ resnet_1_paths = renew_resnet_paths(resnet_1)
309
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
310
+
311
+ attentions_paths = renew_attention_paths(attentions)
312
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
313
+ assign_to_checkpoint(
314
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
315
+ )
316
+
317
+ for i in range(num_output_blocks):
318
+ block_id = i // (config["layers_per_block"] + 1)
319
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
320
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
321
+ output_block_list = {}
322
+
323
+ for layer in output_block_layers:
324
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
325
+ if layer_id in output_block_list:
326
+ output_block_list[layer_id].append(layer_name)
327
+ else:
328
+ output_block_list[layer_id] = [layer_name]
329
+
330
+ if len(output_block_list) > 1:
331
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
332
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
333
+
334
+ resnet_0_paths = renew_resnet_paths(resnets)
335
+ paths = renew_resnet_paths(resnets)
336
+
337
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
338
+ assign_to_checkpoint(
339
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
340
+ )
341
+
342
+ # オリジナル:
343
+ # if ["conv.weight", "conv.bias"] in output_block_list.values():
344
+ # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
345
+
346
+ # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
347
+ for l in output_block_list.values():
348
+ l.sort()
349
+
350
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
351
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
352
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
353
+ f"output_blocks.{i}.{index}.conv.bias"
354
+ ]
355
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
356
+ f"output_blocks.{i}.{index}.conv.weight"
357
+ ]
358
+
359
+ # Clear attentions as they have been attributed above.
360
+ if len(attentions) == 2:
361
+ attentions = []
362
+
363
+ if len(attentions):
364
+ paths = renew_attention_paths(attentions)
365
+ meta_path = {
366
+ "old": f"output_blocks.{i}.1",
367
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
368
+ }
369
+ assign_to_checkpoint(
370
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
371
+ )
372
+ else:
373
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
374
+ for path in resnet_0_paths:
375
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
376
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
377
+
378
+ new_checkpoint[new_path] = unet_state_dict[old_path]
379
+
380
+ # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
381
+ if v2:
382
+ linear_transformer_to_conv(new_checkpoint)
383
+
384
+ return new_checkpoint
385
+
386
+
387
+ def convert_ldm_vae_checkpoint(checkpoint, config):
388
+ # extract state dict for VAE
389
+ vae_state_dict = {}
390
+ vae_key = "first_stage_model."
391
+ keys = list(checkpoint.keys())
392
+ for key in keys:
393
+ if key.startswith(vae_key):
394
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
395
+ # if len(vae_state_dict) == 0:
396
+ # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
397
+ # vae_state_dict = checkpoint
398
+
399
+ new_checkpoint = {}
400
+
401
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
402
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
403
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
404
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
405
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
406
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
407
+
408
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
409
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
410
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
411
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
412
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
413
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
414
+
415
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
416
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
417
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
418
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
419
+
420
+ # Retrieves the keys for the encoder down blocks only
421
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
422
+ down_blocks = {
423
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
424
+ }
425
+
426
+ # Retrieves the keys for the decoder up blocks only
427
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
428
+ up_blocks = {
429
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
430
+ }
431
+
432
+ for i in range(num_down_blocks):
433
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
434
+
435
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
436
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
437
+ f"encoder.down.{i}.downsample.conv.weight"
438
+ )
439
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
440
+ f"encoder.down.{i}.downsample.conv.bias"
441
+ )
442
+
443
+ paths = renew_vae_resnet_paths(resnets)
444
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
445
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
446
+
447
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
448
+ num_mid_res_blocks = 2
449
+ for i in range(1, num_mid_res_blocks + 1):
450
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
451
+
452
+ paths = renew_vae_resnet_paths(resnets)
453
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
454
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
455
+
456
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
457
+ paths = renew_vae_attention_paths(mid_attentions)
458
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
459
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
460
+ conv_attn_to_linear(new_checkpoint)
461
+
462
+ for i in range(num_up_blocks):
463
+ block_id = num_up_blocks - 1 - i
464
+ resnets = [
465
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
466
+ ]
467
+
468
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
469
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
470
+ f"decoder.up.{block_id}.upsample.conv.weight"
471
+ ]
472
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
473
+ f"decoder.up.{block_id}.upsample.conv.bias"
474
+ ]
475
+
476
+ paths = renew_vae_resnet_paths(resnets)
477
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
478
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
479
+
480
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
481
+ num_mid_res_blocks = 2
482
+ for i in range(1, num_mid_res_blocks + 1):
483
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
484
+
485
+ paths = renew_vae_resnet_paths(resnets)
486
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
487
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
488
+
489
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
490
+ paths = renew_vae_attention_paths(mid_attentions)
491
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
492
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
493
+ conv_attn_to_linear(new_checkpoint)
494
+ return new_checkpoint
495
+
496
+
497
+ def create_unet_diffusers_config(v2):
498
+ """
499
+ Creates a config for the diffusers based on the config of the LDM model.
500
+ """
501
+ # unet_params = original_config.model.params.unet_config.params
502
+
503
+ block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
504
+
505
+ down_block_types = []
506
+ resolution = 1
507
+ for i in range(len(block_out_channels)):
508
+ block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
509
+ down_block_types.append(block_type)
510
+ if i != len(block_out_channels) - 1:
511
+ resolution *= 2
512
+
513
+ up_block_types = []
514
+ for i in range(len(block_out_channels)):
515
+ block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
516
+ up_block_types.append(block_type)
517
+ resolution //= 2
518
+
519
+ config = dict(
520
+ sample_size=UNET_PARAMS_IMAGE_SIZE,
521
+ in_channels=UNET_PARAMS_IN_CHANNELS,
522
+ out_channels=UNET_PARAMS_OUT_CHANNELS,
523
+ down_block_types=tuple(down_block_types),
524
+ up_block_types=tuple(up_block_types),
525
+ block_out_channels=tuple(block_out_channels),
526
+ layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
527
+ cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
528
+ attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
529
+ )
530
+
531
+ return config
532
+
533
+
534
+ def create_vae_diffusers_config():
535
+ """
536
+ Creates a config for the diffusers based on the config of the LDM model.
537
+ """
538
+ # vae_params = original_config.model.params.first_stage_config.params.ddconfig
539
+ # _ = original_config.model.params.first_stage_config.params.embed_dim
540
+ block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
541
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
542
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
543
+
544
+ config = dict(
545
+ sample_size=VAE_PARAMS_RESOLUTION,
546
+ in_channels=VAE_PARAMS_IN_CHANNELS,
547
+ out_channels=VAE_PARAMS_OUT_CH,
548
+ down_block_types=tuple(down_block_types),
549
+ up_block_types=tuple(up_block_types),
550
+ block_out_channels=tuple(block_out_channels),
551
+ latent_channels=VAE_PARAMS_Z_CHANNELS,
552
+ layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
553
+ )
554
+ return config
555
+
556
+
557
+ def convert_ldm_clip_checkpoint_v1(checkpoint):
558
+ keys = list(checkpoint.keys())
559
+ text_model_dict = {}
560
+ for key in keys:
561
+ if key.startswith("cond_stage_model.transformer"):
562
+ text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
563
+ return text_model_dict
564
+
565
+
566
+ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
567
+ # 嫌になるくらい違うぞ!
568
+ def convert_key(key):
569
+ if not key.startswith("cond_stage_model"):
570
+ return None
571
+
572
+ # common conversion
573
+ key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
574
+ key = key.replace("cond_stage_model.model.", "text_model.")
575
+
576
+ if "resblocks" in key:
577
+ # resblocks conversion
578
+ key = key.replace(".resblocks.", ".layers.")
579
+ if ".ln_" in key:
580
+ key = key.replace(".ln_", ".layer_norm")
581
+ elif ".mlp." in key:
582
+ key = key.replace(".c_fc.", ".fc1.")
583
+ key = key.replace(".c_proj.", ".fc2.")
584
+ elif '.attn.out_proj' in key:
585
+ key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
586
+ elif '.attn.in_proj' in key:
587
+ key = None # 特殊なので後で処理する
588
+ else:
589
+ raise ValueError(f"unexpected key in SD: {key}")
590
+ elif '.positional_embedding' in key:
591
+ key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
592
+ elif '.text_projection' in key:
593
+ key = None # 使われない???
594
+ elif '.logit_scale' in key:
595
+ key = None # 使われない???
596
+ elif '.token_embedding' in key:
597
+ key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
598
+ elif '.ln_final' in key:
599
+ key = key.replace(".ln_final", ".final_layer_norm")
600
+ return key
601
+
602
+ keys = list(checkpoint.keys())
603
+ new_sd = {}
604
+ for key in keys:
605
+ # remove resblocks 23
606
+ if '.resblocks.23.' in key:
607
+ continue
608
+ new_key = convert_key(key)
609
+ if new_key is None:
610
+ continue
611
+ new_sd[new_key] = checkpoint[key]
612
+
613
+ # attnの変換
614
+ for key in keys:
615
+ if '.resblocks.23.' in key:
616
+ continue
617
+ if '.resblocks' in key and '.attn.in_proj_' in key:
618
+ # 三つに分割
619
+ values = torch.chunk(checkpoint[key], 3)
620
+
621
+ key_suffix = ".weight" if "weight" in key else ".bias"
622
+ key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
623
+ key_pfx = key_pfx.replace("_weight", "")
624
+ key_pfx = key_pfx.replace("_bias", "")
625
+ key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
626
+ new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
627
+ new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
628
+ new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
629
+
630
+ # rename or add position_ids
631
+ ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
632
+ if ANOTHER_POSITION_IDS_KEY in new_sd:
633
+ # waifu diffusion v1.4
634
+ position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
635
+ del new_sd[ANOTHER_POSITION_IDS_KEY]
636
+ else:
637
+ position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
638
+
639
+ new_sd["text_model.embeddings.position_ids"] = position_ids
640
+ return new_sd
641
+
642
+ # endregion
643
+
644
+
645
+ # region Diffusers->StableDiffusion の変換コード
646
+ # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
647
+
648
+ def conv_transformer_to_linear(checkpoint):
649
+ keys = list(checkpoint.keys())
650
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
651
+ for key in keys:
652
+ if ".".join(key.split(".")[-2:]) in tf_keys:
653
+ if checkpoint[key].ndim > 2:
654
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
655
+
656
+
657
+ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
658
+ unet_conversion_map = [
659
+ # (stable-diffusion, HF Diffusers)
660
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
661
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
662
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
663
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
664
+ ("input_blocks.0.0.weight", "conv_in.weight"),
665
+ ("input_blocks.0.0.bias", "conv_in.bias"),
666
+ ("out.0.weight", "conv_norm_out.weight"),
667
+ ("out.0.bias", "conv_norm_out.bias"),
668
+ ("out.2.weight", "conv_out.weight"),
669
+ ("out.2.bias", "conv_out.bias"),
670
+ ]
671
+
672
+ unet_conversion_map_resnet = [
673
+ # (stable-diffusion, HF Diffusers)
674
+ ("in_layers.0", "norm1"),
675
+ ("in_layers.2", "conv1"),
676
+ ("out_layers.0", "norm2"),
677
+ ("out_layers.3", "conv2"),
678
+ ("emb_layers.1", "time_emb_proj"),
679
+ ("skip_connection", "conv_shortcut"),
680
+ ]
681
+
682
+ unet_conversion_map_layer = []
683
+ for i in range(4):
684
+ # loop over downblocks/upblocks
685
+
686
+ for j in range(2):
687
+ # loop over resnets/attentions for downblocks
688
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
689
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
690
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
691
+
692
+ if i < 3:
693
+ # no attention layers in down_blocks.3
694
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
695
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
696
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
697
+
698
+ for j in range(3):
699
+ # loop over resnets/attentions for upblocks
700
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
701
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
702
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
703
+
704
+ if i > 0:
705
+ # no attention layers in up_blocks.0
706
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
707
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
708
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
709
+
710
+ if i < 3:
711
+ # no downsample in down_blocks.3
712
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
713
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
714
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
715
+
716
+ # no upsample in up_blocks.3
717
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
718
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
719
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
720
+
721
+ hf_mid_atn_prefix = "mid_block.attentions.0."
722
+ sd_mid_atn_prefix = "middle_block.1."
723
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
724
+
725
+ for j in range(2):
726
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
727
+ sd_mid_res_prefix = f"middle_block.{2*j}."
728
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
729
+
730
+ # buyer beware: this is a *brittle* function,
731
+ # and correct output requires that all of these pieces interact in
732
+ # the exact order in which I have arranged them.
733
+ mapping = {k: k for k in unet_state_dict.keys()}
734
+ for sd_name, hf_name in unet_conversion_map:
735
+ mapping[hf_name] = sd_name
736
+ for k, v in mapping.items():
737
+ if "resnets" in k:
738
+ for sd_part, hf_part in unet_conversion_map_resnet:
739
+ v = v.replace(hf_part, sd_part)
740
+ mapping[k] = v
741
+ for k, v in mapping.items():
742
+ for sd_part, hf_part in unet_conversion_map_layer:
743
+ v = v.replace(hf_part, sd_part)
744
+ mapping[k] = v
745
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
746
+
747
+ if v2:
748
+ conv_transformer_to_linear(new_state_dict)
749
+
750
+ return new_state_dict
751
+
752
+
753
+ # ================#
754
+ # VAE Conversion #
755
+ # ================#
756
+
757
+ def reshape_weight_for_sd(w):
758
+ # convert HF linear weights to SD conv2d weights
759
+ return w.reshape(*w.shape, 1, 1)
760
+
761
+
762
+ def convert_vae_state_dict(vae_state_dict):
763
+ vae_conversion_map = [
764
+ # (stable-diffusion, HF Diffusers)
765
+ ("nin_shortcut", "conv_shortcut"),
766
+ ("norm_out", "conv_norm_out"),
767
+ ("mid.attn_1.", "mid_block.attentions.0."),
768
+ ]
769
+
770
+ for i in range(4):
771
+ # down_blocks have two resnets
772
+ for j in range(2):
773
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
774
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
775
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
776
+
777
+ if i < 3:
778
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
779
+ sd_downsample_prefix = f"down.{i}.downsample."
780
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
781
+
782
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
783
+ sd_upsample_prefix = f"up.{3-i}.upsample."
784
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
785
+
786
+ # up_blocks have three resnets
787
+ # also, up blocks in hf are numbered in reverse from sd
788
+ for j in range(3):
789
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
790
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
791
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
792
+
793
+ # this part accounts for mid blocks in both the encoder and the decoder
794
+ for i in range(2):
795
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
796
+ sd_mid_res_prefix = f"mid.block_{i+1}."
797
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
798
+
799
+ vae_conversion_map_attn = [
800
+ # (stable-diffusion, HF Diffusers)
801
+ ("norm.", "group_norm."),
802
+ ("q.", "query."),
803
+ ("k.", "key."),
804
+ ("v.", "value."),
805
+ ("proj_out.", "proj_attn."),
806
+ ]
807
+
808
+ mapping = {k: k for k in vae_state_dict.keys()}
809
+ for k, v in mapping.items():
810
+ for sd_part, hf_part in vae_conversion_map:
811
+ v = v.replace(hf_part, sd_part)
812
+ mapping[k] = v
813
+ for k, v in mapping.items():
814
+ if "attentions" in k:
815
+ for sd_part, hf_part in vae_conversion_map_attn:
816
+ v = v.replace(hf_part, sd_part)
817
+ mapping[k] = v
818
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
819
+ weights_to_convert = ["q", "k", "v", "proj_out"]
820
+ for k, v in new_state_dict.items():
821
+ for weight_name in weights_to_convert:
822
+ if f"mid.attn_1.{weight_name}.weight" in k:
823
+ # print(f"Reshaping {k} for SD format")
824
+ new_state_dict[k] = reshape_weight_for_sd(v)
825
+
826
+ return new_state_dict
827
+
828
+
829
+ # endregion
830
+
831
+ # region 自作のモデル読み書きなど
832
+
833
+ def is_safetensors(path):
834
+ return os.path.splitext(path)[1].lower() == '.safetensors'
835
+
836
+
837
+ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
838
+ # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
839
+ TEXT_ENCODER_KEY_REPLACEMENTS = [
840
+ ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
841
+ ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
842
+ ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
843
+ ]
844
+
845
+ if is_safetensors(ckpt_path):
846
+ checkpoint = None
847
+ state_dict = load_file(ckpt_path, "cpu")
848
+ else:
849
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
850
+ if "state_dict" in checkpoint:
851
+ state_dict = checkpoint["state_dict"]
852
+ else:
853
+ state_dict = checkpoint
854
+ checkpoint = None
855
+
856
+ key_reps = []
857
+ for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
858
+ for key in state_dict.keys():
859
+ if key.startswith(rep_from):
860
+ new_key = rep_to + key[len(rep_from):]
861
+ key_reps.append((key, new_key))
862
+
863
+ for key, new_key in key_reps:
864
+ state_dict[new_key] = state_dict[key]
865
+ del state_dict[key]
866
+
867
+ return checkpoint, state_dict
868
+
869
+
870
+ # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
871
+ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
872
+ _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
873
+ if dtype is not None:
874
+ for k, v in state_dict.items():
875
+ if type(v) is torch.Tensor:
876
+ state_dict[k] = v.to(dtype)
877
+
878
+ # Convert the UNet2DConditionModel model.
879
+ unet_config = create_unet_diffusers_config(v2)
880
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
881
+
882
+ unet = UNet2DConditionModel(**unet_config)
883
+ info = unet.load_state_dict(converted_unet_checkpoint)
884
+ print("loading u-net:", info)
885
+
886
+ # Convert the VAE model.
887
+ vae_config = create_vae_diffusers_config()
888
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
889
+
890
+ vae = AutoencoderKL(**vae_config)
891
+ info = vae.load_state_dict(converted_vae_checkpoint)
892
+ print("loading vae:", info)
893
+
894
+ # convert text_model
895
+ if v2:
896
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
897
+ cfg = CLIPTextConfig(
898
+ vocab_size=49408,
899
+ hidden_size=1024,
900
+ intermediate_size=4096,
901
+ num_hidden_layers=23,
902
+ num_attention_heads=16,
903
+ max_position_embeddings=77,
904
+ hidden_act="gelu",
905
+ layer_norm_eps=1e-05,
906
+ dropout=0.0,
907
+ attention_dropout=0.0,
908
+ initializer_range=0.02,
909
+ initializer_factor=1.0,
910
+ pad_token_id=1,
911
+ bos_token_id=0,
912
+ eos_token_id=2,
913
+ model_type="clip_text_model",
914
+ projection_dim=512,
915
+ torch_dtype="float32",
916
+ transformers_version="4.25.0.dev0",
917
+ )
918
+ text_model = CLIPTextModel._from_config(cfg)
919
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
920
+ else:
921
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
922
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
923
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
924
+ print("loading text encoder:", info)
925
+
926
+ return text_model, vae, unet
927
+
928
+
929
+ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
930
+ def convert_key(key):
931
+ # position_idsの除去
932
+ if ".position_ids" in key:
933
+ return None
934
+
935
+ # common
936
+ key = key.replace("text_model.encoder.", "transformer.")
937
+ key = key.replace("text_model.", "")
938
+ if "layers" in key:
939
+ # resblocks conversion
940
+ key = key.replace(".layers.", ".resblocks.")
941
+ if ".layer_norm" in key:
942
+ key = key.replace(".layer_norm", ".ln_")
943
+ elif ".mlp." in key:
944
+ key = key.replace(".fc1.", ".c_fc.")
945
+ key = key.replace(".fc2.", ".c_proj.")
946
+ elif '.self_attn.out_proj' in key:
947
+ key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
948
+ elif '.self_attn.' in key:
949
+ key = None # 特殊なので後で処理する
950
+ else:
951
+ raise ValueError(f"unexpected key in DiffUsers model: {key}")
952
+ elif '.position_embedding' in key:
953
+ key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
954
+ elif '.token_embedding' in key:
955
+ key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
956
+ elif 'final_layer_norm' in key:
957
+ key = key.replace("final_layer_norm", "ln_final")
958
+ return key
959
+
960
+ keys = list(checkpoint.keys())
961
+ new_sd = {}
962
+ for key in keys:
963
+ new_key = convert_key(key)
964
+ if new_key is None:
965
+ continue
966
+ new_sd[new_key] = checkpoint[key]
967
+
968
+ # attnの変換
969
+ for key in keys:
970
+ if 'layers' in key and 'q_proj' in key:
971
+ # 三つを結合
972
+ key_q = key
973
+ key_k = key.replace("q_proj", "k_proj")
974
+ key_v = key.replace("q_proj", "v_proj")
975
+
976
+ value_q = checkpoint[key_q]
977
+ value_k = checkpoint[key_k]
978
+ value_v = checkpoint[key_v]
979
+ value = torch.cat([value_q, value_k, value_v])
980
+
981
+ new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
982
+ new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
983
+ new_sd[new_key] = value
984
+
985
+ # 最後の層などを捏造するか
986
+ if make_dummy_weights:
987
+ print("make dummy weights for resblock.23, text_projection and logit scale.")
988
+ keys = list(new_sd.keys())
989
+ for key in keys:
990
+ if key.startswith("transformer.resblocks.22."):
991
+ new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
992
+
993
+ # Diffusersに含まれない重みを作っておく
994
+ new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
995
+ new_sd['logit_scale'] = torch.tensor(1)
996
+
997
+ return new_sd
998
+
999
+
1000
+ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
1001
+ if ckpt_path is not None:
1002
+ # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
1003
+ checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
1004
+ if checkpoint is None: # safetensors または state_dictのckpt
1005
+ checkpoint = {}
1006
+ strict = False
1007
+ else:
1008
+ strict = True
1009
+ if "state_dict" in state_dict:
1010
+ del state_dict["state_dict"]
1011
+ else:
1012
+ # 新しく作る
1013
+ assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
1014
+ checkpoint = {}
1015
+ state_dict = {}
1016
+ strict = False
1017
+
1018
+ def update_sd(prefix, sd):
1019
+ for k, v in sd.items():
1020
+ key = prefix + k
1021
+ assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
1022
+ if save_dtype is not None:
1023
+ v = v.detach().clone().to("cpu").to(save_dtype)
1024
+ state_dict[key] = v
1025
+
1026
+ # Convert the UNet model
1027
+ unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
1028
+ update_sd("model.diffusion_model.", unet_state_dict)
1029
+
1030
+ # Convert the text encoder model
1031
+ if v2:
1032
+ make_dummy = ckpt_path is None # 参照元のcheckpoint���ない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
1033
+ text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
1034
+ update_sd("cond_stage_model.model.", text_enc_dict)
1035
+ else:
1036
+ text_enc_dict = text_encoder.state_dict()
1037
+ update_sd("cond_stage_model.transformer.", text_enc_dict)
1038
+
1039
+ # Convert the VAE
1040
+ if vae is not None:
1041
+ vae_dict = convert_vae_state_dict(vae.state_dict())
1042
+ update_sd("first_stage_model.", vae_dict)
1043
+
1044
+ # Put together new checkpoint
1045
+ key_count = len(state_dict.keys())
1046
+ new_ckpt = {'state_dict': state_dict}
1047
+
1048
+ if 'epoch' in checkpoint:
1049
+ epochs += checkpoint['epoch']
1050
+ if 'global_step' in checkpoint:
1051
+ steps += checkpoint['global_step']
1052
+
1053
+ new_ckpt['epoch'] = epochs
1054
+ new_ckpt['global_step'] = steps
1055
+
1056
+ if is_safetensors(output_file):
1057
+ # TODO Tensor以外のdictの値を削除したほうがいいか
1058
+ save_file(state_dict, output_file)
1059
+ else:
1060
+ torch.save(new_ckpt, output_file)
1061
+
1062
+ return key_count
1063
+
1064
+
1065
+ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
1066
+ if pretrained_model_name_or_path is None:
1067
+ # load default settings for v1/v2
1068
+ if v2:
1069
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
1070
+ else:
1071
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
1072
+
1073
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
1074
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
1075
+ if vae is None:
1076
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
1077
+
1078
+ pipeline = StableDiffusionPipeline(
1079
+ unet=unet,
1080
+ text_encoder=text_encoder,
1081
+ vae=vae,
1082
+ scheduler=scheduler,
1083
+ tokenizer=tokenizer,
1084
+ safety_checker=None,
1085
+ feature_extractor=None,
1086
+ requires_safety_checker=None,
1087
+ )
1088
+ pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
1089
+
1090
+
1091
+ VAE_PREFIX = "first_stage_model."
1092
+
1093
+
1094
+ def load_vae(vae_id, dtype):
1095
+ print(f"load VAE: {vae_id}")
1096
+ if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
1097
+ # Diffusers local/remote
1098
+ try:
1099
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
1100
+ except EnvironmentError as e:
1101
+ print(f"exception occurs in loading vae: {e}")
1102
+ print("retry with subfolder='vae'")
1103
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
1104
+ return vae
1105
+
1106
+ # local
1107
+ vae_config = create_vae_diffusers_config()
1108
+
1109
+ if vae_id.endswith(".bin"):
1110
+ # SD 1.5 VAE on Huggingface
1111
+ converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
1112
+ else:
1113
+ # StableDiffusion
1114
+ vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
1115
+ else torch.load(vae_id, map_location="cpu"))
1116
+ vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
1117
+
1118
+ # vae only or full model
1119
+ full_model = False
1120
+ for vae_key in vae_sd:
1121
+ if vae_key.startswith(VAE_PREFIX):
1122
+ full_model = True
1123
+ break
1124
+ if not full_model:
1125
+ sd = {}
1126
+ for key, value in vae_sd.items():
1127
+ sd[VAE_PREFIX + key] = value
1128
+ vae_sd = sd
1129
+ del sd
1130
+
1131
+ # Convert the VAE model.
1132
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
1133
+
1134
+ vae = AutoencoderKL(**vae_config)
1135
+ vae.load_state_dict(converted_vae_checkpoint)
1136
+ return vae
1137
+
1138
+ # endregion
1139
+
1140
+
1141
+ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
1142
+ max_width, max_height = max_reso
1143
+ max_area = (max_width // divisible) * (max_height // divisible)
1144
+
1145
+ resos = set()
1146
+
1147
+ size = int(math.sqrt(max_area)) * divisible
1148
+ resos.add((size, size))
1149
+
1150
+ size = min_size
1151
+ while size <= max_size:
1152
+ width = size
1153
+ height = min(max_size, (max_area // (width // divisible)) * divisible)
1154
+ resos.add((width, height))
1155
+ resos.add((height, width))
1156
+
1157
+ # # make additional resos
1158
+ # if width >= height and width - divisible >= min_size:
1159
+ # resos.add((width - divisible, height))
1160
+ # resos.add((height, width - divisible))
1161
+ # if height >= width and height - divisible >= min_size:
1162
+ # resos.add((width, height - divisible))
1163
+ # resos.add((height - divisible, width))
1164
+
1165
+ size += divisible
1166
+
1167
+ resos = list(resos)
1168
+ resos.sort()
1169
+
1170
+ aspect_ratios = [w / h for w, h in resos]
1171
+ return resos, aspect_ratios
1172
+
1173
+
1174
+ if __name__ == '__main__':
1175
+ resos, aspect_ratios = make_bucket_resolutions((512, 768))
1176
+ print(len(resos))
1177
+ print(resos)
1178
+ print(aspect_ratios)
1179
+
1180
+ ars = set()
1181
+ for ar in aspect_ratios:
1182
+ if ar in ars:
1183
+ print("error! duplicate ar:", ar)
1184
+ ars.add(ar)
lycoris/kohya_utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # part of https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py
2
+
3
+ import hashlib
4
+ import safetensors
5
+ from io import BytesIO
6
+
7
+
8
+ def addnet_hash_legacy(b):
9
+ """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
10
+ m = hashlib.sha256()
11
+
12
+ b.seek(0x100000)
13
+ m.update(b.read(0x10000))
14
+ return m.hexdigest()[0:8]
15
+
16
+
17
+ def addnet_hash_safetensors(b):
18
+ """New model hash used by sd-webui-additional-networks for .safetensors format files"""
19
+ hash_sha256 = hashlib.sha256()
20
+ blksize = 1024 * 1024
21
+
22
+ b.seek(0)
23
+ header = b.read(8)
24
+ n = int.from_bytes(header, "little")
25
+
26
+ offset = n + 8
27
+ b.seek(offset)
28
+ for chunk in iter(lambda: b.read(blksize), b""):
29
+ hash_sha256.update(chunk)
30
+
31
+ return hash_sha256.hexdigest()
32
+
33
+
34
+ def precalculate_safetensors_hashes(tensors, metadata):
35
+ """Precalculate the model hashes needed by sd-webui-additional-networks to
36
+ save time on indexing the model later."""
37
+
38
+ # Because writing user metadata to the file can change the result of
39
+ # sd_models.model_hash(), only retain the training metadata for purposes of
40
+ # calculating the hash, as they are meant to be immutable
41
+ metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
42
+
43
+ bytes = safetensors.torch.save(tensors, metadata)
44
+ b = BytesIO(bytes)
45
+
46
+ model_hash = addnet_hash_safetensors(b)
47
+ legacy_hash = addnet_hash_legacy(b)
48
+ return model_hash, legacy_hash
lycoris/locon.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class LoConModule(nn.Module):
9
+ """
10
+ modifed from kohya-ss/sd-scripts/networks/lora:LoRAModule
11
+ """
12
+
13
+ def __init__(self, lora_name, org_module: nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=0.):
14
+ """ if alpha == 0 or None, alpha is rank (no scaling). """
15
+ super().__init__()
16
+ self.lora_name = lora_name
17
+ self.lora_dim = lora_dim
18
+
19
+ if org_module.__class__.__name__ == 'Conv2d':
20
+ # For general LoCon
21
+ in_dim = org_module.in_channels
22
+ k_size = org_module.kernel_size
23
+ stride = org_module.stride
24
+ padding = org_module.padding
25
+ out_dim = org_module.out_channels
26
+ self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False)
27
+ self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
28
+ self.op = F.conv2d
29
+ self.extra_args = {
30
+ 'stride': stride,
31
+ 'padding': padding
32
+ }
33
+ else:
34
+ in_dim = org_module.in_features
35
+ out_dim = org_module.out_features
36
+ self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
37
+ self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
38
+ self.op = F.linear
39
+ self.extra_args = {}
40
+ self.shape = org_module.weight.shape
41
+
42
+ if dropout:
43
+ self.dropout = nn.Dropout(dropout)
44
+ else:
45
+ self.dropout = nn.Identity()
46
+
47
+ if type(alpha) == torch.Tensor:
48
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
49
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
50
+ self.scale = alpha / self.lora_dim
51
+ self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
52
+
53
+ # same as microsoft's
54
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
55
+ torch.nn.init.zeros_(self.lora_up.weight)
56
+
57
+ self.multiplier = multiplier
58
+ self.org_module = [org_module]
59
+
60
+ def apply_to(self):
61
+ self.org_module[0].forward = self.forward
62
+
63
+ def make_weight(self):
64
+ wa = self.lora_up.weight
65
+ wb = self.lora_down.weight
66
+ return (wa.view(wa.size(0), -1) @ wb.view(wb.size(0), -1)).view(self.shape)
67
+
68
+ def forward(self, x):
69
+ bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data
70
+ return self.op(
71
+ x,
72
+ (self.org_module[0].weight.data
73
+ + self.dropout(self.make_weight()) * self.multiplier * self.scale),
74
+ bias,
75
+ **self.extra_args,
76
+ )
lycoris/loha.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class HadaWeight(torch.autograd.Function):
9
+ @staticmethod
10
+ def forward(ctx, orig_weight, w1a, w1b, w2a, w2b, scale=torch.tensor(1), dropout=nn.Identity()):
11
+ ctx.save_for_backward(w1a, w1b, w2a, w2b, scale)
12
+ diff_weight = ((w1a@w1b)*(w2a@w2b)) * scale
13
+ return orig_weight.reshape(diff_weight.shape) + dropout(diff_weight)
14
+
15
+ @staticmethod
16
+ def backward(ctx, grad_out):
17
+ (w1a, w1b, w2a, w2b, scale) = ctx.saved_tensors
18
+ temp = grad_out*(w2a@w2b)*scale
19
+ grad_w1a = temp @ w1b.T
20
+ grad_w1b = w1a.T @ temp
21
+
22
+ temp = grad_out * (w1a@w1b)*scale
23
+ grad_w2a = temp @ w2b.T
24
+ grad_w2b = w2a.T @ temp
25
+
26
+ del temp
27
+ return grad_out, grad_w1a, grad_w1b, grad_w2a, grad_w2b, None
28
+
29
+
30
+ def make_weight(orig_weight, w1a, w1b, w2a, w2b, scale):
31
+ return HadaWeight.apply(orig_weight, w1a, w1b, w2a, w2b, scale)
32
+
33
+
34
+ class LohaModule(nn.Module):
35
+ """
36
+ Hadamard product Implementaion for Low Rank Adaptation
37
+ """
38
+
39
+ def __init__(self, lora_name, org_module: nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=0.):
40
+ """ if alpha == 0 or None, alpha is rank (no scaling). """
41
+ super().__init__()
42
+ self.lora_name = lora_name
43
+ self.lora_dim = lora_dim
44
+
45
+ self.shape = org_module.weight.shape
46
+ if org_module.__class__.__name__ == 'Conv2d':
47
+ in_dim = org_module.in_channels
48
+ k_size = org_module.kernel_size
49
+ out_dim = org_module.out_channels
50
+ shape = (out_dim, in_dim*k_size[0]*k_size[1])
51
+ self.op = F.conv2d
52
+ self.extra_args = {
53
+ "stride": org_module.stride,
54
+ "padding": org_module.padding,
55
+ "dilation": org_module.dilation,
56
+ "groups": org_module.groups
57
+ }
58
+ else:
59
+ in_dim = org_module.in_features
60
+ out_dim = org_module.out_features
61
+ shape = (out_dim, in_dim)
62
+ self.op = F.linear
63
+ self.extra_args = {}
64
+
65
+ self.hada_w1_a = nn.Parameter(torch.empty(shape[0], lora_dim))
66
+ self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1]))
67
+
68
+ self.hada_w2_a = nn.Parameter(torch.empty(shape[0], lora_dim))
69
+ self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1]))
70
+
71
+ if dropout:
72
+ self.dropout = nn.Dropout(dropout)
73
+ else:
74
+ self.dropout = nn.Identity()
75
+
76
+ if type(alpha) == torch.Tensor:
77
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
78
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
79
+ self.scale = alpha / self.lora_dim
80
+ self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
81
+
82
+ # Need more experiences on init method
83
+ torch.nn.init.normal_(self.hada_w1_b, std=1)
84
+ torch.nn.init.normal_(self.hada_w2_b, std=0.05)
85
+ torch.nn.init.normal_(self.hada_w1_a, std=1)
86
+ torch.nn.init.constant_(self.hada_w2_a, 0)
87
+
88
+ self.multiplier = multiplier
89
+ self.org_module = [org_module] # remove in applying
90
+ self.grad_ckpt = False
91
+
92
+ def apply_to(self):
93
+ self.org_module[0].forward = self.forward
94
+
95
+ def get_weight(self):
96
+ d_weight = self.hada_w1_a @ self.hada_w1_b
97
+ d_weight *= self.hada_w2_a @ self.hada_w2_b
98
+ return (d_weight).reshape(self.shape)
99
+
100
+ @torch.enable_grad()
101
+ def forward(self, x):
102
+ # print(torch.mean(torch.abs(self.orig_w1a.to(x.device) - self.hada_w1_a)), end='\r')
103
+ weight = make_weight(
104
+ self.org_module[0].weight.data,
105
+ self.hada_w1_a, self.hada_w1_b,
106
+ self.hada_w2_a, self.hada_w2_b,
107
+ scale = torch.tensor(self.scale*self.multiplier),
108
+ )
109
+
110
+ bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data
111
+ return self.op(
112
+ x,
113
+ weight.view(self.shape),
114
+ bias,
115
+ **self.extra_args
116
+ )
lycoris/utils.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ import torch.linalg as linalg
8
+
9
+ from tqdm import tqdm
10
+
11
+
12
+ def extract_conv(
13
+ weight: Union[torch.Tensor, nn.Parameter],
14
+ mode = 'fixed',
15
+ mode_param = 0,
16
+ device = 'cpu',
17
+ ) -> Tuple[nn.Parameter, nn.Parameter]:
18
+ out_ch, in_ch, kernel_size, _ = weight.shape
19
+
20
+ U, S, Vh = linalg.svd(weight.reshape(out_ch, -1).to(device))
21
+
22
+ if mode=='fixed':
23
+ lora_rank = mode_param
24
+ elif mode=='threshold':
25
+ assert mode_param>=0
26
+ lora_rank = torch.sum(S>mode_param)
27
+ elif mode=='ratio':
28
+ assert 1>=mode_param>=0
29
+ min_s = torch.max(S)*mode_param
30
+ lora_rank = torch.sum(S>min_s)
31
+ elif mode=='percentile':
32
+ assert 1>=mode_param>=0
33
+ s_cum = torch.cumsum(S, dim=0)
34
+ min_cum_sum = mode_param * torch.sum(S)
35
+ lora_rank = torch.sum(s_cum<min_cum_sum)
36
+ lora_rank = max(1, lora_rank)
37
+ lora_rank = min(out_ch, in_ch, lora_rank)
38
+
39
+ U = U[:, :lora_rank]
40
+ S = S[:lora_rank]
41
+ U = U @ torch.diag(S)
42
+ Vh = Vh[:lora_rank, :]
43
+
44
+ extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).cpu()
45
+ extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).cpu()
46
+ del U, S, Vh, weight
47
+ return extract_weight_A, extract_weight_B
48
+
49
+
50
+ def merge_conv(
51
+ weight_a: Union[torch.Tensor, nn.Parameter],
52
+ weight_b: Union[torch.Tensor, nn.Parameter],
53
+ device = 'cpu'
54
+ ):
55
+ rank, in_ch, kernel_size, k_ = weight_a.shape
56
+ out_ch, rank_, _, _ = weight_b.shape
57
+ assert rank == rank_ and kernel_size == k_
58
+
59
+ wa = weight_a.to(device)
60
+ wb = weight_b.to(device)
61
+
62
+ if device == 'cpu':
63
+ wa = wa.float()
64
+ wb = wb.float()
65
+
66
+ merged = wb.reshape(out_ch, -1) @ wa.reshape(rank, -1)
67
+ weight = merged.reshape(out_ch, in_ch, kernel_size, kernel_size)
68
+ del wb, wa
69
+ return weight
70
+
71
+
72
+ def extract_linear(
73
+ weight: Union[torch.Tensor, nn.Parameter],
74
+ mode = 'fixed',
75
+ mode_param = 0,
76
+ device = 'cpu',
77
+ ) -> Tuple[nn.Parameter, nn.Parameter]:
78
+ out_ch, in_ch = weight.shape
79
+
80
+ U, S, Vh = linalg.svd(weight.to(device))
81
+
82
+ if mode=='fixed':
83
+ lora_rank = mode_param
84
+ elif mode=='threshold':
85
+ assert mode_param>=0
86
+ lora_rank = torch.sum(S>mode_param)
87
+ elif mode=='ratio':
88
+ assert 1>=mode_param>=0
89
+ min_s = torch.max(S)*mode_param
90
+ lora_rank = torch.sum(S>min_s)
91
+ elif mode=='percentile':
92
+ assert 1>=mode_param>=0
93
+ s_cum = torch.cumsum(S, dim=0)
94
+ min_cum_sum = mode_param * torch.sum(S)
95
+ lora_rank = torch.sum(s_cum<min_cum_sum)
96
+ lora_rank = max(1, lora_rank)
97
+ lora_rank = min(out_ch, in_ch, lora_rank)
98
+
99
+ U = U[:, :lora_rank]
100
+ S = S[:lora_rank]
101
+ U = U @ torch.diag(S)
102
+ Vh = Vh[:lora_rank, :]
103
+
104
+ extract_weight_A = Vh.reshape(lora_rank, in_ch).cpu()
105
+ extract_weight_B = U.reshape(out_ch, lora_rank).cpu()
106
+ del U, S, Vh, weight
107
+ return extract_weight_A, extract_weight_B
108
+
109
+
110
+ def merge_linear(
111
+ weight_a: Union[torch.Tensor, nn.Parameter],
112
+ weight_b: Union[torch.Tensor, nn.Parameter],
113
+ device = 'cpu'
114
+ ):
115
+ rank, in_ch = weight_a.shape
116
+ out_ch, rank_ = weight_b.shape
117
+ assert rank == rank_
118
+
119
+ wa = weight_a.to(device)
120
+ wb = weight_b.to(device)
121
+
122
+ if device == 'cpu':
123
+ wa = wa.float()
124
+ wb = wb.float()
125
+
126
+ weight = wb @ wa
127
+ del wb, wa
128
+ return weight
129
+
130
+
131
+ def extract_diff(
132
+ base_model,
133
+ db_model,
134
+ mode = 'fixed',
135
+ linear_mode_param = 0,
136
+ conv_mode_param = 0,
137
+ extract_device = 'cpu'
138
+ ):
139
+ UNET_TARGET_REPLACE_MODULE = [
140
+ "Transformer2DModel",
141
+ "Attention",
142
+ "ResnetBlock2D",
143
+ "Downsample2D",
144
+ "Upsample2D"
145
+ ]
146
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
147
+ LORA_PREFIX_UNET = 'lora_unet'
148
+ LORA_PREFIX_TEXT_ENCODER = 'lora_te'
149
+ def make_state_dict(
150
+ prefix,
151
+ root_module: torch.nn.Module,
152
+ target_module: torch.nn.Module,
153
+ target_replace_modules
154
+ ):
155
+ loras = {}
156
+ temp = {}
157
+
158
+ for name, module in root_module.named_modules():
159
+ if module.__class__.__name__ in target_replace_modules:
160
+ temp[name] = {}
161
+ for child_name, child_module in module.named_modules():
162
+ if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}:
163
+ continue
164
+ temp[name][child_name] = child_module.weight
165
+
166
+ for name, module in tqdm(list(target_module.named_modules())):
167
+ if name in temp:
168
+ weights = temp[name]
169
+ for child_name, child_module in module.named_modules():
170
+ lora_name = prefix + '.' + name + '.' + child_name
171
+ lora_name = lora_name.replace('.', '_')
172
+
173
+ layer = child_module.__class__.__name__
174
+ if layer == 'Linear':
175
+ extract_a, extract_b = extract_linear(
176
+ (child_module.weight - weights[child_name]),
177
+ mode,
178
+ linear_mode_param,
179
+ device = extract_device,
180
+ )
181
+ elif layer == 'Conv2d':
182
+ is_linear = (child_module.weight.shape[2] == 1
183
+ and child_module.weight.shape[3] == 1)
184
+ extract_a, extract_b = extract_conv(
185
+ (child_module.weight - weights[child_name]),
186
+ mode,
187
+ linear_mode_param if is_linear else conv_mode_param,
188
+ device = extract_device,
189
+ )
190
+ else:
191
+ continue
192
+ loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half()
193
+ loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half()
194
+ loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half()
195
+ del extract_a, extract_b
196
+ return loras
197
+
198
+ text_encoder_loras = make_state_dict(
199
+ LORA_PREFIX_TEXT_ENCODER,
200
+ base_model[0], db_model[0],
201
+ TEXT_ENCODER_TARGET_REPLACE_MODULE
202
+ )
203
+
204
+ unet_loras = make_state_dict(
205
+ LORA_PREFIX_UNET,
206
+ base_model[2], db_model[2],
207
+ UNET_TARGET_REPLACE_MODULE
208
+ )
209
+ print(len(text_encoder_loras), len(unet_loras))
210
+ return text_encoder_loras|unet_loras
211
+
212
+
213
+ def merge_locon(
214
+ base_model,
215
+ locon_state_dict: Dict[str, torch.TensorType],
216
+ scale: float = 1.0,
217
+ device = 'cpu'
218
+ ):
219
+ UNET_TARGET_REPLACE_MODULE = [
220
+ "Transformer2DModel",
221
+ "Attention",
222
+ "ResnetBlock2D",
223
+ "Downsample2D",
224
+ "Upsample2D"
225
+ ]
226
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
227
+ LORA_PREFIX_UNET = 'lora_unet'
228
+ LORA_PREFIX_TEXT_ENCODER = 'lora_te'
229
+ def merge(
230
+ prefix,
231
+ root_module: torch.nn.Module,
232
+ target_replace_modules
233
+ ):
234
+ temp = {}
235
+
236
+ for name, module in tqdm(list(root_module.named_modules())):
237
+ if module.__class__.__name__ in target_replace_modules:
238
+ temp[name] = {}
239
+ for child_name, child_module in module.named_modules():
240
+ layer = child_module.__class__.__name__
241
+ if layer not in {'Linear', 'Conv2d'}:
242
+ continue
243
+ lora_name = prefix + '.' + name + '.' + child_name
244
+ lora_name = lora_name.replace('.', '_')
245
+
246
+ down = locon_state_dict[f'{lora_name}.lora_down.weight'].float()
247
+ up = locon_state_dict[f'{lora_name}.lora_up.weight'].float()
248
+ alpha = locon_state_dict[f'{lora_name}.alpha'].float()
249
+ rank = down.shape[0]
250
+
251
+ if layer == 'Conv2d':
252
+ delta = merge_conv(down, up, device)
253
+ child_module.weight.requires_grad_(False)
254
+ child_module.weight += (alpha.to(device)/rank * scale * delta).cpu()
255
+ del delta
256
+ elif layer == 'Linear':
257
+ delta = merge_linear(down, up, device)
258
+ child_module.weight.requires_grad_(False)
259
+ child_module.weight += (alpha.to(device)/rank * scale * delta).cpu()
260
+ del delta
261
+
262
+ merge(
263
+ LORA_PREFIX_TEXT_ENCODER,
264
+ base_model[0],
265
+ TEXT_ENCODER_TARGET_REPLACE_MODULE
266
+ )
267
+ merge(
268
+ LORA_PREFIX_UNET,
269
+ base_model[2],
270
+ UNET_TARGET_REPLACE_MODULE
271
+ )
networks/__pycache__/lora.cpython-310.pyc ADDED
Binary file (7.36 kB). View file
 
networks/check_lora_weights.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ from safetensors.torch import load_file
5
+
6
+
7
+ def main(file):
8
+ print(f"loading: {file}")
9
+ if os.path.splitext(file)[1] == '.safetensors':
10
+ sd = load_file(file)
11
+ else:
12
+ sd = torch.load(file, map_location='cpu')
13
+
14
+ values = []
15
+
16
+ keys = list(sd.keys())
17
+ for key in keys:
18
+ if 'lora_up' in key or 'lora_down' in key:
19
+ values.append((key, sd[key]))
20
+ print(f"number of LoRA modules: {len(values)}")
21
+
22
+ for key, value in values:
23
+ value = value.to(torch.float32)
24
+ print(f"{key},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
25
+
26
+
27
+ if __name__ == '__main__':
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル")
30
+ args = parser.parse_args()
31
+
32
+ main(args.file)
networks/extract_lora_from_models.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # extract approximating LoRA by svd from two SD models
2
+ # The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
3
+ # Thanks to cloneofsimo!
4
+
5
+ import argparse
6
+ import os
7
+ import torch
8
+ from safetensors.torch import load_file, save_file
9
+ from tqdm import tqdm
10
+ import library.model_util as model_util
11
+ import lora
12
+
13
+
14
+ CLAMP_QUANTILE = 0.99
15
+ MIN_DIFF = 1e-6
16
+
17
+
18
+ def save_to_file(file_name, model, state_dict, dtype):
19
+ if dtype is not None:
20
+ for key in list(state_dict.keys()):
21
+ if type(state_dict[key]) == torch.Tensor:
22
+ state_dict[key] = state_dict[key].to(dtype)
23
+
24
+ if os.path.splitext(file_name)[1] == '.safetensors':
25
+ save_file(model, file_name)
26
+ else:
27
+ torch.save(model, file_name)
28
+
29
+
30
+ def svd(args):
31
+ def str_to_dtype(p):
32
+ if p == 'float':
33
+ return torch.float
34
+ if p == 'fp16':
35
+ return torch.float16
36
+ if p == 'bf16':
37
+ return torch.bfloat16
38
+ return None
39
+
40
+ save_dtype = str_to_dtype(args.save_precision)
41
+
42
+ print(f"loading SD model : {args.model_org}")
43
+ text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
44
+ print(f"loading SD model : {args.model_tuned}")
45
+ text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
46
+
47
+ # create LoRA network to extract weights: Use dim (rank) as alpha
48
+ lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o)
49
+ lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t)
50
+ assert len(lora_network_o.text_encoder_loras) == len(
51
+ lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
52
+
53
+ # get diffs
54
+ diffs = {}
55
+ text_encoder_different = False
56
+ for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
57
+ lora_name = lora_o.lora_name
58
+ module_o = lora_o.org_module
59
+ module_t = lora_t.org_module
60
+ diff = module_t.weight - module_o.weight
61
+
62
+ # Text Encoder might be same
63
+ if torch.max(torch.abs(diff)) > MIN_DIFF:
64
+ text_encoder_different = True
65
+
66
+ diff = diff.float()
67
+ diffs[lora_name] = diff
68
+
69
+ if not text_encoder_different:
70
+ print("Text encoder is same. Extract U-Net only.")
71
+ lora_network_o.text_encoder_loras = []
72
+ diffs = {}
73
+
74
+ for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
75
+ lora_name = lora_o.lora_name
76
+ module_o = lora_o.org_module
77
+ module_t = lora_t.org_module
78
+ diff = module_t.weight - module_o.weight
79
+ diff = diff.float()
80
+
81
+ if args.device:
82
+ diff = diff.to(args.device)
83
+
84
+ diffs[lora_name] = diff
85
+
86
+ # make LoRA with svd
87
+ print("calculating by svd")
88
+ rank = args.dim
89
+ lora_weights = {}
90
+ with torch.no_grad():
91
+ for lora_name, mat in tqdm(list(diffs.items())):
92
+ conv2d = (len(mat.size()) == 4)
93
+ if conv2d:
94
+ mat = mat.squeeze()
95
+
96
+ U, S, Vh = torch.linalg.svd(mat)
97
+
98
+ U = U[:, :rank]
99
+ S = S[:rank]
100
+ U = U @ torch.diag(S)
101
+
102
+ Vh = Vh[:rank, :]
103
+
104
+ dist = torch.cat([U.flatten(), Vh.flatten()])
105
+ hi_val = torch.quantile(dist, CLAMP_QUANTILE)
106
+ low_val = -hi_val
107
+
108
+ U = U.clamp(low_val, hi_val)
109
+ Vh = Vh.clamp(low_val, hi_val)
110
+
111
+ lora_weights[lora_name] = (U, Vh)
112
+
113
+ # make state dict for LoRA
114
+ lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict
115
+ lora_sd = lora_network_o.state_dict()
116
+ print(f"LoRA has {len(lora_sd)} weights.")
117
+
118
+ for key in list(lora_sd.keys()):
119
+ if "alpha" in key:
120
+ continue
121
+
122
+ lora_name = key.split('.')[0]
123
+ i = 0 if "lora_up" in key else 1
124
+
125
+ weights = lora_weights[lora_name][i]
126
+ # print(key, i, weights.size(), lora_sd[key].size())
127
+ if len(lora_sd[key].size()) == 4:
128
+ weights = weights.unsqueeze(2).unsqueeze(3)
129
+
130
+ assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
131
+ lora_sd[key] = weights
132
+
133
+ # load state dict to LoRA and save it
134
+ info = lora_network_o.load_state_dict(lora_sd)
135
+ print(f"Loading extracted LoRA weights: {info}")
136
+
137
+ dir_name = os.path.dirname(args.save_to)
138
+ if dir_name and not os.path.exists(dir_name):
139
+ os.makedirs(dir_name, exist_ok=True)
140
+
141
+ # minimum metadata
142
+ metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
143
+
144
+ lora_network_o.save_weights(args.save_to, save_dtype, metadata)
145
+ print(f"LoRA weights are saved to: {args.save_to}")
146
+
147
+
148
+ if __name__ == '__main__':
149
+ parser = argparse.ArgumentParser()
150
+ parser.add_argument("--v2", action='store_true',
151
+ help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
152
+ parser.add_argument("--save_precision", type=str, default=None,
153
+ choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat")
154
+ parser.add_argument("--model_org", type=str, default=None,
155
+ help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors")
156
+ parser.add_argument("--model_tuned", type=str, default=None,
157
+ help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors")
158
+ parser.add_argument("--save_to", type=str, default=None,
159
+ help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
160
+ parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
161
+ parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
162
+
163
+ args = parser.parse_args()
164
+ svd(args)
networks/lora.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA network module
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+
6
+ import math
7
+ import os
8
+ from typing import List
9
+ import torch
10
+
11
+ from library import train_util
12
+
13
+
14
+ class LoRAModule(torch.nn.Module):
15
+ """
16
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
17
+ """
18
+
19
+ def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
20
+ """ if alpha == 0 or None, alpha is rank (no scaling). """
21
+ super().__init__()
22
+ self.lora_name = lora_name
23
+ self.lora_dim = lora_dim
24
+
25
+ if org_module.__class__.__name__ == 'Conv2d':
26
+ in_dim = org_module.in_channels
27
+ out_dim = org_module.out_channels
28
+ self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
29
+ self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
30
+ else:
31
+ in_dim = org_module.in_features
32
+ out_dim = org_module.out_features
33
+ self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
34
+ self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
35
+
36
+ if type(alpha) == torch.Tensor:
37
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
38
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
39
+ self.scale = alpha / self.lora_dim
40
+ self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
41
+
42
+ # same as microsoft's
43
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
44
+ torch.nn.init.zeros_(self.lora_up.weight)
45
+
46
+ self.multiplier = multiplier
47
+ self.org_module = org_module # remove in applying
48
+
49
+ def apply_to(self):
50
+ self.org_forward = self.org_module.forward
51
+ self.org_module.forward = self.forward
52
+ del self.org_module
53
+
54
+ def forward(self, x):
55
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
56
+
57
+
58
+ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
59
+ if network_dim is None:
60
+ network_dim = 4 # default
61
+ network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
62
+ return network
63
+
64
+
65
+ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
66
+ if os.path.splitext(file)[1] == '.safetensors':
67
+ from safetensors.torch import load_file, safe_open
68
+ weights_sd = load_file(file)
69
+ else:
70
+ weights_sd = torch.load(file, map_location='cpu')
71
+
72
+ # get dim (rank)
73
+ network_alpha = None
74
+ network_dim = None
75
+ for key, value in weights_sd.items():
76
+ if network_alpha is None and 'alpha' in key:
77
+ network_alpha = value
78
+ if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
79
+ network_dim = value.size()[0]
80
+
81
+ if network_alpha is None:
82
+ network_alpha = network_dim
83
+
84
+ network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
85
+ network.weights_sd = weights_sd
86
+ return network
87
+
88
+
89
+ class LoRANetwork(torch.nn.Module):
90
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
91
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
92
+ LORA_PREFIX_UNET = 'lora_unet'
93
+ LORA_PREFIX_TEXT_ENCODER = 'lora_te'
94
+
95
+ def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
96
+ super().__init__()
97
+ self.multiplier = multiplier
98
+ self.lora_dim = lora_dim
99
+ self.alpha = alpha
100
+
101
+ # create module instances
102
+ def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
103
+ loras = []
104
+ for name, module in root_module.named_modules():
105
+ if module.__class__.__name__ in target_replace_modules:
106
+ for child_name, child_module in module.named_modules():
107
+ if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
108
+ lora_name = prefix + '.' + name + '.' + child_name
109
+ lora_name = lora_name.replace('.', '_')
110
+ lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
111
+ loras.append(lora)
112
+ return loras
113
+
114
+ self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER,
115
+ text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
116
+ print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
117
+
118
+ self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
119
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
120
+
121
+ self.weights_sd = None
122
+
123
+ # assertion
124
+ names = set()
125
+ for lora in self.text_encoder_loras + self.unet_loras:
126
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
127
+ names.add(lora.lora_name)
128
+
129
+ def load_weights(self, file):
130
+ if os.path.splitext(file)[1] == '.safetensors':
131
+ from safetensors.torch import load_file, safe_open
132
+ self.weights_sd = load_file(file)
133
+ else:
134
+ self.weights_sd = torch.load(file, map_location='cpu')
135
+
136
+ def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
137
+ if self.weights_sd:
138
+ weights_has_text_encoder = weights_has_unet = False
139
+ for key in self.weights_sd.keys():
140
+ if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
141
+ weights_has_text_encoder = True
142
+ elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
143
+ weights_has_unet = True
144
+
145
+ if apply_text_encoder is None:
146
+ apply_text_encoder = weights_has_text_encoder
147
+ else:
148
+ assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
149
+
150
+ if apply_unet is None:
151
+ apply_unet = weights_has_unet
152
+ else:
153
+ assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
154
+ else:
155
+ assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
156
+
157
+ if apply_text_encoder:
158
+ print("enable LoRA for text encoder")
159
+ else:
160
+ self.text_encoder_loras = []
161
+
162
+ if apply_unet:
163
+ print("enable LoRA for U-Net")
164
+ else:
165
+ self.unet_loras = []
166
+
167
+ for lora in self.text_encoder_loras + self.unet_loras:
168
+ lora.apply_to()
169
+ self.add_module(lora.lora_name, lora)
170
+
171
+ if self.weights_sd:
172
+ # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
173
+ info = self.load_state_dict(self.weights_sd, False)
174
+ print(f"weights are loaded: {info}")
175
+
176
+ def enable_gradient_checkpointing(self):
177
+ # not supported
178
+ pass
179
+
180
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
181
+ def enumerate_params(loras):
182
+ params = []
183
+ for lora in loras:
184
+ params.extend(lora.parameters())
185
+ return params
186
+
187
+ self.requires_grad_(True)
188
+ all_params = []
189
+
190
+ if self.text_encoder_loras:
191
+ param_data = {'params': enumerate_params(self.text_encoder_loras)}
192
+ if text_encoder_lr is not None:
193
+ param_data['lr'] = text_encoder_lr
194
+ all_params.append(param_data)
195
+
196
+ if self.unet_loras:
197
+ param_data = {'params': enumerate_params(self.unet_loras)}
198
+ if unet_lr is not None:
199
+ param_data['lr'] = unet_lr
200
+ all_params.append(param_data)
201
+
202
+ return all_params
203
+
204
+ def prepare_grad_etc(self, text_encoder, unet):
205
+ self.requires_grad_(True)
206
+
207
+ def on_epoch_start(self, text_encoder, unet):
208
+ self.train()
209
+
210
+ def get_trainable_params(self):
211
+ return self.parameters()
212
+
213
+ def save_weights(self, file, dtype, metadata):
214
+ if metadata is not None and len(metadata) == 0:
215
+ metadata = None
216
+
217
+ state_dict = self.state_dict()
218
+
219
+ if dtype is not None:
220
+ for key in list(state_dict.keys()):
221
+ v = state_dict[key]
222
+ v = v.detach().clone().to("cpu").to(dtype)
223
+ state_dict[key] = v
224
+
225
+ if os.path.splitext(file)[1] == '.safetensors':
226
+ from safetensors.torch import save_file
227
+
228
+ # Precalculate model hashes to save time on indexing
229
+ if metadata is None:
230
+ metadata = {}
231
+ model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
232
+ metadata["sshs_model_hash"] = model_hash
233
+ metadata["sshs_legacy_hash"] = legacy_hash
234
+
235
+ save_file(state_dict, file, metadata)
236
+ else:
237
+ torch.save(state_dict, file)
networks/lora_interrogator.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from tqdm import tqdm
4
+ from library import model_util
5
+ import argparse
6
+ from transformers import CLIPTokenizer
7
+ import torch
8
+
9
+ import library.model_util as model_util
10
+ import lora
11
+
12
+ TOKENIZER_PATH = "openai/clip-vit-large-patch14"
13
+ V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
14
+
15
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+
17
+
18
+ def interrogate(args):
19
+ # いろいろ準備する
20
+ print(f"loading SD model: {args.sd_model}")
21
+ text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
22
+
23
+ print(f"loading LoRA: {args.model}")
24
+ network = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
25
+
26
+ # text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい
27
+ has_te_weight = False
28
+ for key in network.weights_sd.keys():
29
+ if 'lora_te' in key:
30
+ has_te_weight = True
31
+ break
32
+ if not has_te_weight:
33
+ print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
34
+ return
35
+ del vae
36
+
37
+ print("loading tokenizer")
38
+ if args.v2:
39
+ tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
40
+ else:
41
+ tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
42
+
43
+ text_encoder.to(DEVICE)
44
+ text_encoder.eval()
45
+ unet.to(DEVICE)
46
+ unet.eval() # U-Netは呼び出さないので不要だけど
47
+
48
+ # トークンをひとつひとつ当たっていく
49
+ token_id_start = 0
50
+ token_id_end = max(tokenizer.all_special_ids)
51
+ print(f"interrogate tokens are: {token_id_start} to {token_id_end}")
52
+
53
+ def get_all_embeddings(text_encoder):
54
+ embs = []
55
+ with torch.no_grad():
56
+ for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)):
57
+ batch = []
58
+ for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)):
59
+ tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id]
60
+ # tokens = [tid] # こちらは結果がいまひとつ
61
+ batch.append(tokens)
62
+
63
+ # batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1]
64
+ # clip skip対応
65
+ batch = torch.tensor(batch).to(DEVICE)
66
+ if args.clip_skip is None:
67
+ encoder_hidden_states = text_encoder(batch)[0]
68
+ else:
69
+ enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True)
70
+ encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
71
+ encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
72
+ encoder_hidden_states = encoder_hidden_states.to("cpu")
73
+
74
+ embs.extend(encoder_hidden_states)
75
+ return torch.stack(embs)
76
+
77
+ print("get original text encoder embeddings.")
78
+ orig_embs = get_all_embeddings(text_encoder)
79
+
80
+ network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
81
+ network.to(DEVICE)
82
+ network.eval()
83
+
84
+ print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
85
+ print("get text encoder embeddings with lora.")
86
+ lora_embs = get_all_embeddings(text_encoder)
87
+
88
+ # 比べる:とりあえず単純に差分の絶対値で
89
+ print("comparing...")
90
+ diffs = {}
91
+ for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
92
+ diff = torch.mean(torch.abs(orig_emb - lora_emb))
93
+ # diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うまく検出できない
94
+ diff = float(diff.detach().to('cpu').numpy())
95
+ diffs[token_id_start + i] = diff
96
+
97
+ diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1])
98
+
99
+ # 結果を表示する
100
+ print("top 100:")
101
+ for i, (token, diff) in enumerate(diffs_sorted[:100]):
102
+ # if diff < 1e-6:
103
+ # break
104
+ string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token]))
105
+ print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}")
106
+
107
+
108
+ if __name__ == '__main__':
109
+ parser = argparse.ArgumentParser()
110
+ parser.add_argument("--v2", action='store_true',
111
+ help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
112
+ parser.add_argument("--sd_model", type=str, default=None,
113
+ help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors")
114
+ parser.add_argument("--model", type=str, default=None,
115
+ help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors")
116
+ parser.add_argument("--batch_size", type=int, default=16,
117
+ help="batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ")
118
+ parser.add_argument("--clip_skip", type=int, default=None,
119
+ help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
120
+
121
+ args = parser.parse_args()
122
+ interrogate(args)
networks/merge_lora.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ import argparse
4
+ import os
5
+ import torch
6
+ from safetensors.torch import load_file, save_file
7
+ import library.model_util as model_util
8
+ import lora
9
+
10
+
11
+ def load_state_dict(file_name, dtype):
12
+ if os.path.splitext(file_name)[1] == '.safetensors':
13
+ sd = load_file(file_name)
14
+ else:
15
+ sd = torch.load(file_name, map_location='cpu')
16
+ for key in list(sd.keys()):
17
+ if type(sd[key]) == torch.Tensor:
18
+ sd[key] = sd[key].to(dtype)
19
+ return sd
20
+
21
+
22
+ def save_to_file(file_name, model, state_dict, dtype):
23
+ if dtype is not None:
24
+ for key in list(state_dict.keys()):
25
+ if type(state_dict[key]) == torch.Tensor:
26
+ state_dict[key] = state_dict[key].to(dtype)
27
+
28
+ if os.path.splitext(file_name)[1] == '.safetensors':
29
+ save_file(model, file_name)
30
+ else:
31
+ torch.save(model, file_name)
32
+
33
+
34
+ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
35
+ text_encoder.to(merge_dtype)
36
+ unet.to(merge_dtype)
37
+
38
+ # create module map
39
+ name_to_module = {}
40
+ for i, root_module in enumerate([text_encoder, unet]):
41
+ if i == 0:
42
+ prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
43
+ target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
44
+ else:
45
+ prefix = lora.LoRANetwork.LORA_PREFIX_UNET
46
+ target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE
47
+
48
+ for name, module in root_module.named_modules():
49
+ if module.__class__.__name__ in target_replace_modules:
50
+ for child_name, child_module in module.named_modules():
51
+ if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
52
+ lora_name = prefix + '.' + name + '.' + child_name
53
+ lora_name = lora_name.replace('.', '_')
54
+ name_to_module[lora_name] = child_module
55
+
56
+ for model, ratio in zip(models, ratios):
57
+ print(f"loading: {model}")
58
+ lora_sd = load_state_dict(model, merge_dtype)
59
+
60
+ print(f"merging...")
61
+ for key in lora_sd.keys():
62
+ if "lora_down" in key:
63
+ up_key = key.replace("lora_down", "lora_up")
64
+ alpha_key = key[:key.index("lora_down")] + 'alpha'
65
+
66
+ # find original module for this lora
67
+ module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
68
+ if module_name not in name_to_module:
69
+ print(f"no module found for LoRA weight: {key}")
70
+ continue
71
+ module = name_to_module[module_name]
72
+ # print(f"apply {key} to {module}")
73
+
74
+ down_weight = lora_sd[key]
75
+ up_weight = lora_sd[up_key]
76
+
77
+ dim = down_weight.size()[0]
78
+ alpha = lora_sd.get(alpha_key, dim)
79
+ scale = alpha / dim
80
+
81
+ # W <- W + U * D
82
+ weight = module.weight
83
+ if len(weight.size()) == 2:
84
+ # linear
85
+ weight = weight + ratio * (up_weight @ down_weight) * scale
86
+ else:
87
+ # conv2d
88
+ weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
89
+ ).unsqueeze(2).unsqueeze(3) * scale
90
+
91
+ module.weight = torch.nn.Parameter(weight)
92
+
93
+
94
+ def merge_lora_models(models, ratios, merge_dtype):
95
+ base_alphas = {} # alpha for merged model
96
+ base_dims = {}
97
+
98
+ merged_sd = {}
99
+ for model, ratio in zip(models, ratios):
100
+ print(f"loading: {model}")
101
+ lora_sd = load_state_dict(model, merge_dtype)
102
+
103
+ # get alpha and dim
104
+ alphas = {} # alpha for current model
105
+ dims = {} # dims for current model
106
+ for key in lora_sd.keys():
107
+ if 'alpha' in key:
108
+ lora_module_name = key[:key.rfind(".alpha")]
109
+ alpha = float(lora_sd[key].detach().numpy())
110
+ alphas[lora_module_name] = alpha
111
+ if lora_module_name not in base_alphas:
112
+ base_alphas[lora_module_name] = alpha
113
+ elif "lora_down" in key:
114
+ lora_module_name = key[:key.rfind(".lora_down")]
115
+ dim = lora_sd[key].size()[0]
116
+ dims[lora_module_name] = dim
117
+ if lora_module_name not in base_dims:
118
+ base_dims[lora_module_name] = dim
119
+
120
+ for lora_module_name in dims.keys():
121
+ if lora_module_name not in alphas:
122
+ alpha = dims[lora_module_name]
123
+ alphas[lora_module_name] = alpha
124
+ if lora_module_name not in base_alphas:
125
+ base_alphas[lora_module_name] = alpha
126
+
127
+ print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
128
+
129
+ # merge
130
+ print(f"merging...")
131
+ for key in lora_sd.keys():
132
+ if 'alpha' in key:
133
+ continue
134
+
135
+ lora_module_name = key[:key.rfind(".lora_")]
136
+
137
+ base_alpha = base_alphas[lora_module_name]
138
+ alpha = alphas[lora_module_name]
139
+
140
+ scale = math.sqrt(alpha / base_alpha) * ratio
141
+
142
+ if key in merged_sd:
143
+ assert merged_sd[key].size() == lora_sd[key].size(
144
+ ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズ��合いません。v1とv2、または次元数の異なるモデルはマージできません"
145
+ merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
146
+ else:
147
+ merged_sd[key] = lora_sd[key] * scale
148
+
149
+ # set alpha to sd
150
+ for lora_module_name, alpha in base_alphas.items():
151
+ key = lora_module_name + ".alpha"
152
+ merged_sd[key] = torch.tensor(alpha)
153
+
154
+ print("merged model")
155
+ print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
156
+
157
+ return merged_sd
158
+
159
+
160
+ def merge(args):
161
+ assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
162
+
163
+ def str_to_dtype(p):
164
+ if p == 'float':
165
+ return torch.float
166
+ if p == 'fp16':
167
+ return torch.float16
168
+ if p == 'bf16':
169
+ return torch.bfloat16
170
+ return None
171
+
172
+ merge_dtype = str_to_dtype(args.precision)
173
+ save_dtype = str_to_dtype(args.save_precision)
174
+ if save_dtype is None:
175
+ save_dtype = merge_dtype
176
+
177
+ if args.sd_model is not None:
178
+ print(f"loading SD model: {args.sd_model}")
179
+
180
+ text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
181
+
182
+ merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
183
+
184
+ print(f"saving SD model to: {args.save_to}")
185
+ model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
186
+ args.sd_model, 0, 0, save_dtype, vae)
187
+ else:
188
+ state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
189
+
190
+ print(f"saving model to: {args.save_to}")
191
+ save_to_file(args.save_to, state_dict, state_dict, save_dtype)
192
+
193
+
194
+ if __name__ == '__main__':
195
+ parser = argparse.ArgumentParser()
196
+ parser.add_argument("--v2", action='store_true',
197
+ help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
198
+ parser.add_argument("--save_precision", type=str, default=None,
199
+ choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
200
+ parser.add_argument("--precision", type=str, default="float",
201
+ choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)")
202
+ parser.add_argument("--sd_model", type=str, default=None,
203
+ help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
204
+ parser.add_argument("--save_to", type=str, default=None,
205
+ help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
206
+ parser.add_argument("--models", type=str, nargs='*',
207
+ help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
208
+ parser.add_argument("--ratios", type=float, nargs='*',
209
+ help="ratios for each model / それぞれのLoRAモデルの比率")
210
+
211
+ args = parser.parse_args()
212
+ merge(args)
networks/merge_lora_old.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import argparse
4
+ import os
5
+ import torch
6
+ from safetensors.torch import load_file, save_file
7
+ import library.model_util as model_util
8
+ import lora
9
+
10
+
11
+ def load_state_dict(file_name, dtype):
12
+ if os.path.splitext(file_name)[1] == '.safetensors':
13
+ sd = load_file(file_name)
14
+ else:
15
+ sd = torch.load(file_name, map_location='cpu')
16
+ for key in list(sd.keys()):
17
+ if type(sd[key]) == torch.Tensor:
18
+ sd[key] = sd[key].to(dtype)
19
+ return sd
20
+
21
+
22
+ def save_to_file(file_name, model, state_dict, dtype):
23
+ if dtype is not None:
24
+ for key in list(state_dict.keys()):
25
+ if type(state_dict[key]) == torch.Tensor:
26
+ state_dict[key] = state_dict[key].to(dtype)
27
+
28
+ if os.path.splitext(file_name)[1] == '.safetensors':
29
+ save_file(model, file_name)
30
+ else:
31
+ torch.save(model, file_name)
32
+
33
+
34
+ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
35
+ text_encoder.to(merge_dtype)
36
+ unet.to(merge_dtype)
37
+
38
+ # create module map
39
+ name_to_module = {}
40
+ for i, root_module in enumerate([text_encoder, unet]):
41
+ if i == 0:
42
+ prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
43
+ target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
44
+ else:
45
+ prefix = lora.LoRANetwork.LORA_PREFIX_UNET
46
+ target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE
47
+
48
+ for name, module in root_module.named_modules():
49
+ if module.__class__.__name__ in target_replace_modules:
50
+ for child_name, child_module in module.named_modules():
51
+ if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
52
+ lora_name = prefix + '.' + name + '.' + child_name
53
+ lora_name = lora_name.replace('.', '_')
54
+ name_to_module[lora_name] = child_module
55
+
56
+ for model, ratio in zip(models, ratios):
57
+ print(f"loading: {model}")
58
+ lora_sd = load_state_dict(model, merge_dtype)
59
+
60
+ print(f"merging...")
61
+ for key in lora_sd.keys():
62
+ if "lora_down" in key:
63
+ up_key = key.replace("lora_down", "lora_up")
64
+ alpha_key = key[:key.index("lora_down")] + 'alpha'
65
+
66
+ # find original module for this lora
67
+ module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
68
+ if module_name not in name_to_module:
69
+ print(f"no module found for LoRA weight: {key}")
70
+ continue
71
+ module = name_to_module[module_name]
72
+ # print(f"apply {key} to {module}")
73
+
74
+ down_weight = lora_sd[key]
75
+ up_weight = lora_sd[up_key]
76
+
77
+ dim = down_weight.size()[0]
78
+ alpha = lora_sd.get(alpha_key, dim)
79
+ scale = alpha / dim
80
+
81
+ # W <- W + U * D
82
+ weight = module.weight
83
+ if len(weight.size()) == 2:
84
+ # linear
85
+ weight = weight + ratio * (up_weight @ down_weight) * scale
86
+ else:
87
+ # conv2d
88
+ weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale
89
+
90
+ module.weight = torch.nn.Parameter(weight)
91
+
92
+
93
+ def merge_lora_models(models, ratios, merge_dtype):
94
+ merged_sd = {}
95
+
96
+ alpha = None
97
+ dim = None
98
+ for model, ratio in zip(models, ratios):
99
+ print(f"loading: {model}")
100
+ lora_sd = load_state_dict(model, merge_dtype)
101
+
102
+ print(f"merging...")
103
+ for key in lora_sd.keys():
104
+ if 'alpha' in key:
105
+ if key in merged_sd:
106
+ assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません"
107
+ else:
108
+ alpha = lora_sd[key].detach().numpy()
109
+ merged_sd[key] = lora_sd[key]
110
+ else:
111
+ if key in merged_sd:
112
+ assert merged_sd[key].size() == lora_sd[key].size(
113
+ ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
114
+ merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
115
+ else:
116
+ if "lora_down" in key:
117
+ dim = lora_sd[key].size()[0]
118
+ merged_sd[key] = lora_sd[key] * ratio
119
+
120
+ print(f"dim (rank): {dim}, alpha: {alpha}")
121
+ if alpha is None:
122
+ alpha = dim
123
+
124
+ return merged_sd, dim, alpha
125
+
126
+
127
+ def merge(args):
128
+ assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
129
+
130
+ def str_to_dtype(p):
131
+ if p == 'float':
132
+ return torch.float
133
+ if p == 'fp16':
134
+ return torch.float16
135
+ if p == 'bf16':
136
+ return torch.bfloat16
137
+ return None
138
+
139
+ merge_dtype = str_to_dtype(args.precision)
140
+ save_dtype = str_to_dtype(args.save_precision)
141
+ if save_dtype is None:
142
+ save_dtype = merge_dtype
143
+
144
+ if args.sd_model is not None:
145
+ print(f"loading SD model: {args.sd_model}")
146
+
147
+ text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
148
+
149
+ merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
150
+
151
+ print(f"saving SD model to: {args.save_to}")
152
+ model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
153
+ args.sd_model, 0, 0, save_dtype, vae)
154
+ else:
155
+ state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
156
+
157
+ print(f"saving model to: {args.save_to}")
158
+ save_to_file(args.save_to, state_dict, state_dict, save_dtype)
159
+
160
+
161
+ if __name__ == '__main__':
162
+ parser = argparse.ArgumentParser()
163
+ parser.add_argument("--v2", action='store_true',
164
+ help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
165
+ parser.add_argument("--save_precision", type=str, default=None,
166
+ choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
167
+ parser.add_argument("--precision", type=str, default="float",
168
+ choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)")
169
+ parser.add_argument("--sd_model", type=str, default=None,
170
+ help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
171
+ parser.add_argument("--save_to", type=str, default=None,
172
+ help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
173
+ parser.add_argument("--models", type=str, nargs='*',
174
+ help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
175
+ parser.add_argument("--ratios", type=float, nargs='*',
176
+ help="ratios for each model / それぞれのLoRAモデルの比率")
177
+
178
+ args = parser.parse_args()
179
+ merge(args)
networks/resize_lora.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Convert LoRA to different rank approximation (should only be used to go to lower rank)
2
+ # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
3
+ # Thanks to cloneofsimo and kohya
4
+
5
+ import argparse
6
+ import os
7
+ import torch
8
+ from safetensors.torch import load_file, save_file, safe_open
9
+ from tqdm import tqdm
10
+ from library import train_util, model_util
11
+
12
+
13
+ def load_state_dict(file_name, dtype):
14
+ if model_util.is_safetensors(file_name):
15
+ sd = load_file(file_name)
16
+ with safe_open(file_name, framework="pt") as f:
17
+ metadata = f.metadata()
18
+ else:
19
+ sd = torch.load(file_name, map_location='cpu')
20
+ metadata = None
21
+
22
+ for key in list(sd.keys()):
23
+ if type(sd[key]) == torch.Tensor:
24
+ sd[key] = sd[key].to(dtype)
25
+
26
+ return sd, metadata
27
+
28
+
29
+ def save_to_file(file_name, model, state_dict, dtype, metadata):
30
+ if dtype is not None:
31
+ for key in list(state_dict.keys()):
32
+ if type(state_dict[key]) == torch.Tensor:
33
+ state_dict[key] = state_dict[key].to(dtype)
34
+
35
+ if model_util.is_safetensors(file_name):
36
+ save_file(model, file_name, metadata)
37
+ else:
38
+ torch.save(model, file_name)
39
+
40
+
41
+ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
42
+ network_alpha = None
43
+ network_dim = None
44
+ verbose_str = "\n"
45
+
46
+ CLAMP_QUANTILE = 0.99
47
+
48
+ # Extract loaded lora dim and alpha
49
+ for key, value in lora_sd.items():
50
+ if network_alpha is None and 'alpha' in key:
51
+ network_alpha = value
52
+ if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
53
+ network_dim = value.size()[0]
54
+ if network_alpha is not None and network_dim is not None:
55
+ break
56
+ if network_alpha is None:
57
+ network_alpha = network_dim
58
+
59
+ scale = network_alpha/network_dim
60
+ new_alpha = float(scale*new_rank) # calculate new alpha from scale
61
+
62
+ print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}")
63
+
64
+ lora_down_weight = None
65
+ lora_up_weight = None
66
+
67
+ o_lora_sd = lora_sd.copy()
68
+ block_down_name = None
69
+ block_up_name = None
70
+
71
+ print("resizing lora...")
72
+ with torch.no_grad():
73
+ for key, value in tqdm(lora_sd.items()):
74
+ if 'lora_down' in key:
75
+ block_down_name = key.split(".")[0]
76
+ lora_down_weight = value
77
+ if 'lora_up' in key:
78
+ block_up_name = key.split(".")[0]
79
+ lora_up_weight = value
80
+
81
+ weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
82
+
83
+ if (block_down_name == block_up_name) and weights_loaded:
84
+
85
+ conv2d = (len(lora_down_weight.size()) == 4)
86
+
87
+ if conv2d:
88
+ lora_down_weight = lora_down_weight.squeeze()
89
+ lora_up_weight = lora_up_weight.squeeze()
90
+
91
+ if device:
92
+ org_device = lora_up_weight.device
93
+ lora_up_weight = lora_up_weight.to(args.device)
94
+ lora_down_weight = lora_down_weight.to(args.device)
95
+
96
+ full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
97
+
98
+ U, S, Vh = torch.linalg.svd(full_weight_matrix)
99
+
100
+ if verbose:
101
+ s_sum = torch.sum(torch.abs(S))
102
+ s_rank = torch.sum(torch.abs(S[:new_rank]))
103
+ verbose_str+=f"{block_down_name:76} | "
104
+ verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n"
105
+
106
+ U = U[:, :new_rank]
107
+ S = S[:new_rank]
108
+ U = U @ torch.diag(S)
109
+
110
+ Vh = Vh[:new_rank, :]
111
+
112
+ dist = torch.cat([U.flatten(), Vh.flatten()])
113
+ hi_val = torch.quantile(dist, CLAMP_QUANTILE)
114
+ low_val = -hi_val
115
+
116
+ U = U.clamp(low_val, hi_val)
117
+ Vh = Vh.clamp(low_val, hi_val)
118
+
119
+ if conv2d:
120
+ U = U.unsqueeze(2).unsqueeze(3)
121
+ Vh = Vh.unsqueeze(2).unsqueeze(3)
122
+
123
+ if device:
124
+ U = U.to(org_device)
125
+ Vh = Vh.to(org_device)
126
+
127
+ o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
128
+ o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
129
+ o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
130
+
131
+ block_down_name = None
132
+ block_up_name = None
133
+ lora_down_weight = None
134
+ lora_up_weight = None
135
+ weights_loaded = False
136
+
137
+ if verbose:
138
+ print(verbose_str)
139
+ print("resizing complete")
140
+ return o_lora_sd, network_dim, new_alpha
141
+
142
+
143
+ def resize(args):
144
+
145
+ def str_to_dtype(p):
146
+ if p == 'float':
147
+ return torch.float
148
+ if p == 'fp16':
149
+ return torch.float16
150
+ if p == 'bf16':
151
+ return torch.bfloat16
152
+ return None
153
+
154
+ merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
155
+ save_dtype = str_to_dtype(args.save_precision)
156
+ if save_dtype is None:
157
+ save_dtype = merge_dtype
158
+
159
+ print("loading Model...")
160
+ lora_sd, metadata = load_state_dict(args.model, merge_dtype)
161
+
162
+ print("resizing rank...")
163
+ state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose)
164
+
165
+ # update metadata
166
+ if metadata is None:
167
+ metadata = {}
168
+
169
+ comment = metadata.get("ss_training_comment", "")
170
+ metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
171
+ metadata["ss_network_dim"] = str(args.new_rank)
172
+ metadata["ss_network_alpha"] = str(new_alpha)
173
+
174
+ model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
175
+ metadata["sshs_model_hash"] = model_hash
176
+ metadata["sshs_legacy_hash"] = legacy_hash
177
+
178
+ print(f"saving model to: {args.save_to}")
179
+ save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
180
+
181
+
182
+ if __name__ == '__main__':
183
+ parser = argparse.ArgumentParser()
184
+
185
+ parser.add_argument("--save_precision", type=str, default=None,
186
+ choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat")
187
+ parser.add_argument("--new_rank", type=int, default=4,
188
+ help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
189
+ parser.add_argument("--save_to", type=str, default=None,
190
+ help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
191
+ parser.add_argument("--model", type=str, default=None,
192
+ help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors")
193
+ parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
194
+ parser.add_argument("--verbose", action="store_true",
195
+ help="Display verbose resizing information / rank変更時の詳細情報を出力する")
196
+
197
+ args = parser.parse_args()
198
+ resize(args)
networks/svd_merge_lora.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ import argparse
4
+ import os
5
+ import torch
6
+ from safetensors.torch import load_file, save_file
7
+ from tqdm import tqdm
8
+ import library.model_util as model_util
9
+ import lora
10
+
11
+
12
+ CLAMP_QUANTILE = 0.99
13
+
14
+
15
+ def load_state_dict(file_name, dtype):
16
+ if os.path.splitext(file_name)[1] == '.safetensors':
17
+ sd = load_file(file_name)
18
+ else:
19
+ sd = torch.load(file_name, map_location='cpu')
20
+ for key in list(sd.keys()):
21
+ if type(sd[key]) == torch.Tensor:
22
+ sd[key] = sd[key].to(dtype)
23
+ return sd
24
+
25
+
26
+ def save_to_file(file_name, model, state_dict, dtype):
27
+ if dtype is not None:
28
+ for key in list(state_dict.keys()):
29
+ if type(state_dict[key]) == torch.Tensor:
30
+ state_dict[key] = state_dict[key].to(dtype)
31
+
32
+ if os.path.splitext(file_name)[1] == '.safetensors':
33
+ save_file(model, file_name)
34
+ else:
35
+ torch.save(model, file_name)
36
+
37
+
38
+ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
39
+ merged_sd = {}
40
+ for model, ratio in zip(models, ratios):
41
+ print(f"loading: {model}")
42
+ lora_sd = load_state_dict(model, merge_dtype)
43
+
44
+ # merge
45
+ print(f"merging...")
46
+ for key in tqdm(list(lora_sd.keys())):
47
+ if 'lora_down' not in key:
48
+ continue
49
+
50
+ lora_module_name = key[:key.rfind(".lora_down")]
51
+
52
+ down_weight = lora_sd[key]
53
+ network_dim = down_weight.size()[0]
54
+
55
+ up_weight = lora_sd[lora_module_name + '.lora_up.weight']
56
+ alpha = lora_sd.get(lora_module_name + '.alpha', network_dim)
57
+
58
+ in_dim = down_weight.size()[1]
59
+ out_dim = up_weight.size()[0]
60
+ conv2d = len(down_weight.size()) == 4
61
+ print(lora_module_name, network_dim, alpha, in_dim, out_dim)
62
+
63
+ # make original weight if not exist
64
+ if lora_module_name not in merged_sd:
65
+ weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
66
+ if device:
67
+ weight = weight.to(device)
68
+ else:
69
+ weight = merged_sd[lora_module_name]
70
+
71
+ # merge to weight
72
+ if device:
73
+ up_weight = up_weight.to(device)
74
+ down_weight = down_weight.to(device)
75
+
76
+ # W <- W + U * D
77
+ scale = (alpha / network_dim)
78
+ if not conv2d: # linear
79
+ weight = weight + ratio * (up_weight @ down_weight) * scale
80
+ else:
81
+ weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
82
+ ).unsqueeze(2).unsqueeze(3) * scale
83
+
84
+ merged_sd[lora_module_name] = weight
85
+
86
+ # extract from merged weights
87
+ print("extract new lora...")
88
+ merged_lora_sd = {}
89
+ with torch.no_grad():
90
+ for lora_module_name, mat in tqdm(list(merged_sd.items())):
91
+ conv2d = (len(mat.size()) == 4)
92
+ if conv2d:
93
+ mat = mat.squeeze()
94
+
95
+ U, S, Vh = torch.linalg.svd(mat)
96
+
97
+ U = U[:, :new_rank]
98
+ S = S[:new_rank]
99
+ U = U @ torch.diag(S)
100
+
101
+ Vh = Vh[:new_rank, :]
102
+
103
+ dist = torch.cat([U.flatten(), Vh.flatten()])
104
+ hi_val = torch.quantile(dist, CLAMP_QUANTILE)
105
+ low_val = -hi_val
106
+
107
+ U = U.clamp(low_val, hi_val)
108
+ Vh = Vh.clamp(low_val, hi_val)
109
+
110
+ up_weight = U
111
+ down_weight = Vh
112
+
113
+ if conv2d:
114
+ up_weight = up_weight.unsqueeze(2).unsqueeze(3)
115
+ down_weight = down_weight.unsqueeze(2).unsqueeze(3)
116
+
117
+ merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
118
+ merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
119
+ merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank)
120
+
121
+ return merged_lora_sd
122
+
123
+
124
+ def merge(args):
125
+ assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
126
+
127
+ def str_to_dtype(p):
128
+ if p == 'float':
129
+ return torch.float
130
+ if p == 'fp16':
131
+ return torch.float16
132
+ if p == 'bf16':
133
+ return torch.bfloat16
134
+ return None
135
+
136
+ merge_dtype = str_to_dtype(args.precision)
137
+ save_dtype = str_to_dtype(args.save_precision)
138
+ if save_dtype is None:
139
+ save_dtype = merge_dtype
140
+
141
+ state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, args.device, merge_dtype)
142
+
143
+ print(f"saving model to: {args.save_to}")
144
+ save_to_file(args.save_to, state_dict, state_dict, save_dtype)
145
+
146
+
147
+ if __name__ == '__main__':
148
+ parser = argparse.ArgumentParser()
149
+ parser.add_argument("--save_precision", type=str, default=None,
150
+ choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
151
+ parser.add_argument("--precision", type=str, default="float",
152
+ choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)")
153
+ parser.add_argument("--save_to", type=str, default=None,
154
+ help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
155
+ parser.add_argument("--models", type=str, nargs='*',
156
+ help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
157
+ parser.add_argument("--ratios", type=float, nargs='*',
158
+ help="ratios for each model / それぞれのLoRAモデルの比率")
159
+ parser.add_argument("--new_rank", type=int, default=4,
160
+ help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
161
+ parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
162
+
163
+ args = parser.parse_args()
164
+ merge(args)
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.15.0
2
+ transformers==4.26.0
3
+ ftfy==6.1.1
4
+ albumentations==1.3.0
5
+ opencv-python==4.7.0.68
6
+ einops==0.6.0
7
+ diffusers[torch]==0.10.2
8
+ pytorch-lightning==1.9.0
9
+ bitsandbytes==0.35.0
10
+ tensorboard==2.10.1
11
+ safetensors==0.2.6
12
+ gradio==3.16.2
13
+ altair==4.2.2
14
+ easygui==0.98.3
15
+ # for BLIP captioning
16
+ requests==2.28.2
17
+ timm==0.6.12
18
+ fairscale==0.4.13
19
+ # for WD14 captioning
20
+ # tensorflow<2.11
21
+ tensorflow==2.10.1
22
+ huggingface-hub==0.12.0
23
+ # for kohya_ss library
24
+ #locon.locon_kohya
25
+ .
requirements_startup.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.15.0
2
+ transformers==4.26.0
3
+ ftfy==6.1.1
4
+ albumentations==1.3.0
5
+ opencv-python==4.7.0.68
6
+ einops==0.6.0
7
+ diffusers[torch]==0.10.2
8
+ pytorch-lightning==1.9.0
9
+ bitsandbytes==0.35.0
10
+ tensorboard==2.10.1
11
+ safetensors==0.2.6
12
+ gradio==3.18.0
13
+ altair==4.2.2
14
+ easygui==0.98.3
15
+ # for BLIP captioning
16
+ requests==2.28.2
17
+ timm==0.4.12
18
+ fairscale==0.4.4
19
+ # for WD14 captioning
20
+ tensorflow==2.10.1
21
+ huggingface-hub==0.12.0
22
+ # for kohya_ss library
23
+ .
setup.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(name = "library", packages = find_packages())
tools/convert_diffusers20_original_sd.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # convert Diffusers v1.x/v2.0 model to original Stable Diffusion
2
+
3
+ import argparse
4
+ import os
5
+ import torch
6
+ from diffusers import StableDiffusionPipeline
7
+
8
+ import library.model_util as model_util
9
+
10
+
11
+ def convert(args):
12
+ # 引数を確認する
13
+ load_dtype = torch.float16 if args.fp16 else None
14
+
15
+ save_dtype = None
16
+ if args.fp16:
17
+ save_dtype = torch.float16
18
+ elif args.bf16:
19
+ save_dtype = torch.bfloat16
20
+ elif args.float:
21
+ save_dtype = torch.float
22
+
23
+ is_load_ckpt = os.path.isfile(args.model_to_load)
24
+ is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
25
+
26
+ assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
27
+ assert is_save_ckpt or args.reference_model is not None, f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
28
+
29
+ # モデルを読み込む
30
+ msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
31
+ print(f"loading {msg}: {args.model_to_load}")
32
+
33
+ if is_load_ckpt:
34
+ v2_model = args.v2
35
+ text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load)
36
+ else:
37
+ pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None)
38
+ text_encoder = pipe.text_encoder
39
+ vae = pipe.vae
40
+ unet = pipe.unet
41
+
42
+ if args.v1 == args.v2:
43
+ # 自動判定する
44
+ v2_model = unet.config.cross_attention_dim == 1024
45
+ print("checking model version: model is " + ('v2' if v2_model else 'v1'))
46
+ else:
47
+ v2_model = not args.v1
48
+
49
+ # 変換して保存する
50
+ msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
51
+ print(f"converting and saving as {msg}: {args.model_to_save}")
52
+
53
+ if is_save_ckpt:
54
+ original_model = args.model_to_load if is_load_ckpt else None
55
+ key_count = model_util.save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet,
56
+ original_model, args.epoch, args.global_step, save_dtype, vae)
57
+ print(f"model saved. total converted state_dict keys: {key_count}")
58
+ else:
59
+ print(f"copy scheduler/tokenizer config from: {args.reference_model}")
60
+ model_util.save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors)
61
+ print(f"model saved.")
62
+
63
+
64
+ if __name__ == '__main__':
65
+ parser = argparse.ArgumentParser()
66
+ parser.add_argument("--v1", action='store_true',
67
+ help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')
68
+ parser.add_argument("--v2", action='store_true',
69
+ help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む')
70
+ parser.add_argument("--fp16", action='store_true',
71
+ help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)')
72
+ parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)')
73
+ parser.add_argument("--float", action='store_true',
74
+ help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)')
75
+ parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値')
76
+ parser.add_argument("--global_step", type=int, default=0,
77
+ help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')
78
+ parser.add_argument("--reference_model", type=str, default=None,
79
+ help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要")
80
+ parser.add_argument("--use_safetensors", action='store_true',
81
+ help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)")
82
+
83
+ parser.add_argument("model_to_load", type=str, default=None,
84
+ help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")
85
+ parser.add_argument("model_to_save", type=str, default=None,
86
+ help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")
87
+
88
+ args = parser.parse_args()
89
+ convert(args)
tools/detect_face_rotate.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
2
+ # (c) 2022 Kohya S. @kohya_ss
3
+
4
+ # 横長の画像から顔検出して正立するように回転し、そこを中心に正方形に切り出す
5
+
6
+ # v2: extract max face if multiple faces are found
7
+ # v3: add crop_ratio option
8
+ # v4: add multiple faces extraction and min/max size
9
+
10
+ import argparse
11
+ import math
12
+ import cv2
13
+ import glob
14
+ import os
15
+ from anime_face_detector import create_detector
16
+ from tqdm import tqdm
17
+ import numpy as np
18
+
19
+ KP_REYE = 11
20
+ KP_LEYE = 19
21
+
22
+ SCORE_THRES = 0.90
23
+
24
+
25
+ def detect_faces(detector, image, min_size):
26
+ preds = detector(image) # bgr
27
+ # print(len(preds))
28
+
29
+ faces = []
30
+ for pred in preds:
31
+ bb = pred['bbox']
32
+ score = bb[-1]
33
+ if score < SCORE_THRES:
34
+ continue
35
+
36
+ left, top, right, bottom = bb[:4]
37
+ cx = int((left + right) / 2)
38
+ cy = int((top + bottom) / 2)
39
+ fw = int(right - left)
40
+ fh = int(bottom - top)
41
+
42
+ lex, ley = pred['keypoints'][KP_LEYE, 0:2]
43
+ rex, rey = pred['keypoints'][KP_REYE, 0:2]
44
+ angle = math.atan2(ley - rey, lex - rex)
45
+ angle = angle / math.pi * 180
46
+
47
+ faces.append((cx, cy, fw, fh, angle))
48
+
49
+ faces.sort(key=lambda x: max(x[2], x[3]), reverse=True) # 大きい順
50
+ return faces
51
+
52
+
53
+ def rotate_image(image, angle, cx, cy):
54
+ h, w = image.shape[0:2]
55
+ rot_mat = cv2.getRotationMatrix2D((cx, cy), angle, 1.0)
56
+
57
+ # # 回転する分、すこし画像サイズを大きくする→とりあえず無効化
58
+ # nh = max(h, int(w * math.sin(angle)))
59
+ # nw = max(w, int(h * math.sin(angle)))
60
+ # if nh > h or nw > w:
61
+ # pad_y = nh - h
62
+ # pad_t = pad_y // 2
63
+ # pad_x = nw - w
64
+ # pad_l = pad_x // 2
65
+ # m = np.array([[0, 0, pad_l],
66
+ # [0, 0, pad_t]])
67
+ # rot_mat = rot_mat + m
68
+ # h, w = nh, nw
69
+ # cx += pad_l
70
+ # cy += pad_t
71
+
72
+ result = cv2.warpAffine(image, rot_mat, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
73
+ return result, cx, cy
74
+
75
+
76
+ def process(args):
77
+ assert (not args.resize_fit) or args.resize_face_size is None, f"resize_fit and resize_face_size can't be specified both / resize_fitとresize_face_sizeはどちらか片方しか指定できません"
78
+ assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません"
79
+
80
+ # アニメ顔検出モデルを読み込む
81
+ print("loading face detector.")
82
+ detector = create_detector('yolov3')
83
+
84
+ # cropの引数を解析する
85
+ if args.crop_size is None:
86
+ crop_width = crop_height = None
87
+ else:
88
+ tokens = args.crop_size.split(',')
89
+ assert len(tokens) == 2, f"crop_size must be 'width,height' / crop_sizeは'幅,高さ'で指定してください"
90
+ crop_width, crop_height = [int(t) for t in tokens]
91
+
92
+ if args.crop_ratio is None:
93
+ crop_h_ratio = crop_v_ratio = None
94
+ else:
95
+ tokens = args.crop_ratio.split(',')
96
+ assert len(tokens) == 2, f"crop_ratio must be 'horizontal,vertical' / crop_ratioは'幅,高さ'の倍率で指定してください"
97
+ crop_h_ratio, crop_v_ratio = [float(t) for t in tokens]
98
+
99
+ # 画像を処理する
100
+ print("processing.")
101
+ output_extension = ".png"
102
+
103
+ os.makedirs(args.dst_dir, exist_ok=True)
104
+ paths = glob.glob(os.path.join(args.src_dir, "*.png")) + glob.glob(os.path.join(args.src_dir, "*.jpg")) + \
105
+ glob.glob(os.path.join(args.src_dir, "*.webp"))
106
+ for path in tqdm(paths):
107
+ basename = os.path.splitext(os.path.basename(path))[0]
108
+
109
+ # image = cv2.imread(path) # 日本語ファイル名でエラーになる
110
+ image = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_UNCHANGED)
111
+ if len(image.shape) == 2:
112
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
113
+ if image.shape[2] == 4:
114
+ print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
115
+ image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい
116
+
117
+ h, w = image.shape[:2]
118
+
119
+ faces = detect_faces(detector, image, args.multiple_faces)
120
+ for i, face in enumerate(faces):
121
+ cx, cy, fw, fh, angle = face
122
+ face_size = max(fw, fh)
123
+ if args.min_size is not None and face_size < args.min_size:
124
+ continue
125
+ if args.max_size is not None and face_size >= args.max_size:
126
+ continue
127
+ face_suffix = f"_{i+1:02d}" if args.multiple_faces else ""
128
+
129
+ # オプション指定があれば回転する
130
+ face_img = image
131
+ if args.rotate:
132
+ face_img, cx, cy = rotate_image(face_img, angle, cx, cy)
133
+
134
+ # オプション指定があれば顔を中心に切り出す
135
+ if crop_width is not None or crop_h_ratio is not None:
136
+ cur_crop_width, cur_crop_height = crop_width, crop_height
137
+ if crop_h_ratio is not None:
138
+ cur_crop_width = int(face_size * crop_h_ratio + .5)
139
+ cur_crop_height = int(face_size * crop_v_ratio + .5)
140
+
141
+ # リサイズを必要なら行う
142
+ scale = 1.0
143
+ if args.resize_face_size is not None:
144
+ # 顔サイズを基準にリサイズする
145
+ scale = args.resize_face_size / face_size
146
+ if scale < cur_crop_width / w:
147
+ print(
148
+ f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
149
+ scale = cur_crop_width / w
150
+ if scale < cur_crop_height / h:
151
+ print(
152
+ f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
153
+ scale = cur_crop_height / h
154
+ elif crop_h_ratio is not None:
155
+ # 倍率指定の時にはリサイズしない
156
+ pass
157
+ else:
158
+ # 切り出しサイズ指定あり
159
+ if w < cur_crop_width:
160
+ print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
161
+ scale = cur_crop_width / w
162
+ if h < cur_crop_height:
163
+ print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
164
+ scale = cur_crop_height / h
165
+ if args.resize_fit:
166
+ scale = max(cur_crop_width / w, cur_crop_height / h)
167
+
168
+ if scale != 1.0:
169
+ w = int(w * scale + .5)
170
+ h = int(h * scale + .5)
171
+ face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4)
172
+ cx = int(cx * scale + .5)
173
+ cy = int(cy * scale + .5)
174
+ fw = int(fw * scale + .5)
175
+ fh = int(fh * scale + .5)
176
+
177
+ cur_crop_width = min(cur_crop_width, face_img.shape[1])
178
+ cur_crop_height = min(cur_crop_height, face_img.shape[0])
179
+
180
+ x = cx - cur_crop_width // 2
181
+ cx = cur_crop_width // 2
182
+ if x < 0:
183
+ cx = cx + x
184
+ x = 0
185
+ elif x + cur_crop_width > w:
186
+ cx = cx + (x + cur_crop_width - w)
187
+ x = w - cur_crop_width
188
+ face_img = face_img[:, x:x+cur_crop_width]
189
+
190
+ y = cy - cur_crop_height // 2
191
+ cy = cur_crop_height // 2
192
+ if y < 0:
193
+ cy = cy + y
194
+ y = 0
195
+ elif y + cur_crop_height > h:
196
+ cy = cy + (y + cur_crop_height - h)
197
+ y = h - cur_crop_height
198
+ face_img = face_img[y:y + cur_crop_height]
199
+
200
+ # # debug
201
+ # print(path, cx, cy, angle)
202
+ # crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8))
203
+ # cv2.imshow("image", crp)
204
+ # if cv2.waitKey() == 27:
205
+ # break
206
+ # cv2.destroyAllWindows()
207
+
208
+ # debug
209
+ if args.debug:
210
+ cv2.rectangle(face_img, (cx-fw//2, cy-fh//2), (cx+fw//2, cy+fh//2), (255, 0, 255), fw//20)
211
+
212
+ _, buf = cv2.imencode(output_extension, face_img)
213
+ with open(os.path.join(args.dst_dir, f"{basename}{face_suffix}_{cx:04d}_{cy:04d}_{fw:04d}_{fh:04d}{output_extension}"), "wb") as f:
214
+ buf.tofile(f)
215
+
216
+
217
+ if __name__ == '__main__':
218
+ parser = argparse.ArgumentParser()
219
+ parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ")
220
+ parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ")
221
+ parser.add_argument("--rotate", action="store_true", help="rotate images to align faces / 顔が正立するように画像を回転する")
222
+ parser.add_argument("--resize_fit", action="store_true",
223
+ help="resize to fit smaller side after cropping / 切り出し後の画像の短辺がcrop_sizeにあうようにリサイズする")
224
+ parser.add_argument("--resize_face_size", type=int, default=None,
225
+ help="resize image before cropping by face size / 切り出し前に顔がこのサイズになるようにリサイズする")
226
+ parser.add_argument("--crop_size", type=str, default=None,
227
+ help="crop images with 'width,height' pixels, face centered / 顔を中心として'幅,高さ'のサイズで切り出す")
228
+ parser.add_argument("--crop_ratio", type=str, default=None,
229
+ help="crop images with 'horizontal,vertical' ratio to face, face centered / 顔を中心として顔サイズの'幅倍率,高さ倍率'のサイズで切り出す")
230
+ parser.add_argument("--min_size", type=int, default=None,
231
+ help="minimum face size to output (included) / 処理対象とする顔の最小サイズ(この値以上)")
232
+ parser.add_argument("--max_size", type=int, default=None,
233
+ help="maximum face size to output (excluded) / 処理対象とする顔の最大サイズ(この値未満)")
234
+ parser.add_argument("--multiple_faces", action="store_true",
235
+ help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す")
236
+ parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します")
237
+ args = parser.parse_args()
238
+
239
+ process(args)
tools/resize_images_to_resolution.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import cv2
4
+ import argparse
5
+ import shutil
6
+ import math
7
+ from PIL import Image
8
+ import numpy as np
9
+
10
+
11
+ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
12
+ # Split the max_resolution string by "," and strip any whitespaces
13
+ max_resolutions = [res.strip() for res in max_resolution.split(',')]
14
+
15
+ # # Calculate max_pixels from max_resolution string
16
+ # max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
17
+
18
+ # Create destination folder if it does not exist
19
+ if not os.path.exists(dst_img_folder):
20
+ os.makedirs(dst_img_folder)
21
+
22
+ # Select interpolation method
23
+ if interpolation == 'lanczos4':
24
+ cv2_interpolation = cv2.INTER_LANCZOS4
25
+ elif interpolation == 'cubic':
26
+ cv2_interpolation = cv2.INTER_CUBIC
27
+ else:
28
+ cv2_interpolation = cv2.INTER_AREA
29
+
30
+ # Iterate through all files in src_img_folder
31
+ img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
32
+ for filename in os.listdir(src_img_folder):
33
+ # Check if the image is png, jpg or webp etc...
34
+ if not filename.endswith(img_exts):
35
+ # Copy the file to the destination folder if not png, jpg or webp etc (.txt or .caption or etc.)
36
+ shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
37
+ continue
38
+
39
+ # Load image
40
+ # img = cv2.imread(os.path.join(src_img_folder, filename))
41
+ image = Image.open(os.path.join(src_img_folder, filename))
42
+ if not image.mode == "RGB":
43
+ image = image.convert("RGB")
44
+ img = np.array(image, np.uint8)
45
+
46
+ base, _ = os.path.splitext(filename)
47
+ for max_resolution in max_resolutions:
48
+ # Calculate max_pixels from max_resolution string
49
+ max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
50
+
51
+ # Calculate current number of pixels
52
+ current_pixels = img.shape[0] * img.shape[1]
53
+
54
+ # Check if the image needs resizing
55
+ if current_pixels > max_pixels:
56
+ # Calculate scaling factor
57
+ scale_factor = max_pixels / current_pixels
58
+
59
+ # Calculate new dimensions
60
+ new_height = int(img.shape[0] * math.sqrt(scale_factor))
61
+ new_width = int(img.shape[1] * math.sqrt(scale_factor))
62
+
63
+ # Resize image
64
+ img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
65
+ else:
66
+ new_height, new_width = img.shape[0:2]
67
+
68
+ # Calculate the new height and width that are divisible by divisible_by (with/without resizing)
69
+ new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
70
+ new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
71
+
72
+ # Center crop the image to the calculated dimensions
73
+ y = int((img.shape[0] - new_height) / 2)
74
+ x = int((img.shape[1] - new_width) / 2)
75
+ img = img[y:y + new_height, x:x + new_width]
76
+
77
+ # Split filename into base and extension
78
+ new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg')
79
+
80
+ # Save resized image in dst_img_folder
81
+ # cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
82
+ image = Image.fromarray(img)
83
+ image.save(os.path.join(dst_img_folder, new_filename), quality=100)
84
+
85
+ proc = "Resized" if current_pixels > max_pixels else "Saved"
86
+ print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
87
+
88
+ # If other files with same basename, copy them with resolution suffix
89
+ if copy_associated_files:
90
+ asoc_files = glob.glob(os.path.join(src_img_folder, base + ".*"))
91
+ for asoc_file in asoc_files:
92
+ ext = os.path.splitext(asoc_file)[1]
93
+ if ext in img_exts:
94
+ continue
95
+ for max_resolution in max_resolutions:
96
+ new_asoc_file = base + '+' + max_resolution + ext
97
+ print(f"Copy {asoc_file} as {new_asoc_file}")
98
+ shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))
99
+
100
+
101
+ def main():
102
+ parser = argparse.ArgumentParser(
103
+ description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします')
104
+ parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ')
105
+ parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images / リサイズ後の画像を保存するフォルダ')
106
+ parser.add_argument('--max_resolution', type=str,
107
+ help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
108
+ parser.add_argument('--divisible_by', type=int,
109
+ help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
110
+ parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
111
+ default='area', help='Interpolation method for resizing / リサイズ時の補完方法')
112
+ parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
113
+ parser.add_argument('--copy_associated_files', action='store_true',
114
+ help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')
115
+
116
+ args = parser.parse_args()
117
+ resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution,
118
+ args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files)
119
+
120
+
121
+ if __name__ == '__main__':
122
+ main()