abc
commited on
Commit
·
26a0909
1
Parent(s):
3249d87
Upload 55 files
Browse files- .gitattributes +34 -1
- append_module.py +56 -378
- build/lib/library/__init__.py +0 -0
- build/lib/library/model_util.py +1180 -0
- build/lib/library/train_util.py +1796 -0
- fine_tune.py +45 -50
- gen_img_diffusers.py +55 -234
- library.egg-info/PKG-INFO +4 -0
- library.egg-info/SOURCES.txt +10 -0
- library.egg-info/dependency_links.txt +1 -0
- library.egg-info/top_level.txt +1 -0
- library/model_util.py +1 -5
- library/train_util.py +229 -853
- lora_train_popup.py +862 -0
- lycoris/kohya.py +0 -17
- lycoris/loha.py +1 -6
- lycoris/utils.py +2 -69
- networks/check_lora_weights.py +1 -1
- networks/extract_lora_from_models.py +25 -44
- networks/lora.py +30 -191
- networks/merge_lora.py +5 -11
- networks/resize_lora.py +50 -187
- networks/svd_merge_lora.py +18 -40
- requirements.txt +1 -2
- requirements_startup.txt +23 -0
- train_db.py +45 -47
- train_network.py +175 -248
- train_network_opt.py +373 -324
- train_textual_inversion.py +58 -72
.gitattributes
CHANGED
@@ -1 +1,34 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
append_module.py
CHANGED
@@ -2,19 +2,7 @@ import argparse
|
|
2 |
import json
|
3 |
import shutil
|
4 |
import time
|
5 |
-
from typing import
|
6 |
-
Dict,
|
7 |
-
List,
|
8 |
-
NamedTuple,
|
9 |
-
Optional,
|
10 |
-
Sequence,
|
11 |
-
Tuple,
|
12 |
-
Union,
|
13 |
-
)
|
14 |
-
from dataclasses import (
|
15 |
-
asdict,
|
16 |
-
dataclass,
|
17 |
-
)
|
18 |
from accelerate import Accelerator
|
19 |
from torch.autograd.function import Function
|
20 |
import glob
|
@@ -40,7 +28,6 @@ import safetensors.torch
|
|
40 |
|
41 |
import library.model_util as model_util
|
42 |
import library.train_util as train_util
|
43 |
-
import library.config_util as config_util
|
44 |
|
45 |
#============================================================================================================
|
46 |
#AdafactorScheduleに暫定的にinitial_lrを層別に適用できるようにしたもの
|
@@ -128,124 +115,6 @@ def make_bucket_resolutions_fix(max_reso, min_reso, min_size=256, max_size=1024,
|
|
128 |
return area_size_resos_list, area_size_list
|
129 |
|
130 |
#============================================================================================================
|
131 |
-
#config_util 内より
|
132 |
-
#============================================================================================================
|
133 |
-
@dataclass
|
134 |
-
class DreamBoothDatasetParams(config_util.DreamBoothDatasetParams):
|
135 |
-
min_resolution: Optional[Tuple[int, int]] = None
|
136 |
-
area_step : int = 2
|
137 |
-
|
138 |
-
class ConfigSanitizer(config_util.ConfigSanitizer):
|
139 |
-
#@config_util.curry
|
140 |
-
@staticmethod
|
141 |
-
def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
|
142 |
-
config_util.Schema(config_util.ExactSequence([klass, klass]))(value)
|
143 |
-
return tuple(value)
|
144 |
-
|
145 |
-
#@config_util.curry
|
146 |
-
@staticmethod
|
147 |
-
def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
|
148 |
-
config_util.Schema(config_util.Any(klass, config_util.ExactSequence([klass, klass])))(value)
|
149 |
-
try:
|
150 |
-
config_util.Schema(klass)(value)
|
151 |
-
return (value, value)
|
152 |
-
except:
|
153 |
-
return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
|
154 |
-
# datasets schema
|
155 |
-
DATASET_ASCENDABLE_SCHEMA = {
|
156 |
-
"batch_size": int,
|
157 |
-
"bucket_no_upscale": bool,
|
158 |
-
"bucket_reso_steps": int,
|
159 |
-
"enable_bucket": bool,
|
160 |
-
"max_bucket_reso": int,
|
161 |
-
"min_bucket_reso": int,
|
162 |
-
"resolution": config_util.functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
163 |
-
"min_resolution": config_util.functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
164 |
-
"area_step": int,
|
165 |
-
}
|
166 |
-
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None:
|
167 |
-
super().__init__(support_dreambooth, support_finetuning, support_dropout)
|
168 |
-
def _check(self):
|
169 |
-
print(self.db_dataset_schema)
|
170 |
-
|
171 |
-
class BlueprintGenerator(config_util.BlueprintGenerator):
|
172 |
-
def __init__(self, sanitizer: ConfigSanitizer):
|
173 |
-
config_util.DreamBoothDatasetParams = DreamBoothDatasetParams
|
174 |
-
super().__init__(sanitizer)
|
175 |
-
|
176 |
-
def generate_dataset_group_by_blueprint(dataset_group_blueprint: config_util.DatasetGroupBlueprint):
|
177 |
-
datasets: List[Union[DreamBoothDataset, train_util.FineTuningDataset]] = []
|
178 |
-
|
179 |
-
for dataset_blueprint in dataset_group_blueprint.datasets:
|
180 |
-
if dataset_blueprint.is_dreambooth:
|
181 |
-
subset_klass = train_util.DreamBoothSubset
|
182 |
-
dataset_klass = DreamBoothDataset
|
183 |
-
else:
|
184 |
-
subset_klass = train_util.FineTuningSubset
|
185 |
-
dataset_klass = train_util.FineTuningDataset
|
186 |
-
|
187 |
-
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
188 |
-
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
|
189 |
-
datasets.append(dataset)
|
190 |
-
|
191 |
-
# print info
|
192 |
-
info = ""
|
193 |
-
for i, dataset in enumerate(datasets):
|
194 |
-
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
195 |
-
info += config_util.dedent(f"""\
|
196 |
-
[Dataset {i}]
|
197 |
-
batch_size: {dataset.batch_size}
|
198 |
-
resolution: {(dataset.width, dataset.height)}
|
199 |
-
enable_bucket: {dataset.enable_bucket}
|
200 |
-
""")
|
201 |
-
|
202 |
-
if dataset.enable_bucket:
|
203 |
-
info += config_util.indent(config_util.dedent(f"""\
|
204 |
-
min_bucket_reso: {dataset.min_bucket_reso}
|
205 |
-
max_bucket_reso: {dataset.max_bucket_reso}
|
206 |
-
bucket_reso_steps: {dataset.bucket_reso_steps}
|
207 |
-
bucket_no_upscale: {dataset.bucket_no_upscale}
|
208 |
-
\n"""), " ")
|
209 |
-
else:
|
210 |
-
info += "\n"
|
211 |
-
|
212 |
-
for j, subset in enumerate(dataset.subsets):
|
213 |
-
info += config_util.indent(config_util.dedent(f"""\
|
214 |
-
[Subset {j} of Dataset {i}]
|
215 |
-
image_dir: "{subset.image_dir}"
|
216 |
-
image_count: {subset.img_count}
|
217 |
-
num_repeats: {subset.num_repeats}
|
218 |
-
shuffle_caption: {subset.shuffle_caption}
|
219 |
-
keep_tokens: {subset.keep_tokens}
|
220 |
-
caption_dropout_rate: {subset.caption_dropout_rate}
|
221 |
-
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
222 |
-
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
223 |
-
color_aug: {subset.color_aug}
|
224 |
-
flip_aug: {subset.flip_aug}
|
225 |
-
face_crop_aug_range: {subset.face_crop_aug_range}
|
226 |
-
random_crop: {subset.random_crop}
|
227 |
-
"""), " ")
|
228 |
-
|
229 |
-
if is_dreambooth:
|
230 |
-
info += config_util.indent(config_util.dedent(f"""\
|
231 |
-
is_reg: {subset.is_reg}
|
232 |
-
class_tokens: {subset.class_tokens}
|
233 |
-
caption_extension: {subset.caption_extension}
|
234 |
-
\n"""), " ")
|
235 |
-
else:
|
236 |
-
info += config_util.indent(config_util.dedent(f"""\
|
237 |
-
metadata_file: {subset.metadata_file}
|
238 |
-
\n"""), " ")
|
239 |
-
|
240 |
-
print(info)
|
241 |
-
|
242 |
-
# make buckets first because it determines the length of dataset
|
243 |
-
for i, dataset in enumerate(datasets):
|
244 |
-
print(f"[Dataset {i}]")
|
245 |
-
dataset.make_buckets()
|
246 |
-
|
247 |
-
return train_util.DatasetGroup(datasets)
|
248 |
-
#============================================================================================================
|
249 |
#train_util 内より
|
250 |
#============================================================================================================
|
251 |
class BucketManager_append(train_util.BucketManager):
|
@@ -310,7 +179,7 @@ class BucketManager_append(train_util.BucketManager):
|
|
310 |
bucket_size_id_list.append(bucket_size_id + i + 1)
|
311 |
_min_error = 1000.
|
312 |
_min_id = bucket_size_id
|
313 |
-
for now_size_id in
|
314 |
self.predefined_aspect_ratios = self.predefined_aspect_ratios_list[now_size_id]
|
315 |
ar_errors = self.predefined_aspect_ratios - aspect_ratio
|
316 |
ar_error = np.abs(ar_errors).min()
|
@@ -384,13 +253,13 @@ class BucketManager_append(train_util.BucketManager):
|
|
384 |
return reso, resized_size, ar_error
|
385 |
|
386 |
class DreamBoothDataset(train_util.DreamBoothDataset):
|
387 |
-
def __init__(self,
|
388 |
print("use append DreamBoothDataset")
|
389 |
self.min_resolution = min_resolution
|
390 |
self.area_step = area_step
|
391 |
-
super().__init__(
|
392 |
-
|
393 |
-
|
394 |
def make_buckets(self):
|
395 |
'''
|
396 |
bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
|
@@ -483,50 +352,40 @@ class DreamBoothDataset(train_util.DreamBoothDataset):
|
|
483 |
self.shuffle_buckets()
|
484 |
self._length = len(self.buckets_indices)
|
485 |
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
#============================================================================================================
|
500 |
#networks.lora
|
501 |
#============================================================================================================
|
502 |
-
|
503 |
-
def replace_prepare_optimizer_params(networks
|
504 |
-
def prepare_optimizer_params(self, text_encoder_lr, unet_lr,
|
505 |
-
|
506 |
def enumerate_params(loras, lora_name=None):
|
507 |
params = []
|
508 |
for lora in loras:
|
509 |
if lora_name is not None:
|
510 |
-
|
511 |
-
|
512 |
-
lora_names = [lora_name]
|
513 |
-
if "attentions" in lora_name:
|
514 |
-
lora_names.append(lora_name.replace("attentions", "resnets"))
|
515 |
-
elif "lora_unet_up_blocks_0_resnets_2" in lora_name:
|
516 |
-
lora_names.append("lora_unet_up_blocks_0_upsamplers_")
|
517 |
-
elif "lora_unet_up_blocks_1_attentions_2_" in lora_name:
|
518 |
-
lora_names.append("lora_unet_up_blocks_1_upsamplers_")
|
519 |
-
elif "lora_unet_up_blocks_2_attentions_2_" in lora_name:
|
520 |
-
lora_names.append("lora_unet_up_blocks_2_upsamplers_")
|
521 |
-
|
522 |
-
for _name in lora_names:
|
523 |
-
if _name in lora.lora_name:
|
524 |
-
get_param_flag = True
|
525 |
-
break
|
526 |
-
else:
|
527 |
-
if lora_name in lora.lora_name:
|
528 |
-
get_param_flag = True
|
529 |
-
if get_param_flag: params.extend(lora.parameters())
|
530 |
else:
|
531 |
params.extend(lora.parameters())
|
532 |
return params
|
@@ -534,7 +393,6 @@ def replace_prepare_optimizer_params(networks, network_module):
|
|
534 |
self.requires_grad_(True)
|
535 |
all_params = []
|
536 |
ret_scheduler_lr = []
|
537 |
-
used_names = []
|
538 |
|
539 |
if loranames is not None:
|
540 |
textencoder_names = [None]
|
@@ -547,181 +405,37 @@ def replace_prepare_optimizer_params(networks, network_module):
|
|
547 |
if self.text_encoder_loras:
|
548 |
for textencoder_name in textencoder_names:
|
549 |
param_data = {'params': enumerate_params(self.text_encoder_loras, lora_name=textencoder_name)}
|
550 |
-
used_names.append(textencoder_name)
|
551 |
if text_encoder_lr is not None:
|
552 |
param_data['lr'] = text_encoder_lr
|
553 |
-
|
554 |
-
|
555 |
-
param_data['lr'] = lr_dic[textencoder_name]
|
556 |
-
print(f"{textencoder_name} lr: {param_data['lr']}")
|
557 |
-
|
558 |
-
if block_args_dic is not None:
|
559 |
-
if "lora_te_" in block_args_dic:
|
560 |
-
for pname, value in block_args_dic["lora_te_"].items():
|
561 |
-
param_data[pname] = value
|
562 |
-
if textencoder_name in block_args_dic:
|
563 |
-
for pname, value in block_args_dic[textencoder_name].items():
|
564 |
-
param_data[pname] = value
|
565 |
-
|
566 |
-
if text_encoder_lr is not None:
|
567 |
-
ret_scheduler_lr.append(text_encoder_lr)
|
568 |
-
else:
|
569 |
-
ret_scheduler_lr.append(0.)
|
570 |
-
if lr_dic is not None:
|
571 |
-
if textencoder_name in lr_dic:
|
572 |
-
ret_scheduler_lr[-1] = lr_dic[textencoder_name]
|
573 |
all_params.append(param_data)
|
574 |
|
575 |
if self.unet_loras:
|
576 |
for unet_name in unet_names:
|
577 |
param_data = {'params': enumerate_params(self.unet_loras, lora_name=unet_name)}
|
578 |
-
if len(param_data["params"])==0: continue
|
579 |
-
used_names.append(unet_name)
|
580 |
if unet_lr is not None:
|
581 |
param_data['lr'] = unet_lr
|
582 |
-
|
583 |
-
|
584 |
-
param_data['lr'] = lr_dic[unet_name]
|
585 |
-
print(f"{unet_name} lr: {param_data['lr']}")
|
586 |
-
|
587 |
-
if block_args_dic is not None:
|
588 |
-
if "lora_unet_" in block_args_dic:
|
589 |
-
for pname, value in block_args_dic["lora_unet_"].items():
|
590 |
-
param_data[pname] = value
|
591 |
-
if unet_name in block_args_dic:
|
592 |
-
for pname, value in block_args_dic[unet_name].items():
|
593 |
-
param_data[pname] = value
|
594 |
-
|
595 |
-
if unet_lr is not None:
|
596 |
-
ret_scheduler_lr.append(unet_lr)
|
597 |
-
else:
|
598 |
-
ret_scheduler_lr.append(0.)
|
599 |
-
if lr_dic is not None:
|
600 |
-
if unet_name in lr_dic:
|
601 |
-
ret_scheduler_lr[-1] = lr_dic[unet_name]
|
602 |
all_params.append(param_data)
|
603 |
|
604 |
-
return all_params,
|
605 |
-
|
606 |
-
|
607 |
-
except:
|
608 |
-
print("cant't replace prepare_optimizer_params")
|
609 |
|
610 |
#============================================================================================================
|
611 |
#新規追加
|
612 |
#============================================================================================================
|
613 |
def add_append_arguments(parser: argparse.ArgumentParser):
|
614 |
# for train_network_opt.py
|
615 |
-
|
616 |
-
|
617 |
-
parser.add_argument("--use_lookahead", action="store_true")
|
618 |
-
parser.add_argument("--lookahead_arg", type=str, nargs="*", default=None)
|
619 |
parser.add_argument("--split_lora_networks", action="store_true")
|
620 |
parser.add_argument("--split_lora_level", type=int, default=0, help="どれくらい細分化するかの設定 0がunetのみを層別に 1がunetを大枠で分割 2がtextencoder含めて層別")
|
621 |
-
parser.add_argument("--blocks_lr_setting", type=str, default=None)
|
622 |
-
parser.add_argument("--block_optim_args", type=str, nargs="*", default=None)
|
623 |
parser.add_argument("--min_resolution", type=str, default=None)
|
624 |
parser.add_argument("--area_step", type=int, default=1)
|
625 |
parser.add_argument("--config", type=str, default=None)
|
626 |
-
parser.add_argument("--not_output_config", action="store_true")
|
627 |
-
|
628 |
-
class MyNetwork_Names:
|
629 |
-
ex_block_weight_dic = {
|
630 |
-
"BASE": ["te"],
|
631 |
-
"IN01": ["down_0_at_0","donw_0_res_0"], "IN02": ["down_0_at_1","down_0_res_1"], "IN03": ["down_0_down"],
|
632 |
-
"IN04": ["down_1_at_0","donw_1_res_0"], "IN05": ["down_1_at_1","donw_1_res_1"], "IN06": ["down_1_down"],
|
633 |
-
"IN07": ["down_2_at_0","donw_2_res_0"], "IN08": ["down_2_at_1","donw_2_res_1"], "IN09": ["down_2_down"],
|
634 |
-
"IN10": ["down_3_res_0"], "IN11": ["down_3_res_1"],
|
635 |
-
"MID": ["mid"],
|
636 |
-
"OUT00": ["up_0_res_0"], "OUT01": ["up_0_res_1"], "OUT02": ["up_0_res_2", "up_0_up"],
|
637 |
-
"OUT03": ["up_1_at_0", "up_1_res_0"], "OUT04": ["up_1_at_1", "up_1_res_1"], "OUT05": ["up_1_at_2", "up_1_res_2", "up_1_up"],
|
638 |
-
"OUT06": ["up_2_at_0", "up_2_res_0"], "OUT07": ["up_2_at_1", "up_2_res_1"], "OUT08": ["up_2_at_2", "up_2_res_2", "up_2_up"],
|
639 |
-
"OUT09": ["up_3_at_0", "up_3_res_0"], "OUT10": ["up_3_at_1", "up_3_res_1"], "OUT11": ["up_3_at_2", "up_3_res_2"],
|
640 |
-
}
|
641 |
-
|
642 |
-
blocks_name_dic = { "te": "lora_te_",
|
643 |
-
"unet": "lora_unet_",
|
644 |
-
"mid": "lora_unet_mid_block_",
|
645 |
-
"down": "lora_unet_down_blocks_",
|
646 |
-
"up": "lora_unet_up_blocks_"}
|
647 |
-
for i in range(12):
|
648 |
-
blocks_name_dic[f"te_{i}"] = f"lora_te_text_model_encoder_layers_{i}_"
|
649 |
-
for i in range(3):
|
650 |
-
blocks_name_dic[f"down_{i}"] = f"lora_unet_down_blocks_{i}"
|
651 |
-
blocks_name_dic[f"up_{i+1}"] = f"lora_unet_up_blocks_{i+1}"
|
652 |
-
for i in range(4):
|
653 |
-
for j in range(2):
|
654 |
-
if i<=2: blocks_name_dic[f"down_{i}_at_{j}"] = f"lora_unet_down_blocks_{i}_attentions_{j}_"
|
655 |
-
blocks_name_dic[f"down_{i}_res_{j}"] = f"lora_unet_down_blocks_{i}_resnets_{j}"
|
656 |
-
for j in range(3):
|
657 |
-
if i>=1: blocks_name_dic[f"up_{i}_at_{j}"] = f"lora_unet_up_blocks_{i}_attentions_{j}_"
|
658 |
-
blocks_name_dic[f"up_{i}_res_{j}"] = f"lora_unet_up_blocks_{i}_resnets_{j}"
|
659 |
-
if i<=2:
|
660 |
-
blocks_name_dic[f"down_{i}_down"] = f"lora_unet_down_blocks_{i}_downsamplers_"
|
661 |
-
blocks_name_dic[f"up_{i}_up"] = f"lora_unet_up_blocks_{i}_upsamplers_"
|
662 |
-
|
663 |
-
def create_lr_blocks(lr_setting_str=None, block_optim_args=None):
|
664 |
-
ex_block_weight_dic = MyNetwork_Names.ex_block_weight_dic
|
665 |
-
blocks_name_dic = MyNetwork_Names.blocks_name_dic
|
666 |
-
|
667 |
-
lr_dic = {}
|
668 |
-
if lr_setting_str==None or lr_setting_str=="":
|
669 |
-
pass
|
670 |
-
else:
|
671 |
-
lr_settings = lr_setting_str.replace(" ", "").split(",")
|
672 |
-
for lr_setting in lr_settings:
|
673 |
-
key, value = lr_setting.split("=")
|
674 |
-
if key in ex_block_weight_dic:
|
675 |
-
keys = ex_block_weight_dic[key]
|
676 |
-
else:
|
677 |
-
keys = [key]
|
678 |
-
for key in keys:
|
679 |
-
if key in blocks_name_dic:
|
680 |
-
new_key = blocks_name_dic[key]
|
681 |
-
lr_dic[new_key] = float(value)
|
682 |
-
if len(lr_dic)==0:
|
683 |
-
lr_dic = None
|
684 |
-
|
685 |
-
args_dic = {}
|
686 |
-
if (block_optim_args is None):
|
687 |
-
block_optim_args = []
|
688 |
-
if (len(block_optim_args)>0):
|
689 |
-
for my_arg in block_optim_args:
|
690 |
-
my_arg = my_arg.replace(" ", "")
|
691 |
-
splits = my_arg.split(":")
|
692 |
-
b_name = splits[0]
|
693 |
-
|
694 |
-
key, _value = splits[1].split("=")
|
695 |
-
value_type = float
|
696 |
-
if len(splits)==3:
|
697 |
-
if _value=="str":
|
698 |
-
value_type = str
|
699 |
-
elif _value=="int":
|
700 |
-
value_type = int
|
701 |
-
_value = splits[2]
|
702 |
-
if _value=="true" or _value=="false":
|
703 |
-
value_type = bool
|
704 |
-
if "," in _value:
|
705 |
-
_value = _value.split(",")
|
706 |
-
for i in range(len(_value)):
|
707 |
-
_value[i] = value_type(_value[i])
|
708 |
-
value=tuple(_value)
|
709 |
-
else:
|
710 |
-
value = value_type(_value)
|
711 |
-
|
712 |
-
if b_name in ex_block_weight_dic:
|
713 |
-
b_names = ex_block_weight_dic[b_name]
|
714 |
-
else:
|
715 |
-
b_names = [b_name]
|
716 |
-
for b_name in b_names:
|
717 |
-
new_b_name = blocks_name_dic[b_name]
|
718 |
-
if not new_b_name in args_dic:
|
719 |
-
args_dic[new_b_name] = {}
|
720 |
-
args_dic[new_b_name][key] = value
|
721 |
-
|
722 |
-
if len(args_dic)==0:
|
723 |
-
args_dic = None
|
724 |
-
return lr_dic, args_dic
|
725 |
|
726 |
def create_split_names(split_flag, split_level):
|
727 |
split_names = None
|
@@ -732,28 +446,14 @@ def create_split_names(split_flag, split_level):
|
|
732 |
if split_level==1:
|
733 |
unet_names.append(f"lora_unet_down_blocks_")
|
734 |
unet_names.append(f"lora_unet_up_blocks_")
|
735 |
-
elif split_level==2 or split_level==0
|
736 |
-
if split_level
|
737 |
text_encoder_names = []
|
738 |
for i in range(12):
|
739 |
text_encoder_names.append(f"lora_te_text_model_encoder_layers_{i}_")
|
740 |
-
|
741 |
-
|
742 |
-
|
743 |
-
unet_names.append(f"lora_unet_down_blocks_{i}")
|
744 |
-
unet_names.append(f"lora_unet_up_blocks_{i+1}")
|
745 |
-
|
746 |
-
if split_level>=3:
|
747 |
-
for i in range(4):
|
748 |
-
for j in range(2):
|
749 |
-
if i<=2: unet_names.append(f"lora_unet_down_blocks_{i}_attentions_{j}_")
|
750 |
-
if i== 3: unet_names.append(f"lora_unet_down_blocks_{i}_resnets_{j}")
|
751 |
-
for j in range(3):
|
752 |
-
if i>=1: unet_names.append(f"lora_unet_up_blocks_{i}_attentions_{j}_")
|
753 |
-
if i==0: unet_names.append(f"lora_unet_up_blocks_{i}_resnets_{j}")
|
754 |
-
if i<=2:
|
755 |
-
unet_names.append(f"lora_unet_down_blocks_{i}_downsamplers_")
|
756 |
-
|
757 |
split_names["text_encoder"] = text_encoder_names
|
758 |
split_names["unet"] = unet_names
|
759 |
return split_names
|
@@ -765,7 +465,7 @@ def get_config(parser):
|
|
765 |
import datetime
|
766 |
if os.path.splitext(args.config)[-1] == ".yaml":
|
767 |
args.config = os.path.splitext(args.config)[0]
|
768 |
-
config_path = f"{args.config}.yaml"
|
769 |
if os.path.exists(config_path):
|
770 |
print(f"{config_path} から設定を読み込み中...")
|
771 |
margs, rest = parser.parse_known_args()
|
@@ -786,41 +486,19 @@ def get_config(parser):
|
|
786 |
args_type_dic[key] = act.type
|
787 |
#データタイプの確認とargsにkeyの内容を代入していく
|
788 |
for key, v in configs.items():
|
789 |
-
if
|
790 |
-
if key
|
791 |
-
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
if not type(v) == args_type_dic[key]:
|
797 |
v = args_type_dic[key](v)
|
798 |
-
|
799 |
#最後にデフォから指定が変わってるものを変更する
|
800 |
for key, v in change_def_dic.items():
|
801 |
args_dic[key] = v
|
802 |
else:
|
803 |
print(f"{config_path} が見つかりませんでした")
|
804 |
return args
|
805 |
-
|
806 |
-
'''
|
807 |
-
class GradientReversalFunction(torch.autograd.Function):
|
808 |
-
@staticmethod
|
809 |
-
def forward(ctx, input_forward: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
810 |
-
ctx.save_for_backward(scale)
|
811 |
-
return input_forward
|
812 |
-
@staticmethod
|
813 |
-
def backward(ctx, grad_backward: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
814 |
-
scale, = ctx.saved_tensors
|
815 |
-
return scale * -grad_backward, None
|
816 |
-
|
817 |
-
class GradientReversal(torch.nn.Module):
|
818 |
-
def __init__(self, scale: float):
|
819 |
-
super(GradientReversal, self).__init__()
|
820 |
-
self.scale = torch.tensor(scale)
|
821 |
-
def forward(self, x: torch.Tensor, flag: bool = False) -> torch.Tensor:
|
822 |
-
if flag:
|
823 |
-
return x
|
824 |
-
else:
|
825 |
-
return GradientReversalFunction.apply(x, self.scale)
|
826 |
-
'''
|
|
|
2 |
import json
|
3 |
import shutil
|
4 |
import time
|
5 |
+
from typing import Dict, List, NamedTuple, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from accelerate import Accelerator
|
7 |
from torch.autograd.function import Function
|
8 |
import glob
|
|
|
28 |
|
29 |
import library.model_util as model_util
|
30 |
import library.train_util as train_util
|
|
|
31 |
|
32 |
#============================================================================================================
|
33 |
#AdafactorScheduleに暫定的にinitial_lrを層別に適用できるようにしたもの
|
|
|
115 |
return area_size_resos_list, area_size_list
|
116 |
|
117 |
#============================================================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
#train_util 内より
|
119 |
#============================================================================================================
|
120 |
class BucketManager_append(train_util.BucketManager):
|
|
|
179 |
bucket_size_id_list.append(bucket_size_id + i + 1)
|
180 |
_min_error = 1000.
|
181 |
_min_id = bucket_size_id
|
182 |
+
for now_size_id in bucket_size_id:
|
183 |
self.predefined_aspect_ratios = self.predefined_aspect_ratios_list[now_size_id]
|
184 |
ar_errors = self.predefined_aspect_ratios - aspect_ratio
|
185 |
ar_error = np.abs(ar_errors).min()
|
|
|
253 |
return reso, resized_size, ar_error
|
254 |
|
255 |
class DreamBoothDataset(train_util.DreamBoothDataset):
|
256 |
+
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset, min_resolution=None, area_step=None) -> None:
|
257 |
print("use append DreamBoothDataset")
|
258 |
self.min_resolution = min_resolution
|
259 |
self.area_step = area_step
|
260 |
+
super().__init__(batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens,
|
261 |
+
resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight,
|
262 |
+
flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
263 |
def make_buckets(self):
|
264 |
'''
|
265 |
bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
|
|
|
352 |
self.shuffle_buckets()
|
353 |
self._length = len(self.buckets_indices)
|
354 |
|
355 |
+
class FineTuningDataset(train_util.FineTuningDataset):
|
356 |
+
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
|
357 |
+
train_util.glob_images = glob_images
|
358 |
+
super().__init__( json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
359 |
+
resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range,
|
360 |
+
random_crop, dataset_repeats, debug_dataset)
|
361 |
+
|
362 |
+
def glob_images(directory, base="*", npz_flag=True):
|
363 |
+
img_paths = []
|
364 |
+
dots = []
|
365 |
+
for ext in train_util.IMAGE_EXTENSIONS:
|
366 |
+
dots.append(ext)
|
367 |
+
if npz_flag:
|
368 |
+
dots.append(".npz")
|
369 |
+
for ext in dots:
|
370 |
+
if base == '*':
|
371 |
+
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
|
372 |
+
else:
|
373 |
+
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
374 |
+
return img_paths
|
375 |
+
|
376 |
#============================================================================================================
|
377 |
#networks.lora
|
378 |
#============================================================================================================
|
379 |
+
from networks.lora import LoRANetwork
|
380 |
+
def replace_prepare_optimizer_params(networks):
|
381 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, scheduler_lr=None, loranames=None):
|
382 |
+
|
383 |
def enumerate_params(loras, lora_name=None):
|
384 |
params = []
|
385 |
for lora in loras:
|
386 |
if lora_name is not None:
|
387 |
+
if lora_name in lora.lora_name:
|
388 |
+
params.extend(lora.parameters())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
else:
|
390 |
params.extend(lora.parameters())
|
391 |
return params
|
|
|
393 |
self.requires_grad_(True)
|
394 |
all_params = []
|
395 |
ret_scheduler_lr = []
|
|
|
396 |
|
397 |
if loranames is not None:
|
398 |
textencoder_names = [None]
|
|
|
405 |
if self.text_encoder_loras:
|
406 |
for textencoder_name in textencoder_names:
|
407 |
param_data = {'params': enumerate_params(self.text_encoder_loras, lora_name=textencoder_name)}
|
|
|
408 |
if text_encoder_lr is not None:
|
409 |
param_data['lr'] = text_encoder_lr
|
410 |
+
if scheduler_lr is not None:
|
411 |
+
ret_scheduler_lr.append(scheduler_lr[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
all_params.append(param_data)
|
413 |
|
414 |
if self.unet_loras:
|
415 |
for unet_name in unet_names:
|
416 |
param_data = {'params': enumerate_params(self.unet_loras, lora_name=unet_name)}
|
|
|
|
|
417 |
if unet_lr is not None:
|
418 |
param_data['lr'] = unet_lr
|
419 |
+
if scheduler_lr is not None:
|
420 |
+
ret_scheduler_lr.append(scheduler_lr[1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
421 |
all_params.append(param_data)
|
422 |
|
423 |
+
return all_params, ret_scheduler_lr
|
424 |
+
|
425 |
+
LoRANetwork.prepare_optimizer_params = prepare_optimizer_params
|
|
|
|
|
426 |
|
427 |
#============================================================================================================
|
428 |
#新規追加
|
429 |
#============================================================================================================
|
430 |
def add_append_arguments(parser: argparse.ArgumentParser):
|
431 |
# for train_network_opt.py
|
432 |
+
parser.add_argument("--optimizer", type=str, default="AdamW", choices=["AdamW", "RAdam", "AdaBound", "AdaBelief", "AggMo", "AdamP", "Adastand", "Adastand_belief", "Apollo", "Lamb", "Ranger", "RangerVA", "Lookahead_Adam", "Lookahead_DiffGrad", "Yogi", "NovoGrad", "QHAdam", "DiffGrad", "MADGRAD", "Adafactor"], help="使用するoptimizerを指定する")
|
433 |
+
parser.add_argument("--optimizer_arg", type=str, default=None, nargs='*')
|
|
|
|
|
434 |
parser.add_argument("--split_lora_networks", action="store_true")
|
435 |
parser.add_argument("--split_lora_level", type=int, default=0, help="どれくらい細分化するかの設定 0がunetのみを層別に 1がunetを大枠で分割 2がtextencoder含めて層別")
|
|
|
|
|
436 |
parser.add_argument("--min_resolution", type=str, default=None)
|
437 |
parser.add_argument("--area_step", type=int, default=1)
|
438 |
parser.add_argument("--config", type=str, default=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
|
440 |
def create_split_names(split_flag, split_level):
|
441 |
split_names = None
|
|
|
446 |
if split_level==1:
|
447 |
unet_names.append(f"lora_unet_down_blocks_")
|
448 |
unet_names.append(f"lora_unet_up_blocks_")
|
449 |
+
elif split_level==2 or split_level==0:
|
450 |
+
if split_level==2:
|
451 |
text_encoder_names = []
|
452 |
for i in range(12):
|
453 |
text_encoder_names.append(f"lora_te_text_model_encoder_layers_{i}_")
|
454 |
+
for i in range(3):
|
455 |
+
unet_names.append(f"lora_unet_down_blocks_{i}")
|
456 |
+
unet_names.append(f"lora_unet_up_blocks_{i+1}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
457 |
split_names["text_encoder"] = text_encoder_names
|
458 |
split_names["unet"] = unet_names
|
459 |
return split_names
|
|
|
465 |
import datetime
|
466 |
if os.path.splitext(args.config)[-1] == ".yaml":
|
467 |
args.config = os.path.splitext(args.config)[0]
|
468 |
+
config_path = f"./{args.config}.yaml"
|
469 |
if os.path.exists(config_path):
|
470 |
print(f"{config_path} から設定を読み込み中...")
|
471 |
margs, rest = parser.parse_known_args()
|
|
|
486 |
args_type_dic[key] = act.type
|
487 |
#データタイプの確認とargsにkeyの内容を代入していく
|
488 |
for key, v in configs.items():
|
489 |
+
if key in args_dic:
|
490 |
+
if args_dic[key] is not None:
|
491 |
+
new_type = type(args_dic[key])
|
492 |
+
if (not type(v) == new_type) and (not new_type==list):
|
493 |
+
v = new_type(v)
|
494 |
+
else:
|
495 |
+
if v is not None:
|
496 |
if not type(v) == args_type_dic[key]:
|
497 |
v = args_type_dic[key](v)
|
498 |
+
args_dic[key] = v
|
499 |
#最後にデフォから指定が変わってるものを変更する
|
500 |
for key, v in change_def_dic.items():
|
501 |
args_dic[key] = v
|
502 |
else:
|
503 |
print(f"{config_path} が見つかりませんでした")
|
504 |
return args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/lib/library/__init__.py
ADDED
File without changes
|
build/lib/library/model_util.py
ADDED
@@ -0,0 +1,1180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# v1: split from train_db_fixed.py.
|
2 |
+
# v2: support safetensors
|
3 |
+
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
|
8 |
+
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
9 |
+
from safetensors.torch import load_file, save_file
|
10 |
+
|
11 |
+
# DiffUsers版StableDiffusionのモデルパラメータ
|
12 |
+
NUM_TRAIN_TIMESTEPS = 1000
|
13 |
+
BETA_START = 0.00085
|
14 |
+
BETA_END = 0.0120
|
15 |
+
|
16 |
+
UNET_PARAMS_MODEL_CHANNELS = 320
|
17 |
+
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
|
18 |
+
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
|
19 |
+
UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
|
20 |
+
UNET_PARAMS_IN_CHANNELS = 4
|
21 |
+
UNET_PARAMS_OUT_CHANNELS = 4
|
22 |
+
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
23 |
+
UNET_PARAMS_CONTEXT_DIM = 768
|
24 |
+
UNET_PARAMS_NUM_HEADS = 8
|
25 |
+
|
26 |
+
VAE_PARAMS_Z_CHANNELS = 4
|
27 |
+
VAE_PARAMS_RESOLUTION = 256
|
28 |
+
VAE_PARAMS_IN_CHANNELS = 3
|
29 |
+
VAE_PARAMS_OUT_CH = 3
|
30 |
+
VAE_PARAMS_CH = 128
|
31 |
+
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
|
32 |
+
VAE_PARAMS_NUM_RES_BLOCKS = 2
|
33 |
+
|
34 |
+
# V2
|
35 |
+
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
36 |
+
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
37 |
+
|
38 |
+
# Diffusersの設定を読み込むための参照モデル
|
39 |
+
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
|
40 |
+
DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
|
41 |
+
|
42 |
+
|
43 |
+
# region StableDiffusion->Diffusersの変換コード
|
44 |
+
# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
|
45 |
+
|
46 |
+
|
47 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
48 |
+
"""
|
49 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
50 |
+
"""
|
51 |
+
if n_shave_prefix_segments >= 0:
|
52 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
53 |
+
else:
|
54 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
55 |
+
|
56 |
+
|
57 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
58 |
+
"""
|
59 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
60 |
+
"""
|
61 |
+
mapping = []
|
62 |
+
for old_item in old_list:
|
63 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
64 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
65 |
+
|
66 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
67 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
68 |
+
|
69 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
70 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
71 |
+
|
72 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
73 |
+
|
74 |
+
mapping.append({"old": old_item, "new": new_item})
|
75 |
+
|
76 |
+
return mapping
|
77 |
+
|
78 |
+
|
79 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
80 |
+
"""
|
81 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
82 |
+
"""
|
83 |
+
mapping = []
|
84 |
+
for old_item in old_list:
|
85 |
+
new_item = old_item
|
86 |
+
|
87 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
88 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
89 |
+
|
90 |
+
mapping.append({"old": old_item, "new": new_item})
|
91 |
+
|
92 |
+
return mapping
|
93 |
+
|
94 |
+
|
95 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
96 |
+
"""
|
97 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
98 |
+
"""
|
99 |
+
mapping = []
|
100 |
+
for old_item in old_list:
|
101 |
+
new_item = old_item
|
102 |
+
|
103 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
104 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
105 |
+
|
106 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
107 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
108 |
+
|
109 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
110 |
+
|
111 |
+
mapping.append({"old": old_item, "new": new_item})
|
112 |
+
|
113 |
+
return mapping
|
114 |
+
|
115 |
+
|
116 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
117 |
+
"""
|
118 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
119 |
+
"""
|
120 |
+
mapping = []
|
121 |
+
for old_item in old_list:
|
122 |
+
new_item = old_item
|
123 |
+
|
124 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
125 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
126 |
+
|
127 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
128 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
129 |
+
|
130 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
131 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
132 |
+
|
133 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
134 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
135 |
+
|
136 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
137 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
138 |
+
|
139 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
140 |
+
|
141 |
+
mapping.append({"old": old_item, "new": new_item})
|
142 |
+
|
143 |
+
return mapping
|
144 |
+
|
145 |
+
|
146 |
+
def assign_to_checkpoint(
|
147 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
148 |
+
):
|
149 |
+
"""
|
150 |
+
This does the final conversion step: take locally converted weights and apply a global renaming
|
151 |
+
to them. It splits attention layers, and takes into account additional replacements
|
152 |
+
that may arise.
|
153 |
+
|
154 |
+
Assigns the weights to the new checkpoint.
|
155 |
+
"""
|
156 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
157 |
+
|
158 |
+
# Splits the attention layers into three variables.
|
159 |
+
if attention_paths_to_split is not None:
|
160 |
+
for path, path_map in attention_paths_to_split.items():
|
161 |
+
old_tensor = old_checkpoint[path]
|
162 |
+
channels = old_tensor.shape[0] // 3
|
163 |
+
|
164 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
165 |
+
|
166 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
167 |
+
|
168 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
169 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
170 |
+
|
171 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
172 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
173 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
174 |
+
|
175 |
+
for path in paths:
|
176 |
+
new_path = path["new"]
|
177 |
+
|
178 |
+
# These have already been assigned
|
179 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
180 |
+
continue
|
181 |
+
|
182 |
+
# Global renaming happens here
|
183 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
184 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
185 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
186 |
+
|
187 |
+
if additional_replacements is not None:
|
188 |
+
for replacement in additional_replacements:
|
189 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
190 |
+
|
191 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
192 |
+
if "proj_attn.weight" in new_path:
|
193 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
194 |
+
else:
|
195 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
196 |
+
|
197 |
+
|
198 |
+
def conv_attn_to_linear(checkpoint):
|
199 |
+
keys = list(checkpoint.keys())
|
200 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
201 |
+
for key in keys:
|
202 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
203 |
+
if checkpoint[key].ndim > 2:
|
204 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
205 |
+
elif "proj_attn.weight" in key:
|
206 |
+
if checkpoint[key].ndim > 2:
|
207 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
208 |
+
|
209 |
+
|
210 |
+
def linear_transformer_to_conv(checkpoint):
|
211 |
+
keys = list(checkpoint.keys())
|
212 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
213 |
+
for key in keys:
|
214 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
215 |
+
if checkpoint[key].ndim == 2:
|
216 |
+
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
217 |
+
|
218 |
+
|
219 |
+
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
220 |
+
"""
|
221 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
222 |
+
"""
|
223 |
+
|
224 |
+
# extract state_dict for UNet
|
225 |
+
unet_state_dict = {}
|
226 |
+
unet_key = "model.diffusion_model."
|
227 |
+
keys = list(checkpoint.keys())
|
228 |
+
for key in keys:
|
229 |
+
if key.startswith(unet_key):
|
230 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
231 |
+
|
232 |
+
new_checkpoint = {}
|
233 |
+
|
234 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
235 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
236 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
237 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
238 |
+
|
239 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
240 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
241 |
+
|
242 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
243 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
244 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
245 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
246 |
+
|
247 |
+
# Retrieves the keys for the input blocks only
|
248 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
249 |
+
input_blocks = {
|
250 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
|
251 |
+
for layer_id in range(num_input_blocks)
|
252 |
+
}
|
253 |
+
|
254 |
+
# Retrieves the keys for the middle blocks only
|
255 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
256 |
+
middle_blocks = {
|
257 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
|
258 |
+
for layer_id in range(num_middle_blocks)
|
259 |
+
}
|
260 |
+
|
261 |
+
# Retrieves the keys for the output blocks only
|
262 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
263 |
+
output_blocks = {
|
264 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
|
265 |
+
for layer_id in range(num_output_blocks)
|
266 |
+
}
|
267 |
+
|
268 |
+
for i in range(1, num_input_blocks):
|
269 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
270 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
271 |
+
|
272 |
+
resnets = [
|
273 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
274 |
+
]
|
275 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
276 |
+
|
277 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
278 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
279 |
+
f"input_blocks.{i}.0.op.weight"
|
280 |
+
)
|
281 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
282 |
+
f"input_blocks.{i}.0.op.bias"
|
283 |
+
)
|
284 |
+
|
285 |
+
paths = renew_resnet_paths(resnets)
|
286 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
287 |
+
assign_to_checkpoint(
|
288 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
289 |
+
)
|
290 |
+
|
291 |
+
if len(attentions):
|
292 |
+
paths = renew_attention_paths(attentions)
|
293 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
294 |
+
assign_to_checkpoint(
|
295 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
296 |
+
)
|
297 |
+
|
298 |
+
resnet_0 = middle_blocks[0]
|
299 |
+
attentions = middle_blocks[1]
|
300 |
+
resnet_1 = middle_blocks[2]
|
301 |
+
|
302 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
303 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
304 |
+
|
305 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
306 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
307 |
+
|
308 |
+
attentions_paths = renew_attention_paths(attentions)
|
309 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
310 |
+
assign_to_checkpoint(
|
311 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
312 |
+
)
|
313 |
+
|
314 |
+
for i in range(num_output_blocks):
|
315 |
+
block_id = i // (config["layers_per_block"] + 1)
|
316 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
317 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
318 |
+
output_block_list = {}
|
319 |
+
|
320 |
+
for layer in output_block_layers:
|
321 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
322 |
+
if layer_id in output_block_list:
|
323 |
+
output_block_list[layer_id].append(layer_name)
|
324 |
+
else:
|
325 |
+
output_block_list[layer_id] = [layer_name]
|
326 |
+
|
327 |
+
if len(output_block_list) > 1:
|
328 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
329 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
330 |
+
|
331 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
332 |
+
paths = renew_resnet_paths(resnets)
|
333 |
+
|
334 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
335 |
+
assign_to_checkpoint(
|
336 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
337 |
+
)
|
338 |
+
|
339 |
+
# オリジナル:
|
340 |
+
# if ["conv.weight", "conv.bias"] in output_block_list.values():
|
341 |
+
# index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
342 |
+
|
343 |
+
# biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
|
344 |
+
for l in output_block_list.values():
|
345 |
+
l.sort()
|
346 |
+
|
347 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
348 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
349 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
350 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
351 |
+
]
|
352 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
353 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
354 |
+
]
|
355 |
+
|
356 |
+
# Clear attentions as they have been attributed above.
|
357 |
+
if len(attentions) == 2:
|
358 |
+
attentions = []
|
359 |
+
|
360 |
+
if len(attentions):
|
361 |
+
paths = renew_attention_paths(attentions)
|
362 |
+
meta_path = {
|
363 |
+
"old": f"output_blocks.{i}.1",
|
364 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
365 |
+
}
|
366 |
+
assign_to_checkpoint(
|
367 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
368 |
+
)
|
369 |
+
else:
|
370 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
371 |
+
for path in resnet_0_paths:
|
372 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
373 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
374 |
+
|
375 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
376 |
+
|
377 |
+
# SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
|
378 |
+
if v2:
|
379 |
+
linear_transformer_to_conv(new_checkpoint)
|
380 |
+
|
381 |
+
return new_checkpoint
|
382 |
+
|
383 |
+
|
384 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
385 |
+
# extract state dict for VAE
|
386 |
+
vae_state_dict = {}
|
387 |
+
vae_key = "first_stage_model."
|
388 |
+
keys = list(checkpoint.keys())
|
389 |
+
for key in keys:
|
390 |
+
if key.startswith(vae_key):
|
391 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
392 |
+
# if len(vae_state_dict) == 0:
|
393 |
+
# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
|
394 |
+
# vae_state_dict = checkpoint
|
395 |
+
|
396 |
+
new_checkpoint = {}
|
397 |
+
|
398 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
399 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
400 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
401 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
402 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
403 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
404 |
+
|
405 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
406 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
407 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
408 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
409 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
410 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
411 |
+
|
412 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
413 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
414 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
415 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
416 |
+
|
417 |
+
# Retrieves the keys for the encoder down blocks only
|
418 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
419 |
+
down_blocks = {
|
420 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
421 |
+
}
|
422 |
+
|
423 |
+
# Retrieves the keys for the decoder up blocks only
|
424 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
425 |
+
up_blocks = {
|
426 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
427 |
+
}
|
428 |
+
|
429 |
+
for i in range(num_down_blocks):
|
430 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
431 |
+
|
432 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
433 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
434 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
435 |
+
)
|
436 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
437 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
438 |
+
)
|
439 |
+
|
440 |
+
paths = renew_vae_resnet_paths(resnets)
|
441 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
442 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
443 |
+
|
444 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
445 |
+
num_mid_res_blocks = 2
|
446 |
+
for i in range(1, num_mid_res_blocks + 1):
|
447 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
448 |
+
|
449 |
+
paths = renew_vae_resnet_paths(resnets)
|
450 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
451 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
452 |
+
|
453 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
454 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
455 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
456 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
457 |
+
conv_attn_to_linear(new_checkpoint)
|
458 |
+
|
459 |
+
for i in range(num_up_blocks):
|
460 |
+
block_id = num_up_blocks - 1 - i
|
461 |
+
resnets = [
|
462 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
463 |
+
]
|
464 |
+
|
465 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
466 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
467 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
468 |
+
]
|
469 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
470 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
471 |
+
]
|
472 |
+
|
473 |
+
paths = renew_vae_resnet_paths(resnets)
|
474 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
475 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
476 |
+
|
477 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
478 |
+
num_mid_res_blocks = 2
|
479 |
+
for i in range(1, num_mid_res_blocks + 1):
|
480 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
481 |
+
|
482 |
+
paths = renew_vae_resnet_paths(resnets)
|
483 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
484 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
485 |
+
|
486 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
487 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
488 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
489 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
490 |
+
conv_attn_to_linear(new_checkpoint)
|
491 |
+
return new_checkpoint
|
492 |
+
|
493 |
+
|
494 |
+
def create_unet_diffusers_config(v2):
|
495 |
+
"""
|
496 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
497 |
+
"""
|
498 |
+
# unet_params = original_config.model.params.unet_config.params
|
499 |
+
|
500 |
+
block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
|
501 |
+
|
502 |
+
down_block_types = []
|
503 |
+
resolution = 1
|
504 |
+
for i in range(len(block_out_channels)):
|
505 |
+
block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
|
506 |
+
down_block_types.append(block_type)
|
507 |
+
if i != len(block_out_channels) - 1:
|
508 |
+
resolution *= 2
|
509 |
+
|
510 |
+
up_block_types = []
|
511 |
+
for i in range(len(block_out_channels)):
|
512 |
+
block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
|
513 |
+
up_block_types.append(block_type)
|
514 |
+
resolution //= 2
|
515 |
+
|
516 |
+
config = dict(
|
517 |
+
sample_size=UNET_PARAMS_IMAGE_SIZE,
|
518 |
+
in_channels=UNET_PARAMS_IN_CHANNELS,
|
519 |
+
out_channels=UNET_PARAMS_OUT_CHANNELS,
|
520 |
+
down_block_types=tuple(down_block_types),
|
521 |
+
up_block_types=tuple(up_block_types),
|
522 |
+
block_out_channels=tuple(block_out_channels),
|
523 |
+
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
524 |
+
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
|
525 |
+
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
526 |
+
)
|
527 |
+
|
528 |
+
return config
|
529 |
+
|
530 |
+
|
531 |
+
def create_vae_diffusers_config():
|
532 |
+
"""
|
533 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
534 |
+
"""
|
535 |
+
# vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
536 |
+
# _ = original_config.model.params.first_stage_config.params.embed_dim
|
537 |
+
block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
|
538 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
539 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
540 |
+
|
541 |
+
config = dict(
|
542 |
+
sample_size=VAE_PARAMS_RESOLUTION,
|
543 |
+
in_channels=VAE_PARAMS_IN_CHANNELS,
|
544 |
+
out_channels=VAE_PARAMS_OUT_CH,
|
545 |
+
down_block_types=tuple(down_block_types),
|
546 |
+
up_block_types=tuple(up_block_types),
|
547 |
+
block_out_channels=tuple(block_out_channels),
|
548 |
+
latent_channels=VAE_PARAMS_Z_CHANNELS,
|
549 |
+
layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
|
550 |
+
)
|
551 |
+
return config
|
552 |
+
|
553 |
+
|
554 |
+
def convert_ldm_clip_checkpoint_v1(checkpoint):
|
555 |
+
keys = list(checkpoint.keys())
|
556 |
+
text_model_dict = {}
|
557 |
+
for key in keys:
|
558 |
+
if key.startswith("cond_stage_model.transformer"):
|
559 |
+
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
|
560 |
+
return text_model_dict
|
561 |
+
|
562 |
+
|
563 |
+
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
564 |
+
# 嫌になるくらい違うぞ!
|
565 |
+
def convert_key(key):
|
566 |
+
if not key.startswith("cond_stage_model"):
|
567 |
+
return None
|
568 |
+
|
569 |
+
# common conversion
|
570 |
+
key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
|
571 |
+
key = key.replace("cond_stage_model.model.", "text_model.")
|
572 |
+
|
573 |
+
if "resblocks" in key:
|
574 |
+
# resblocks conversion
|
575 |
+
key = key.replace(".resblocks.", ".layers.")
|
576 |
+
if ".ln_" in key:
|
577 |
+
key = key.replace(".ln_", ".layer_norm")
|
578 |
+
elif ".mlp." in key:
|
579 |
+
key = key.replace(".c_fc.", ".fc1.")
|
580 |
+
key = key.replace(".c_proj.", ".fc2.")
|
581 |
+
elif '.attn.out_proj' in key:
|
582 |
+
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
583 |
+
elif '.attn.in_proj' in key:
|
584 |
+
key = None # 特殊なので後で処理する
|
585 |
+
else:
|
586 |
+
raise ValueError(f"unexpected key in SD: {key}")
|
587 |
+
elif '.positional_embedding' in key:
|
588 |
+
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
589 |
+
elif '.text_projection' in key:
|
590 |
+
key = None # 使われない???
|
591 |
+
elif '.logit_scale' in key:
|
592 |
+
key = None # 使われない???
|
593 |
+
elif '.token_embedding' in key:
|
594 |
+
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
595 |
+
elif '.ln_final' in key:
|
596 |
+
key = key.replace(".ln_final", ".final_layer_norm")
|
597 |
+
return key
|
598 |
+
|
599 |
+
keys = list(checkpoint.keys())
|
600 |
+
new_sd = {}
|
601 |
+
for key in keys:
|
602 |
+
# remove resblocks 23
|
603 |
+
if '.resblocks.23.' in key:
|
604 |
+
continue
|
605 |
+
new_key = convert_key(key)
|
606 |
+
if new_key is None:
|
607 |
+
continue
|
608 |
+
new_sd[new_key] = checkpoint[key]
|
609 |
+
|
610 |
+
# attnの変換
|
611 |
+
for key in keys:
|
612 |
+
if '.resblocks.23.' in key:
|
613 |
+
continue
|
614 |
+
if '.resblocks' in key and '.attn.in_proj_' in key:
|
615 |
+
# 三つに分割
|
616 |
+
values = torch.chunk(checkpoint[key], 3)
|
617 |
+
|
618 |
+
key_suffix = ".weight" if "weight" in key else ".bias"
|
619 |
+
key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
|
620 |
+
key_pfx = key_pfx.replace("_weight", "")
|
621 |
+
key_pfx = key_pfx.replace("_bias", "")
|
622 |
+
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
623 |
+
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
624 |
+
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
625 |
+
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
626 |
+
|
627 |
+
# rename or add position_ids
|
628 |
+
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
|
629 |
+
if ANOTHER_POSITION_IDS_KEY in new_sd:
|
630 |
+
# waifu diffusion v1.4
|
631 |
+
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
|
632 |
+
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
633 |
+
else:
|
634 |
+
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
635 |
+
|
636 |
+
new_sd["text_model.embeddings.position_ids"] = position_ids
|
637 |
+
return new_sd
|
638 |
+
|
639 |
+
# endregion
|
640 |
+
|
641 |
+
|
642 |
+
# region Diffusers->StableDiffusion の変換コード
|
643 |
+
# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
|
644 |
+
|
645 |
+
def conv_transformer_to_linear(checkpoint):
|
646 |
+
keys = list(checkpoint.keys())
|
647 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
648 |
+
for key in keys:
|
649 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
650 |
+
if checkpoint[key].ndim > 2:
|
651 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
652 |
+
|
653 |
+
|
654 |
+
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
655 |
+
unet_conversion_map = [
|
656 |
+
# (stable-diffusion, HF Diffusers)
|
657 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
658 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
659 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
660 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
661 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
662 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
663 |
+
("out.0.weight", "conv_norm_out.weight"),
|
664 |
+
("out.0.bias", "conv_norm_out.bias"),
|
665 |
+
("out.2.weight", "conv_out.weight"),
|
666 |
+
("out.2.bias", "conv_out.bias"),
|
667 |
+
]
|
668 |
+
|
669 |
+
unet_conversion_map_resnet = [
|
670 |
+
# (stable-diffusion, HF Diffusers)
|
671 |
+
("in_layers.0", "norm1"),
|
672 |
+
("in_layers.2", "conv1"),
|
673 |
+
("out_layers.0", "norm2"),
|
674 |
+
("out_layers.3", "conv2"),
|
675 |
+
("emb_layers.1", "time_emb_proj"),
|
676 |
+
("skip_connection", "conv_shortcut"),
|
677 |
+
]
|
678 |
+
|
679 |
+
unet_conversion_map_layer = []
|
680 |
+
for i in range(4):
|
681 |
+
# loop over downblocks/upblocks
|
682 |
+
|
683 |
+
for j in range(2):
|
684 |
+
# loop over resnets/attentions for downblocks
|
685 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
686 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
687 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
688 |
+
|
689 |
+
if i < 3:
|
690 |
+
# no attention layers in down_blocks.3
|
691 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
692 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
693 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
694 |
+
|
695 |
+
for j in range(3):
|
696 |
+
# loop over resnets/attentions for upblocks
|
697 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
698 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
699 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
700 |
+
|
701 |
+
if i > 0:
|
702 |
+
# no attention layers in up_blocks.0
|
703 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
704 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
705 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
706 |
+
|
707 |
+
if i < 3:
|
708 |
+
# no downsample in down_blocks.3
|
709 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
710 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
711 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
712 |
+
|
713 |
+
# no upsample in up_blocks.3
|
714 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
715 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
716 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
717 |
+
|
718 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
719 |
+
sd_mid_atn_prefix = "middle_block.1."
|
720 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
721 |
+
|
722 |
+
for j in range(2):
|
723 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
724 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
725 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
726 |
+
|
727 |
+
# buyer beware: this is a *brittle* function,
|
728 |
+
# and correct output requires that all of these pieces interact in
|
729 |
+
# the exact order in which I have arranged them.
|
730 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
731 |
+
for sd_name, hf_name in unet_conversion_map:
|
732 |
+
mapping[hf_name] = sd_name
|
733 |
+
for k, v in mapping.items():
|
734 |
+
if "resnets" in k:
|
735 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
736 |
+
v = v.replace(hf_part, sd_part)
|
737 |
+
mapping[k] = v
|
738 |
+
for k, v in mapping.items():
|
739 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
740 |
+
v = v.replace(hf_part, sd_part)
|
741 |
+
mapping[k] = v
|
742 |
+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
743 |
+
|
744 |
+
if v2:
|
745 |
+
conv_transformer_to_linear(new_state_dict)
|
746 |
+
|
747 |
+
return new_state_dict
|
748 |
+
|
749 |
+
|
750 |
+
# ================#
|
751 |
+
# VAE Conversion #
|
752 |
+
# ================#
|
753 |
+
|
754 |
+
def reshape_weight_for_sd(w):
|
755 |
+
# convert HF linear weights to SD conv2d weights
|
756 |
+
return w.reshape(*w.shape, 1, 1)
|
757 |
+
|
758 |
+
|
759 |
+
def convert_vae_state_dict(vae_state_dict):
|
760 |
+
vae_conversion_map = [
|
761 |
+
# (stable-diffusion, HF Diffusers)
|
762 |
+
("nin_shortcut", "conv_shortcut"),
|
763 |
+
("norm_out", "conv_norm_out"),
|
764 |
+
("mid.attn_1.", "mid_block.attentions.0."),
|
765 |
+
]
|
766 |
+
|
767 |
+
for i in range(4):
|
768 |
+
# down_blocks have two resnets
|
769 |
+
for j in range(2):
|
770 |
+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
771 |
+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
772 |
+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
773 |
+
|
774 |
+
if i < 3:
|
775 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
776 |
+
sd_downsample_prefix = f"down.{i}.downsample."
|
777 |
+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
778 |
+
|
779 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
780 |
+
sd_upsample_prefix = f"up.{3-i}.upsample."
|
781 |
+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
782 |
+
|
783 |
+
# up_blocks have three resnets
|
784 |
+
# also, up blocks in hf are numbered in reverse from sd
|
785 |
+
for j in range(3):
|
786 |
+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
787 |
+
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
788 |
+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
789 |
+
|
790 |
+
# this part accounts for mid blocks in both the encoder and the decoder
|
791 |
+
for i in range(2):
|
792 |
+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
793 |
+
sd_mid_res_prefix = f"mid.block_{i+1}."
|
794 |
+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
795 |
+
|
796 |
+
vae_conversion_map_attn = [
|
797 |
+
# (stable-diffusion, HF Diffusers)
|
798 |
+
("norm.", "group_norm."),
|
799 |
+
("q.", "query."),
|
800 |
+
("k.", "key."),
|
801 |
+
("v.", "value."),
|
802 |
+
("proj_out.", "proj_attn."),
|
803 |
+
]
|
804 |
+
|
805 |
+
mapping = {k: k for k in vae_state_dict.keys()}
|
806 |
+
for k, v in mapping.items():
|
807 |
+
for sd_part, hf_part in vae_conversion_map:
|
808 |
+
v = v.replace(hf_part, sd_part)
|
809 |
+
mapping[k] = v
|
810 |
+
for k, v in mapping.items():
|
811 |
+
if "attentions" in k:
|
812 |
+
for sd_part, hf_part in vae_conversion_map_attn:
|
813 |
+
v = v.replace(hf_part, sd_part)
|
814 |
+
mapping[k] = v
|
815 |
+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
816 |
+
weights_to_convert = ["q", "k", "v", "proj_out"]
|
817 |
+
for k, v in new_state_dict.items():
|
818 |
+
for weight_name in weights_to_convert:
|
819 |
+
if f"mid.attn_1.{weight_name}.weight" in k:
|
820 |
+
# print(f"Reshaping {k} for SD format")
|
821 |
+
new_state_dict[k] = reshape_weight_for_sd(v)
|
822 |
+
|
823 |
+
return new_state_dict
|
824 |
+
|
825 |
+
|
826 |
+
# endregion
|
827 |
+
|
828 |
+
# region 自作のモデル読み書きなど
|
829 |
+
|
830 |
+
def is_safetensors(path):
|
831 |
+
return os.path.splitext(path)[1].lower() == '.safetensors'
|
832 |
+
|
833 |
+
|
834 |
+
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
835 |
+
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
836 |
+
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
837 |
+
('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
|
838 |
+
('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
|
839 |
+
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
|
840 |
+
]
|
841 |
+
|
842 |
+
if is_safetensors(ckpt_path):
|
843 |
+
checkpoint = None
|
844 |
+
state_dict = load_file(ckpt_path, "cpu")
|
845 |
+
else:
|
846 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
847 |
+
if "state_dict" in checkpoint:
|
848 |
+
state_dict = checkpoint["state_dict"]
|
849 |
+
else:
|
850 |
+
state_dict = checkpoint
|
851 |
+
checkpoint = None
|
852 |
+
|
853 |
+
key_reps = []
|
854 |
+
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
855 |
+
for key in state_dict.keys():
|
856 |
+
if key.startswith(rep_from):
|
857 |
+
new_key = rep_to + key[len(rep_from):]
|
858 |
+
key_reps.append((key, new_key))
|
859 |
+
|
860 |
+
for key, new_key in key_reps:
|
861 |
+
state_dict[new_key] = state_dict[key]
|
862 |
+
del state_dict[key]
|
863 |
+
|
864 |
+
return checkpoint, state_dict
|
865 |
+
|
866 |
+
|
867 |
+
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
868 |
+
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
869 |
+
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
870 |
+
if dtype is not None:
|
871 |
+
for k, v in state_dict.items():
|
872 |
+
if type(v) is torch.Tensor:
|
873 |
+
state_dict[k] = v.to(dtype)
|
874 |
+
|
875 |
+
# Convert the UNet2DConditionModel model.
|
876 |
+
unet_config = create_unet_diffusers_config(v2)
|
877 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
|
878 |
+
|
879 |
+
unet = UNet2DConditionModel(**unet_config)
|
880 |
+
info = unet.load_state_dict(converted_unet_checkpoint)
|
881 |
+
print("loading u-net:", info)
|
882 |
+
|
883 |
+
# Convert the VAE model.
|
884 |
+
vae_config = create_vae_diffusers_config()
|
885 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
|
886 |
+
|
887 |
+
vae = AutoencoderKL(**vae_config)
|
888 |
+
info = vae.load_state_dict(converted_vae_checkpoint)
|
889 |
+
print("loading vae:", info)
|
890 |
+
|
891 |
+
# convert text_model
|
892 |
+
if v2:
|
893 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
|
894 |
+
cfg = CLIPTextConfig(
|
895 |
+
vocab_size=49408,
|
896 |
+
hidden_size=1024,
|
897 |
+
intermediate_size=4096,
|
898 |
+
num_hidden_layers=23,
|
899 |
+
num_attention_heads=16,
|
900 |
+
max_position_embeddings=77,
|
901 |
+
hidden_act="gelu",
|
902 |
+
layer_norm_eps=1e-05,
|
903 |
+
dropout=0.0,
|
904 |
+
attention_dropout=0.0,
|
905 |
+
initializer_range=0.02,
|
906 |
+
initializer_factor=1.0,
|
907 |
+
pad_token_id=1,
|
908 |
+
bos_token_id=0,
|
909 |
+
eos_token_id=2,
|
910 |
+
model_type="clip_text_model",
|
911 |
+
projection_dim=512,
|
912 |
+
torch_dtype="float32",
|
913 |
+
transformers_version="4.25.0.dev0",
|
914 |
+
)
|
915 |
+
text_model = CLIPTextModel._from_config(cfg)
|
916 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
917 |
+
else:
|
918 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
919 |
+
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
920 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
921 |
+
print("loading text encoder:", info)
|
922 |
+
|
923 |
+
return text_model, vae, unet
|
924 |
+
|
925 |
+
|
926 |
+
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
|
927 |
+
def convert_key(key):
|
928 |
+
# position_idsの除去
|
929 |
+
if ".position_ids" in key:
|
930 |
+
return None
|
931 |
+
|
932 |
+
# common
|
933 |
+
key = key.replace("text_model.encoder.", "transformer.")
|
934 |
+
key = key.replace("text_model.", "")
|
935 |
+
if "layers" in key:
|
936 |
+
# resblocks conversion
|
937 |
+
key = key.replace(".layers.", ".resblocks.")
|
938 |
+
if ".layer_norm" in key:
|
939 |
+
key = key.replace(".layer_norm", ".ln_")
|
940 |
+
elif ".mlp." in key:
|
941 |
+
key = key.replace(".fc1.", ".c_fc.")
|
942 |
+
key = key.replace(".fc2.", ".c_proj.")
|
943 |
+
elif '.self_attn.out_proj' in key:
|
944 |
+
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
945 |
+
elif '.self_attn.' in key:
|
946 |
+
key = None # 特殊なので後で処理する
|
947 |
+
else:
|
948 |
+
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
949 |
+
elif '.position_embedding' in key:
|
950 |
+
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
951 |
+
elif '.token_embedding' in key:
|
952 |
+
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
953 |
+
elif 'final_layer_norm' in key:
|
954 |
+
key = key.replace("final_layer_norm", "ln_final")
|
955 |
+
return key
|
956 |
+
|
957 |
+
keys = list(checkpoint.keys())
|
958 |
+
new_sd = {}
|
959 |
+
for key in keys:
|
960 |
+
new_key = convert_key(key)
|
961 |
+
if new_key is None:
|
962 |
+
continue
|
963 |
+
new_sd[new_key] = checkpoint[key]
|
964 |
+
|
965 |
+
# attnの変換
|
966 |
+
for key in keys:
|
967 |
+
if 'layers' in key and 'q_proj' in key:
|
968 |
+
# 三つを結合
|
969 |
+
key_q = key
|
970 |
+
key_k = key.replace("q_proj", "k_proj")
|
971 |
+
key_v = key.replace("q_proj", "v_proj")
|
972 |
+
|
973 |
+
value_q = checkpoint[key_q]
|
974 |
+
value_k = checkpoint[key_k]
|
975 |
+
value_v = checkpoint[key_v]
|
976 |
+
value = torch.cat([value_q, value_k, value_v])
|
977 |
+
|
978 |
+
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
979 |
+
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
980 |
+
new_sd[new_key] = value
|
981 |
+
|
982 |
+
# 最後の層などを捏造するか
|
983 |
+
if make_dummy_weights:
|
984 |
+
print("make dummy weights for resblock.23, text_projection and logit scale.")
|
985 |
+
keys = list(new_sd.keys())
|
986 |
+
for key in keys:
|
987 |
+
if key.startswith("transformer.resblocks.22."):
|
988 |
+
new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
|
989 |
+
|
990 |
+
# Diffusersに含まれない重みを作っておく
|
991 |
+
new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
|
992 |
+
new_sd['logit_scale'] = torch.tensor(1)
|
993 |
+
|
994 |
+
return new_sd
|
995 |
+
|
996 |
+
|
997 |
+
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
|
998 |
+
if ckpt_path is not None:
|
999 |
+
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
1000 |
+
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
1001 |
+
if checkpoint is None: # safetensors または state_dictのckpt
|
1002 |
+
checkpoint = {}
|
1003 |
+
strict = False
|
1004 |
+
else:
|
1005 |
+
strict = True
|
1006 |
+
if "state_dict" in state_dict:
|
1007 |
+
del state_dict["state_dict"]
|
1008 |
+
else:
|
1009 |
+
# 新しく作る
|
1010 |
+
assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
|
1011 |
+
checkpoint = {}
|
1012 |
+
state_dict = {}
|
1013 |
+
strict = False
|
1014 |
+
|
1015 |
+
def update_sd(prefix, sd):
|
1016 |
+
for k, v in sd.items():
|
1017 |
+
key = prefix + k
|
1018 |
+
assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
|
1019 |
+
if save_dtype is not None:
|
1020 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
1021 |
+
state_dict[key] = v
|
1022 |
+
|
1023 |
+
# Convert the UNet model
|
1024 |
+
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
1025 |
+
update_sd("model.diffusion_model.", unet_state_dict)
|
1026 |
+
|
1027 |
+
# Convert the text encoder model
|
1028 |
+
if v2:
|
1029 |
+
make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製��て作るなどダミーの重みを入れる
|
1030 |
+
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
|
1031 |
+
update_sd("cond_stage_model.model.", text_enc_dict)
|
1032 |
+
else:
|
1033 |
+
text_enc_dict = text_encoder.state_dict()
|
1034 |
+
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
1035 |
+
|
1036 |
+
# Convert the VAE
|
1037 |
+
if vae is not None:
|
1038 |
+
vae_dict = convert_vae_state_dict(vae.state_dict())
|
1039 |
+
update_sd("first_stage_model.", vae_dict)
|
1040 |
+
|
1041 |
+
# Put together new checkpoint
|
1042 |
+
key_count = len(state_dict.keys())
|
1043 |
+
new_ckpt = {'state_dict': state_dict}
|
1044 |
+
|
1045 |
+
if 'epoch' in checkpoint:
|
1046 |
+
epochs += checkpoint['epoch']
|
1047 |
+
if 'global_step' in checkpoint:
|
1048 |
+
steps += checkpoint['global_step']
|
1049 |
+
|
1050 |
+
new_ckpt['epoch'] = epochs
|
1051 |
+
new_ckpt['global_step'] = steps
|
1052 |
+
|
1053 |
+
if is_safetensors(output_file):
|
1054 |
+
# TODO Tensor以外のdictの値を削除したほうがいいか
|
1055 |
+
save_file(state_dict, output_file)
|
1056 |
+
else:
|
1057 |
+
torch.save(new_ckpt, output_file)
|
1058 |
+
|
1059 |
+
return key_count
|
1060 |
+
|
1061 |
+
|
1062 |
+
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
|
1063 |
+
if pretrained_model_name_or_path is None:
|
1064 |
+
# load default settings for v1/v2
|
1065 |
+
if v2:
|
1066 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
|
1067 |
+
else:
|
1068 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
|
1069 |
+
|
1070 |
+
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
1071 |
+
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
1072 |
+
if vae is None:
|
1073 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
1074 |
+
|
1075 |
+
pipeline = StableDiffusionPipeline(
|
1076 |
+
unet=unet,
|
1077 |
+
text_encoder=text_encoder,
|
1078 |
+
vae=vae,
|
1079 |
+
scheduler=scheduler,
|
1080 |
+
tokenizer=tokenizer,
|
1081 |
+
safety_checker=None,
|
1082 |
+
feature_extractor=None,
|
1083 |
+
requires_safety_checker=None,
|
1084 |
+
)
|
1085 |
+
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
1086 |
+
|
1087 |
+
|
1088 |
+
VAE_PREFIX = "first_stage_model."
|
1089 |
+
|
1090 |
+
|
1091 |
+
def load_vae(vae_id, dtype):
|
1092 |
+
print(f"load VAE: {vae_id}")
|
1093 |
+
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
1094 |
+
# Diffusers local/remote
|
1095 |
+
try:
|
1096 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
|
1097 |
+
except EnvironmentError as e:
|
1098 |
+
print(f"exception occurs in loading vae: {e}")
|
1099 |
+
print("retry with subfolder='vae'")
|
1100 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
|
1101 |
+
return vae
|
1102 |
+
|
1103 |
+
# local
|
1104 |
+
vae_config = create_vae_diffusers_config()
|
1105 |
+
|
1106 |
+
if vae_id.endswith(".bin"):
|
1107 |
+
# SD 1.5 VAE on Huggingface
|
1108 |
+
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
|
1109 |
+
else:
|
1110 |
+
# StableDiffusion
|
1111 |
+
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
|
1112 |
+
else torch.load(vae_id, map_location="cpu"))
|
1113 |
+
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
|
1114 |
+
|
1115 |
+
# vae only or full model
|
1116 |
+
full_model = False
|
1117 |
+
for vae_key in vae_sd:
|
1118 |
+
if vae_key.startswith(VAE_PREFIX):
|
1119 |
+
full_model = True
|
1120 |
+
break
|
1121 |
+
if not full_model:
|
1122 |
+
sd = {}
|
1123 |
+
for key, value in vae_sd.items():
|
1124 |
+
sd[VAE_PREFIX + key] = value
|
1125 |
+
vae_sd = sd
|
1126 |
+
del sd
|
1127 |
+
|
1128 |
+
# Convert the VAE model.
|
1129 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
1130 |
+
|
1131 |
+
vae = AutoencoderKL(**vae_config)
|
1132 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
1133 |
+
return vae
|
1134 |
+
|
1135 |
+
# endregion
|
1136 |
+
|
1137 |
+
|
1138 |
+
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
|
1139 |
+
max_width, max_height = max_reso
|
1140 |
+
max_area = (max_width // divisible) * (max_height // divisible)
|
1141 |
+
|
1142 |
+
resos = set()
|
1143 |
+
|
1144 |
+
size = int(math.sqrt(max_area)) * divisible
|
1145 |
+
resos.add((size, size))
|
1146 |
+
|
1147 |
+
size = min_size
|
1148 |
+
while size <= max_size:
|
1149 |
+
width = size
|
1150 |
+
height = min(max_size, (max_area // (width // divisible)) * divisible)
|
1151 |
+
resos.add((width, height))
|
1152 |
+
resos.add((height, width))
|
1153 |
+
|
1154 |
+
# # make additional resos
|
1155 |
+
# if width >= height and width - divisible >= min_size:
|
1156 |
+
# resos.add((width - divisible, height))
|
1157 |
+
# resos.add((height, width - divisible))
|
1158 |
+
# if height >= width and height - divisible >= min_size:
|
1159 |
+
# resos.add((width, height - divisible))
|
1160 |
+
# resos.add((height - divisible, width))
|
1161 |
+
|
1162 |
+
size += divisible
|
1163 |
+
|
1164 |
+
resos = list(resos)
|
1165 |
+
resos.sort()
|
1166 |
+
return resos
|
1167 |
+
|
1168 |
+
|
1169 |
+
if __name__ == '__main__':
|
1170 |
+
resos = make_bucket_resolutions((512, 768))
|
1171 |
+
print(len(resos))
|
1172 |
+
print(resos)
|
1173 |
+
aspect_ratios = [w / h for w, h in resos]
|
1174 |
+
print(aspect_ratios)
|
1175 |
+
|
1176 |
+
ars = set()
|
1177 |
+
for ar in aspect_ratios:
|
1178 |
+
if ar in ars:
|
1179 |
+
print("error! duplicate ar:", ar)
|
1180 |
+
ars.add(ar)
|
build/lib/library/train_util.py
ADDED
@@ -0,0 +1,1796 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# common functions for training
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import shutil
|
6 |
+
import time
|
7 |
+
from typing import Dict, List, NamedTuple, Tuple
|
8 |
+
from accelerate import Accelerator
|
9 |
+
from torch.autograd.function import Function
|
10 |
+
import glob
|
11 |
+
import math
|
12 |
+
import os
|
13 |
+
import random
|
14 |
+
import hashlib
|
15 |
+
import subprocess
|
16 |
+
from io import BytesIO
|
17 |
+
|
18 |
+
from tqdm import tqdm
|
19 |
+
import torch
|
20 |
+
from torchvision import transforms
|
21 |
+
from transformers import CLIPTokenizer
|
22 |
+
import diffusers
|
23 |
+
from diffusers import DDPMScheduler, StableDiffusionPipeline
|
24 |
+
import albumentations as albu
|
25 |
+
import numpy as np
|
26 |
+
from PIL import Image
|
27 |
+
import cv2
|
28 |
+
from einops import rearrange
|
29 |
+
from torch import einsum
|
30 |
+
import safetensors.torch
|
31 |
+
|
32 |
+
import library.model_util as model_util
|
33 |
+
|
34 |
+
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
35 |
+
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
36 |
+
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
|
37 |
+
|
38 |
+
# checkpointファイル名
|
39 |
+
EPOCH_STATE_NAME = "{}-{:06d}-state"
|
40 |
+
EPOCH_FILE_NAME = "{}-{:06d}"
|
41 |
+
EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}"
|
42 |
+
LAST_STATE_NAME = "{}-state"
|
43 |
+
DEFAULT_EPOCH_NAME = "epoch"
|
44 |
+
DEFAULT_LAST_OUTPUT_NAME = "last"
|
45 |
+
|
46 |
+
# region dataset
|
47 |
+
|
48 |
+
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
|
49 |
+
# , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux?
|
50 |
+
|
51 |
+
|
52 |
+
class ImageInfo():
|
53 |
+
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
54 |
+
self.image_key: str = image_key
|
55 |
+
self.num_repeats: int = num_repeats
|
56 |
+
self.caption: str = caption
|
57 |
+
self.is_reg: bool = is_reg
|
58 |
+
self.absolute_path: str = absolute_path
|
59 |
+
self.image_size: Tuple[int, int] = None
|
60 |
+
self.resized_size: Tuple[int, int] = None
|
61 |
+
self.bucket_reso: Tuple[int, int] = None
|
62 |
+
self.latents: torch.Tensor = None
|
63 |
+
self.latents_flipped: torch.Tensor = None
|
64 |
+
self.latents_npz: str = None
|
65 |
+
self.latents_npz_flipped: str = None
|
66 |
+
|
67 |
+
|
68 |
+
class BucketManager():
|
69 |
+
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
|
70 |
+
self.no_upscale = no_upscale
|
71 |
+
if max_reso is None:
|
72 |
+
self.max_reso = None
|
73 |
+
self.max_area = None
|
74 |
+
else:
|
75 |
+
self.max_reso = max_reso
|
76 |
+
self.max_area = max_reso[0] * max_reso[1]
|
77 |
+
self.min_size = min_size
|
78 |
+
self.max_size = max_size
|
79 |
+
self.reso_steps = reso_steps
|
80 |
+
|
81 |
+
self.resos = []
|
82 |
+
self.reso_to_id = {}
|
83 |
+
self.buckets = [] # 前処理時は (image_key, image)、学習時は image_key
|
84 |
+
|
85 |
+
def add_image(self, reso, image):
|
86 |
+
bucket_id = self.reso_to_id[reso]
|
87 |
+
self.buckets[bucket_id].append(image)
|
88 |
+
|
89 |
+
def shuffle(self):
|
90 |
+
for bucket in self.buckets:
|
91 |
+
random.shuffle(bucket)
|
92 |
+
|
93 |
+
def sort(self):
|
94 |
+
# 解像度順にソートする(表示時、メタデータ格納時の見栄えをよくするためだけ)。bucketsも入れ替えてreso_to_idも振り直す
|
95 |
+
sorted_resos = self.resos.copy()
|
96 |
+
sorted_resos.sort()
|
97 |
+
|
98 |
+
sorted_buckets = []
|
99 |
+
sorted_reso_to_id = {}
|
100 |
+
for i, reso in enumerate(sorted_resos):
|
101 |
+
bucket_id = self.reso_to_id[reso]
|
102 |
+
sorted_buckets.append(self.buckets[bucket_id])
|
103 |
+
sorted_reso_to_id[reso] = i
|
104 |
+
|
105 |
+
self.resos = sorted_resos
|
106 |
+
self.buckets = sorted_buckets
|
107 |
+
self.reso_to_id = sorted_reso_to_id
|
108 |
+
|
109 |
+
def make_buckets(self):
|
110 |
+
resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps)
|
111 |
+
self.set_predefined_resos(resos)
|
112 |
+
|
113 |
+
def set_predefined_resos(self, resos):
|
114 |
+
# 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
|
115 |
+
self.predefined_resos = resos.copy()
|
116 |
+
self.predefined_resos_set = set(resos)
|
117 |
+
self.predefined_aspect_ratios = np.array([w / h for w, h in resos])
|
118 |
+
|
119 |
+
def add_if_new_reso(self, reso):
|
120 |
+
if reso not in self.reso_to_id:
|
121 |
+
bucket_id = len(self.resos)
|
122 |
+
self.reso_to_id[reso] = bucket_id
|
123 |
+
self.resos.append(reso)
|
124 |
+
self.buckets.append([])
|
125 |
+
# print(reso, bucket_id, len(self.buckets))
|
126 |
+
|
127 |
+
def round_to_steps(self, x):
|
128 |
+
x = int(x + .5)
|
129 |
+
return x - x % self.reso_steps
|
130 |
+
|
131 |
+
def select_bucket(self, image_width, image_height):
|
132 |
+
aspect_ratio = image_width / image_height
|
133 |
+
if not self.no_upscale:
|
134 |
+
# 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する
|
135 |
+
reso = (image_width, image_height)
|
136 |
+
if reso in self.predefined_resos_set:
|
137 |
+
pass
|
138 |
+
else:
|
139 |
+
ar_errors = self.predefined_aspect_ratios - aspect_ratio
|
140 |
+
predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
|
141 |
+
reso = self.predefined_resos[predefined_bucket_id]
|
142 |
+
|
143 |
+
ar_reso = reso[0] / reso[1]
|
144 |
+
if aspect_ratio > ar_reso: # 横が長い→縦を合わせる
|
145 |
+
scale = reso[1] / image_height
|
146 |
+
else:
|
147 |
+
scale = reso[0] / image_width
|
148 |
+
|
149 |
+
resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
|
150 |
+
# print("use predef", image_width, image_height, reso, resized_size)
|
151 |
+
else:
|
152 |
+
if image_width * image_height > self.max_area:
|
153 |
+
# 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める
|
154 |
+
resized_width = math.sqrt(self.max_area * aspect_ratio)
|
155 |
+
resized_height = self.max_area / resized_width
|
156 |
+
assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal"
|
157 |
+
|
158 |
+
# リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ
|
159 |
+
# 元のbucketingと同じロジック
|
160 |
+
b_width_rounded = self.round_to_steps(resized_width)
|
161 |
+
b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio)
|
162 |
+
ar_width_rounded = b_width_rounded / b_height_in_wr
|
163 |
+
|
164 |
+
b_height_rounded = self.round_to_steps(resized_height)
|
165 |
+
b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio)
|
166 |
+
ar_height_rounded = b_width_in_hr / b_height_rounded
|
167 |
+
|
168 |
+
# print(b_width_rounded, b_height_in_wr, ar_width_rounded)
|
169 |
+
# print(b_width_in_hr, b_height_rounded, ar_height_rounded)
|
170 |
+
|
171 |
+
if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio):
|
172 |
+
resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + .5))
|
173 |
+
else:
|
174 |
+
resized_size = (int(b_height_rounded * aspect_ratio + .5), b_height_rounded)
|
175 |
+
# print(resized_size)
|
176 |
+
else:
|
177 |
+
resized_size = (image_width, image_height) # リサイズは不要
|
178 |
+
|
179 |
+
# 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする)
|
180 |
+
bucket_width = resized_size[0] - resized_size[0] % self.reso_steps
|
181 |
+
bucket_height = resized_size[1] - resized_size[1] % self.reso_steps
|
182 |
+
# print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height)
|
183 |
+
|
184 |
+
reso = (bucket_width, bucket_height)
|
185 |
+
|
186 |
+
self.add_if_new_reso(reso)
|
187 |
+
|
188 |
+
ar_error = (reso[0] / reso[1]) - aspect_ratio
|
189 |
+
return reso, resized_size, ar_error
|
190 |
+
|
191 |
+
|
192 |
+
class BucketBatchIndex(NamedTuple):
|
193 |
+
bucket_index: int
|
194 |
+
bucket_batch_size: int
|
195 |
+
batch_index: int
|
196 |
+
|
197 |
+
|
198 |
+
class BaseDataset(torch.utils.data.Dataset):
|
199 |
+
def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
|
200 |
+
super().__init__()
|
201 |
+
self.tokenizer: CLIPTokenizer = tokenizer
|
202 |
+
self.max_token_length = max_token_length
|
203 |
+
self.shuffle_caption = shuffle_caption
|
204 |
+
self.shuffle_keep_tokens = shuffle_keep_tokens
|
205 |
+
# width/height is used when enable_bucket==False
|
206 |
+
self.width, self.height = (None, None) if resolution is None else resolution
|
207 |
+
self.face_crop_aug_range = face_crop_aug_range
|
208 |
+
self.flip_aug = flip_aug
|
209 |
+
self.color_aug = color_aug
|
210 |
+
self.debug_dataset = debug_dataset
|
211 |
+
self.random_crop = random_crop
|
212 |
+
self.token_padding_disabled = False
|
213 |
+
self.dataset_dirs_info = {}
|
214 |
+
self.reg_dataset_dirs_info = {}
|
215 |
+
self.tag_frequency = {}
|
216 |
+
|
217 |
+
self.enable_bucket = False
|
218 |
+
self.bucket_manager: BucketManager = None # not initialized
|
219 |
+
self.min_bucket_reso = None
|
220 |
+
self.max_bucket_reso = None
|
221 |
+
self.bucket_reso_steps = None
|
222 |
+
self.bucket_no_upscale = None
|
223 |
+
self.bucket_info = None # for metadata
|
224 |
+
|
225 |
+
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
226 |
+
|
227 |
+
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
228 |
+
self.dropout_rate: float = 0
|
229 |
+
self.dropout_every_n_epochs: int = None
|
230 |
+
self.tag_dropout_rate: float = 0
|
231 |
+
|
232 |
+
# augmentation
|
233 |
+
flip_p = 0.5 if flip_aug else 0.0
|
234 |
+
if color_aug:
|
235 |
+
# わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
|
236 |
+
self.aug = albu.Compose([
|
237 |
+
albu.OneOf([
|
238 |
+
albu.HueSaturationValue(8, 0, 0, p=.5),
|
239 |
+
albu.RandomGamma((95, 105), p=.5),
|
240 |
+
], p=.33),
|
241 |
+
albu.HorizontalFlip(p=flip_p)
|
242 |
+
], p=1.)
|
243 |
+
elif flip_aug:
|
244 |
+
self.aug = albu.Compose([
|
245 |
+
albu.HorizontalFlip(p=flip_p)
|
246 |
+
], p=1.)
|
247 |
+
else:
|
248 |
+
self.aug = None
|
249 |
+
|
250 |
+
self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
|
251 |
+
|
252 |
+
self.image_data: Dict[str, ImageInfo] = {}
|
253 |
+
|
254 |
+
self.replacements = {}
|
255 |
+
|
256 |
+
def set_current_epoch(self, epoch):
|
257 |
+
self.current_epoch = epoch
|
258 |
+
|
259 |
+
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
|
260 |
+
# コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
|
261 |
+
self.dropout_rate = dropout_rate
|
262 |
+
self.dropout_every_n_epochs = dropout_every_n_epochs
|
263 |
+
self.tag_dropout_rate = tag_dropout_rate
|
264 |
+
|
265 |
+
def set_tag_frequency(self, dir_name, captions):
|
266 |
+
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
267 |
+
self.tag_frequency[dir_name] = frequency_for_dir
|
268 |
+
for caption in captions:
|
269 |
+
for tag in caption.split(","):
|
270 |
+
if tag and not tag.isspace():
|
271 |
+
tag = tag.lower()
|
272 |
+
frequency = frequency_for_dir.get(tag, 0)
|
273 |
+
frequency_for_dir[tag] = frequency + 1
|
274 |
+
|
275 |
+
def disable_token_padding(self):
|
276 |
+
self.token_padding_disabled = True
|
277 |
+
|
278 |
+
def add_replacement(self, str_from, str_to):
|
279 |
+
self.replacements[str_from] = str_to
|
280 |
+
|
281 |
+
def process_caption(self, caption):
|
282 |
+
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
283 |
+
is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
|
284 |
+
is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
|
285 |
+
|
286 |
+
if is_drop_out:
|
287 |
+
caption = ""
|
288 |
+
else:
|
289 |
+
if self.shuffle_caption or self.tag_dropout_rate > 0:
|
290 |
+
def dropout_tags(tokens):
|
291 |
+
if self.tag_dropout_rate <= 0:
|
292 |
+
return tokens
|
293 |
+
l = []
|
294 |
+
for token in tokens:
|
295 |
+
if random.random() >= self.tag_dropout_rate:
|
296 |
+
l.append(token)
|
297 |
+
return l
|
298 |
+
|
299 |
+
tokens = [t.strip() for t in caption.strip().split(",")]
|
300 |
+
if self.shuffle_keep_tokens is None:
|
301 |
+
if self.shuffle_caption:
|
302 |
+
random.shuffle(tokens)
|
303 |
+
|
304 |
+
tokens = dropout_tags(tokens)
|
305 |
+
else:
|
306 |
+
if len(tokens) > self.shuffle_keep_tokens:
|
307 |
+
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
308 |
+
tokens = tokens[self.shuffle_keep_tokens:]
|
309 |
+
|
310 |
+
if self.shuffle_caption:
|
311 |
+
random.shuffle(tokens)
|
312 |
+
|
313 |
+
tokens = dropout_tags(tokens)
|
314 |
+
|
315 |
+
tokens = keep_tokens + tokens
|
316 |
+
caption = ", ".join(tokens)
|
317 |
+
|
318 |
+
# textual inversion対応
|
319 |
+
for str_from, str_to in self.replacements.items():
|
320 |
+
if str_from == "":
|
321 |
+
# replace all
|
322 |
+
if type(str_to) == list:
|
323 |
+
caption = random.choice(str_to)
|
324 |
+
else:
|
325 |
+
caption = str_to
|
326 |
+
else:
|
327 |
+
caption = caption.replace(str_from, str_to)
|
328 |
+
|
329 |
+
return caption
|
330 |
+
|
331 |
+
def get_input_ids(self, caption):
|
332 |
+
input_ids = self.tokenizer(caption, padding="max_length", truncation=True,
|
333 |
+
max_length=self.tokenizer_max_length, return_tensors="pt").input_ids
|
334 |
+
|
335 |
+
if self.tokenizer_max_length > self.tokenizer.model_max_length:
|
336 |
+
input_ids = input_ids.squeeze(0)
|
337 |
+
iids_list = []
|
338 |
+
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
|
339 |
+
# v1
|
340 |
+
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
|
341 |
+
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
|
342 |
+
for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75)
|
343 |
+
ids_chunk = (input_ids[0].unsqueeze(0),
|
344 |
+
input_ids[i:i + self.tokenizer.model_max_length - 2],
|
345 |
+
input_ids[-1].unsqueeze(0))
|
346 |
+
ids_chunk = torch.cat(ids_chunk)
|
347 |
+
iids_list.append(ids_chunk)
|
348 |
+
else:
|
349 |
+
# v2
|
350 |
+
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
|
351 |
+
for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2):
|
352 |
+
ids_chunk = (input_ids[0].unsqueeze(0), # BOS
|
353 |
+
input_ids[i:i + self.tokenizer.model_max_length - 2],
|
354 |
+
input_ids[-1].unsqueeze(0)) # PAD or EOS
|
355 |
+
ids_chunk = torch.cat(ids_chunk)
|
356 |
+
|
357 |
+
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
|
358 |
+
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
|
359 |
+
if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id:
|
360 |
+
ids_chunk[-1] = self.tokenizer.eos_token_id
|
361 |
+
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
|
362 |
+
if ids_chunk[1] == self.tokenizer.pad_token_id:
|
363 |
+
ids_chunk[1] = self.tokenizer.eos_token_id
|
364 |
+
|
365 |
+
iids_list.append(ids_chunk)
|
366 |
+
|
367 |
+
input_ids = torch.stack(iids_list) # 3,77
|
368 |
+
return input_ids
|
369 |
+
|
370 |
+
def register_image(self, info: ImageInfo):
|
371 |
+
self.image_data[info.image_key] = info
|
372 |
+
|
373 |
+
def make_buckets(self):
|
374 |
+
'''
|
375 |
+
bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
|
376 |
+
min_size and max_size are ignored when enable_bucket is False
|
377 |
+
'''
|
378 |
+
print("loading image sizes.")
|
379 |
+
for info in tqdm(self.image_data.values()):
|
380 |
+
if info.image_size is None:
|
381 |
+
info.image_size = self.get_image_size(info.absolute_path)
|
382 |
+
|
383 |
+
if self.enable_bucket:
|
384 |
+
print("make buckets")
|
385 |
+
else:
|
386 |
+
print("prepare dataset")
|
387 |
+
|
388 |
+
# bucketを作成し、画像をbucketに振り分ける
|
389 |
+
if self.enable_bucket:
|
390 |
+
if self.bucket_manager is None: # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み
|
391 |
+
self.bucket_manager = BucketManager(self.bucket_no_upscale, (self.width, self.height),
|
392 |
+
self.min_bucket_reso, self.max_bucket_reso, self.bucket_reso_steps)
|
393 |
+
if not self.bucket_no_upscale:
|
394 |
+
self.bucket_manager.make_buckets()
|
395 |
+
else:
|
396 |
+
print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
|
397 |
+
|
398 |
+
img_ar_errors = []
|
399 |
+
for image_info in self.image_data.values():
|
400 |
+
image_width, image_height = image_info.image_size
|
401 |
+
image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(image_width, image_height)
|
402 |
+
|
403 |
+
# print(image_info.image_key, image_info.bucket_reso)
|
404 |
+
img_ar_errors.append(abs(ar_error))
|
405 |
+
|
406 |
+
self.bucket_manager.sort()
|
407 |
+
else:
|
408 |
+
self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None)
|
409 |
+
self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ
|
410 |
+
for image_info in self.image_data.values():
|
411 |
+
image_width, image_height = image_info.image_size
|
412 |
+
image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height)
|
413 |
+
|
414 |
+
for image_info in self.image_data.values():
|
415 |
+
for _ in range(image_info.num_repeats):
|
416 |
+
self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key)
|
417 |
+
|
418 |
+
# bucket情報を表示、格納する
|
419 |
+
if self.enable_bucket:
|
420 |
+
self.bucket_info = {"buckets": {}}
|
421 |
+
print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
|
422 |
+
for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)):
|
423 |
+
count = len(bucket)
|
424 |
+
if count > 0:
|
425 |
+
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
|
426 |
+
print(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
|
427 |
+
|
428 |
+
img_ar_errors = np.array(img_ar_errors)
|
429 |
+
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
|
430 |
+
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
|
431 |
+
print(f"mean ar error (without repeats): {mean_img_ar_error}")
|
432 |
+
|
433 |
+
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
|
434 |
+
self.buckets_indices: List(BucketBatchIndex) = []
|
435 |
+
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
|
436 |
+
batch_count = int(math.ceil(len(bucket) / self.batch_size))
|
437 |
+
for batch_index in range(batch_count):
|
438 |
+
self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index))
|
439 |
+
|
440 |
+
# ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す
|
441 |
+
# 学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる
|
442 |
+
#
|
443 |
+
# # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
|
444 |
+
# # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
|
445 |
+
# # そのためバッチサイズを画像種類までに制限する
|
446 |
+
# # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
|
447 |
+
# # TO DO 正則化画像をepochまたがりで利用する仕組み
|
448 |
+
# num_of_image_types = len(set(bucket))
|
449 |
+
# bucket_batch_size = min(self.batch_size, num_of_image_types)
|
450 |
+
# batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
|
451 |
+
# # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
|
452 |
+
# for batch_index in range(batch_count):
|
453 |
+
# self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
|
454 |
+
# ↑ここまで
|
455 |
+
|
456 |
+
self.shuffle_buckets()
|
457 |
+
self._length = len(self.buckets_indices)
|
458 |
+
|
459 |
+
def shuffle_buckets(self):
|
460 |
+
random.shuffle(self.buckets_indices)
|
461 |
+
self.bucket_manager.shuffle()
|
462 |
+
|
463 |
+
def load_image(self, image_path):
|
464 |
+
image = Image.open(image_path)
|
465 |
+
if not image.mode == "RGB":
|
466 |
+
image = image.convert("RGB")
|
467 |
+
img = np.array(image, np.uint8)
|
468 |
+
return img
|
469 |
+
|
470 |
+
def trim_and_resize_if_required(self, image, reso, resized_size):
|
471 |
+
image_height, image_width = image.shape[0:2]
|
472 |
+
|
473 |
+
if image_width != resized_size[0] or image_height != resized_size[1]:
|
474 |
+
# リサイズする
|
475 |
+
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
476 |
+
|
477 |
+
image_height, image_width = image.shape[0:2]
|
478 |
+
if image_width > reso[0]:
|
479 |
+
trim_size = image_width - reso[0]
|
480 |
+
p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
|
481 |
+
# print("w", trim_size, p)
|
482 |
+
image = image[:, p:p + reso[0]]
|
483 |
+
if image_height > reso[1]:
|
484 |
+
trim_size = image_height - reso[1]
|
485 |
+
p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
|
486 |
+
# print("h", trim_size, p)
|
487 |
+
image = image[p:p + reso[1]]
|
488 |
+
|
489 |
+
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
490 |
+
return image
|
491 |
+
|
492 |
+
def cache_latents(self, vae):
|
493 |
+
# TODO ここを高速化したい
|
494 |
+
print("caching latents.")
|
495 |
+
for info in tqdm(self.image_data.values()):
|
496 |
+
if info.latents_npz is not None:
|
497 |
+
info.latents = self.load_latents_from_npz(info, False)
|
498 |
+
info.latents = torch.FloatTensor(info.latents)
|
499 |
+
info.latents_flipped = self.load_latents_from_npz(info, True) # might be None
|
500 |
+
if info.latents_flipped is not None:
|
501 |
+
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
|
502 |
+
continue
|
503 |
+
|
504 |
+
image = self.load_image(info.absolute_path)
|
505 |
+
image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
|
506 |
+
|
507 |
+
img_tensor = self.image_transforms(image)
|
508 |
+
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
509 |
+
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
510 |
+
|
511 |
+
if self.flip_aug:
|
512 |
+
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
|
513 |
+
img_tensor = self.image_transforms(image)
|
514 |
+
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
515 |
+
info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
516 |
+
|
517 |
+
def get_image_size(self, image_path):
|
518 |
+
image = Image.open(image_path)
|
519 |
+
return image.size
|
520 |
+
|
521 |
+
def load_image_with_face_info(self, image_path: str):
|
522 |
+
img = self.load_image(image_path)
|
523 |
+
|
524 |
+
face_cx = face_cy = face_w = face_h = 0
|
525 |
+
if self.face_crop_aug_range is not None:
|
526 |
+
tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
|
527 |
+
if len(tokens) >= 5:
|
528 |
+
face_cx = int(tokens[-4])
|
529 |
+
face_cy = int(tokens[-3])
|
530 |
+
face_w = int(tokens[-2])
|
531 |
+
face_h = int(tokens[-1])
|
532 |
+
|
533 |
+
return img, face_cx, face_cy, face_w, face_h
|
534 |
+
|
535 |
+
# いい感じに切り出す
|
536 |
+
def crop_target(self, image, face_cx, face_cy, face_w, face_h):
|
537 |
+
height, width = image.shape[0:2]
|
538 |
+
if height == self.height and width == self.width:
|
539 |
+
return image
|
540 |
+
|
541 |
+
# 画像サイズはsizeより大きいのでリサイズする
|
542 |
+
face_size = max(face_w, face_h)
|
543 |
+
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
|
544 |
+
min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
|
545 |
+
max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
|
546 |
+
if min_scale >= max_scale: # range指定がmin==max
|
547 |
+
scale = min_scale
|
548 |
+
else:
|
549 |
+
scale = random.uniform(min_scale, max_scale)
|
550 |
+
|
551 |
+
nh = int(height * scale + .5)
|
552 |
+
nw = int(width * scale + .5)
|
553 |
+
assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
|
554 |
+
image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
|
555 |
+
face_cx = int(face_cx * scale + .5)
|
556 |
+
face_cy = int(face_cy * scale + .5)
|
557 |
+
height, width = nh, nw
|
558 |
+
|
559 |
+
# 顔を中心として448*640とかへ切り出す
|
560 |
+
for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
|
561 |
+
p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
|
562 |
+
|
563 |
+
if self.random_crop:
|
564 |
+
# 背景も含めるために顔を中心に置く確率を高めつつずらす
|
565 |
+
range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
|
566 |
+
p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
|
567 |
+
else:
|
568 |
+
# range指定があるときのみ、すこしだけランダムに(わりと適当)
|
569 |
+
if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
|
570 |
+
if face_size > self.size // 10 and face_size >= 40:
|
571 |
+
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
|
572 |
+
|
573 |
+
p1 = max(0, min(p1, length - target_size))
|
574 |
+
|
575 |
+
if axis == 0:
|
576 |
+
image = image[p1:p1 + target_size, :]
|
577 |
+
else:
|
578 |
+
image = image[:, p1:p1 + target_size]
|
579 |
+
|
580 |
+
return image
|
581 |
+
|
582 |
+
def load_latents_from_npz(self, image_info: ImageInfo, flipped):
|
583 |
+
npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
|
584 |
+
if npz_file is None:
|
585 |
+
return None
|
586 |
+
return np.load(npz_file)['arr_0']
|
587 |
+
|
588 |
+
def __len__(self):
|
589 |
+
return self._length
|
590 |
+
|
591 |
+
def __getitem__(self, index):
|
592 |
+
if index == 0:
|
593 |
+
self.shuffle_buckets()
|
594 |
+
|
595 |
+
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
596 |
+
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
597 |
+
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
598 |
+
|
599 |
+
loss_weights = []
|
600 |
+
captions = []
|
601 |
+
input_ids_list = []
|
602 |
+
latents_list = []
|
603 |
+
images = []
|
604 |
+
|
605 |
+
for image_key in bucket[image_index:image_index + bucket_batch_size]:
|
606 |
+
image_info = self.image_data[image_key]
|
607 |
+
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
608 |
+
|
609 |
+
# image/latentsを処理する
|
610 |
+
if image_info.latents is not None:
|
611 |
+
latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped
|
612 |
+
image = None
|
613 |
+
elif image_info.latents_npz is not None:
|
614 |
+
latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5)
|
615 |
+
latents = torch.FloatTensor(latents)
|
616 |
+
image = None
|
617 |
+
else:
|
618 |
+
# 画像を読み込み、必要ならcropする
|
619 |
+
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
|
620 |
+
im_h, im_w = img.shape[0:2]
|
621 |
+
|
622 |
+
if self.enable_bucket:
|
623 |
+
img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
|
624 |
+
else:
|
625 |
+
if face_cx > 0: # 顔位置情報あり
|
626 |
+
img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
|
627 |
+
elif im_h > self.height or im_w > self.width:
|
628 |
+
assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
|
629 |
+
if im_h > self.height:
|
630 |
+
p = random.randint(0, im_h - self.height)
|
631 |
+
img = img[p:p + self.height]
|
632 |
+
if im_w > self.width:
|
633 |
+
p = random.randint(0, im_w - self.width)
|
634 |
+
img = img[:, p:p + self.width]
|
635 |
+
|
636 |
+
im_h, im_w = img.shape[0:2]
|
637 |
+
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
638 |
+
|
639 |
+
# augmentation
|
640 |
+
if self.aug is not None:
|
641 |
+
img = self.aug(image=img)['image']
|
642 |
+
|
643 |
+
latents = None
|
644 |
+
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
645 |
+
|
646 |
+
images.append(image)
|
647 |
+
latents_list.append(latents)
|
648 |
+
|
649 |
+
caption = self.process_caption(image_info.caption)
|
650 |
+
captions.append(caption)
|
651 |
+
if not self.token_padding_disabled: # this option might be omitted in future
|
652 |
+
input_ids_list.append(self.get_input_ids(caption))
|
653 |
+
|
654 |
+
example = {}
|
655 |
+
example['loss_weights'] = torch.FloatTensor(loss_weights)
|
656 |
+
|
657 |
+
if self.token_padding_disabled:
|
658 |
+
# padding=True means pad in the batch
|
659 |
+
example['input_ids'] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
660 |
+
else:
|
661 |
+
# batch processing seems to be good
|
662 |
+
example['input_ids'] = torch.stack(input_ids_list)
|
663 |
+
|
664 |
+
if images[0] is not None:
|
665 |
+
images = torch.stack(images)
|
666 |
+
images = images.to(memory_format=torch.contiguous_format).float()
|
667 |
+
else:
|
668 |
+
images = None
|
669 |
+
example['images'] = images
|
670 |
+
|
671 |
+
example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None
|
672 |
+
|
673 |
+
if self.debug_dataset:
|
674 |
+
example['image_keys'] = bucket[image_index:image_index + self.batch_size]
|
675 |
+
example['captions'] = captions
|
676 |
+
return example
|
677 |
+
|
678 |
+
|
679 |
+
class DreamBoothDataset(BaseDataset):
|
680 |
+
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
|
681 |
+
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
682 |
+
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
683 |
+
|
684 |
+
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
685 |
+
|
686 |
+
self.batch_size = batch_size
|
687 |
+
self.size = min(self.width, self.height) # 短いほう
|
688 |
+
self.prior_loss_weight = prior_loss_weight
|
689 |
+
self.latents_cache = None
|
690 |
+
|
691 |
+
self.enable_bucket = enable_bucket
|
692 |
+
if self.enable_bucket:
|
693 |
+
assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
|
694 |
+
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
695 |
+
self.min_bucket_reso = min_bucket_reso
|
696 |
+
self.max_bucket_reso = max_bucket_reso
|
697 |
+
self.bucket_reso_steps = bucket_reso_steps
|
698 |
+
self.bucket_no_upscale = bucket_no_upscale
|
699 |
+
else:
|
700 |
+
self.min_bucket_reso = None
|
701 |
+
self.max_bucket_reso = None
|
702 |
+
self.bucket_reso_steps = None # この情報は使われない
|
703 |
+
self.bucket_no_upscale = False
|
704 |
+
|
705 |
+
def read_caption(img_path):
|
706 |
+
# captionの候補ファイル名を作る
|
707 |
+
base_name = os.path.splitext(img_path)[0]
|
708 |
+
base_name_face_det = base_name
|
709 |
+
tokens = base_name.split("_")
|
710 |
+
if len(tokens) >= 5:
|
711 |
+
base_name_face_det = "_".join(tokens[:-4])
|
712 |
+
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
|
713 |
+
|
714 |
+
caption = None
|
715 |
+
for cap_path in cap_paths:
|
716 |
+
if os.path.isfile(cap_path):
|
717 |
+
with open(cap_path, "rt", encoding='utf-8') as f:
|
718 |
+
try:
|
719 |
+
lines = f.readlines()
|
720 |
+
except UnicodeDecodeError as e:
|
721 |
+
print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
|
722 |
+
raise e
|
723 |
+
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
724 |
+
caption = lines[0].strip()
|
725 |
+
break
|
726 |
+
return caption
|
727 |
+
|
728 |
+
def load_dreambooth_dir(dir):
|
729 |
+
if not os.path.isdir(dir):
|
730 |
+
# print(f"ignore file: {dir}")
|
731 |
+
return 0, [], []
|
732 |
+
|
733 |
+
tokens = os.path.basename(dir).split('_')
|
734 |
+
try:
|
735 |
+
n_repeats = int(tokens[0])
|
736 |
+
except ValueError as e:
|
737 |
+
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
|
738 |
+
return 0, [], []
|
739 |
+
|
740 |
+
caption_by_folder = '_'.join(tokens[1:])
|
741 |
+
img_paths = glob_images(dir, "*")
|
742 |
+
print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
|
743 |
+
|
744 |
+
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
745 |
+
captions = []
|
746 |
+
for img_path in img_paths:
|
747 |
+
cap_for_img = read_caption(img_path)
|
748 |
+
captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
|
749 |
+
|
750 |
+
self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
|
751 |
+
|
752 |
+
return n_repeats, img_paths, captions
|
753 |
+
|
754 |
+
print("prepare train images.")
|
755 |
+
train_dirs = os.listdir(train_data_dir)
|
756 |
+
num_train_images = 0
|
757 |
+
for dir in train_dirs:
|
758 |
+
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
|
759 |
+
num_train_images += n_repeats * len(img_paths)
|
760 |
+
|
761 |
+
for img_path, caption in zip(img_paths, captions):
|
762 |
+
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
|
763 |
+
self.register_image(info)
|
764 |
+
|
765 |
+
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
766 |
+
|
767 |
+
print(f"{num_train_images} train images with repeating.")
|
768 |
+
self.num_train_images = num_train_images
|
769 |
+
|
770 |
+
# reg imageは数を数えて学習画像と同じ枚数にする
|
771 |
+
num_reg_images = 0
|
772 |
+
if reg_data_dir:
|
773 |
+
print("prepare reg images.")
|
774 |
+
reg_infos: List[ImageInfo] = []
|
775 |
+
|
776 |
+
reg_dirs = os.listdir(reg_data_dir)
|
777 |
+
for dir in reg_dirs:
|
778 |
+
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
|
779 |
+
num_reg_images += n_repeats * len(img_paths)
|
780 |
+
|
781 |
+
for img_path, caption in zip(img_paths, captions):
|
782 |
+
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
783 |
+
reg_infos.append(info)
|
784 |
+
|
785 |
+
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
786 |
+
|
787 |
+
print(f"{num_reg_images} reg images.")
|
788 |
+
if num_train_images < num_reg_images:
|
789 |
+
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
|
790 |
+
|
791 |
+
if num_reg_images == 0:
|
792 |
+
print("no regularization images / 正則化画像が見つかりませんでした")
|
793 |
+
else:
|
794 |
+
# num_repeatsを計算する:どうせ大した数ではないのでループで処理する
|
795 |
+
n = 0
|
796 |
+
first_loop = True
|
797 |
+
while n < num_train_images:
|
798 |
+
for info in reg_infos:
|
799 |
+
if first_loop:
|
800 |
+
self.register_image(info)
|
801 |
+
n += info.num_repeats
|
802 |
+
else:
|
803 |
+
info.num_repeats += 1
|
804 |
+
n += 1
|
805 |
+
if n >= num_train_images:
|
806 |
+
break
|
807 |
+
first_loop = False
|
808 |
+
|
809 |
+
self.num_reg_images = num_reg_images
|
810 |
+
|
811 |
+
|
812 |
+
class FineTuningDataset(BaseDataset):
|
813 |
+
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
|
814 |
+
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
815 |
+
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
816 |
+
|
817 |
+
# メタデータを読み込む
|
818 |
+
if os.path.exists(json_file_name):
|
819 |
+
print(f"loading existing metadata: {json_file_name}")
|
820 |
+
with open(json_file_name, "rt", encoding='utf-8') as f:
|
821 |
+
metadata = json.load(f)
|
822 |
+
else:
|
823 |
+
raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}")
|
824 |
+
|
825 |
+
self.metadata = metadata
|
826 |
+
self.train_data_dir = train_data_dir
|
827 |
+
self.batch_size = batch_size
|
828 |
+
|
829 |
+
tags_list = []
|
830 |
+
for image_key, img_md in metadata.items():
|
831 |
+
# path情報を作る
|
832 |
+
if os.path.exists(image_key):
|
833 |
+
abs_path = image_key
|
834 |
+
else:
|
835 |
+
# わりといい加減だがいい方法が思いつかん
|
836 |
+
abs_path = glob_images(train_data_dir, image_key)
|
837 |
+
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
838 |
+
abs_path = abs_path[0]
|
839 |
+
|
840 |
+
caption = img_md.get('caption')
|
841 |
+
tags = img_md.get('tags')
|
842 |
+
if caption is None:
|
843 |
+
caption = tags
|
844 |
+
elif tags is not None and len(tags) > 0:
|
845 |
+
caption = caption + ', ' + tags
|
846 |
+
tags_list.append(tags)
|
847 |
+
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
|
848 |
+
|
849 |
+
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
|
850 |
+
image_info.image_size = img_md.get('train_resolution')
|
851 |
+
|
852 |
+
if not self.color_aug and not self.random_crop:
|
853 |
+
# if npz exists, use them
|
854 |
+
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
|
855 |
+
|
856 |
+
self.register_image(image_info)
|
857 |
+
self.num_train_images = len(metadata) * dataset_repeats
|
858 |
+
self.num_reg_images = 0
|
859 |
+
|
860 |
+
# TODO do not record tag freq when no tag
|
861 |
+
self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
|
862 |
+
self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
|
863 |
+
|
864 |
+
# check existence of all npz files
|
865 |
+
use_npz_latents = not (self.color_aug or self.random_crop)
|
866 |
+
if use_npz_latents:
|
867 |
+
npz_any = False
|
868 |
+
npz_all = True
|
869 |
+
for image_info in self.image_data.values():
|
870 |
+
has_npz = image_info.latents_npz is not None
|
871 |
+
npz_any = npz_any or has_npz
|
872 |
+
|
873 |
+
if self.flip_aug:
|
874 |
+
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
875 |
+
npz_all = npz_all and has_npz
|
876 |
+
|
877 |
+
if npz_any and not npz_all:
|
878 |
+
break
|
879 |
+
|
880 |
+
if not npz_any:
|
881 |
+
use_npz_latents = False
|
882 |
+
print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します")
|
883 |
+
elif not npz_all:
|
884 |
+
use_npz_latents = False
|
885 |
+
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
|
886 |
+
if self.flip_aug:
|
887 |
+
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
888 |
+
# else:
|
889 |
+
# print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
890 |
+
|
891 |
+
# check min/max bucket size
|
892 |
+
sizes = set()
|
893 |
+
resos = set()
|
894 |
+
for image_info in self.image_data.values():
|
895 |
+
if image_info.image_size is None:
|
896 |
+
sizes = None # not calculated
|
897 |
+
break
|
898 |
+
sizes.add(image_info.image_size[0])
|
899 |
+
sizes.add(image_info.image_size[1])
|
900 |
+
resos.add(tuple(image_info.image_size))
|
901 |
+
|
902 |
+
if sizes is None:
|
903 |
+
if use_npz_latents:
|
904 |
+
use_npz_latents = False
|
905 |
+
print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します")
|
906 |
+
|
907 |
+
assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
|
908 |
+
|
909 |
+
self.enable_bucket = enable_bucket
|
910 |
+
if self.enable_bucket:
|
911 |
+
self.min_bucket_reso = min_bucket_reso
|
912 |
+
self.max_bucket_reso = max_bucket_reso
|
913 |
+
self.bucket_reso_steps = bucket_reso_steps
|
914 |
+
self.bucket_no_upscale = bucket_no_upscale
|
915 |
+
else:
|
916 |
+
if not enable_bucket:
|
917 |
+
print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします")
|
918 |
+
print("using bucket info in metadata / メタデータ内のbucket情報を使います")
|
919 |
+
self.enable_bucket = True
|
920 |
+
|
921 |
+
assert not bucket_no_upscale, "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません"
|
922 |
+
|
923 |
+
# bucket情報を初期化しておく、make_bucketsで再作成しない
|
924 |
+
self.bucket_manager = BucketManager(False, None, None, None, None)
|
925 |
+
self.bucket_manager.set_predefined_resos(resos)
|
926 |
+
|
927 |
+
# npz情報をきれいにしておく
|
928 |
+
if not use_npz_latents:
|
929 |
+
for image_info in self.image_data.values():
|
930 |
+
image_info.latents_npz = image_info.latents_npz_flipped = None
|
931 |
+
|
932 |
+
def image_key_to_npz_file(self, image_key):
|
933 |
+
base_name = os.path.splitext(image_key)[0]
|
934 |
+
npz_file_norm = base_name + '.npz'
|
935 |
+
|
936 |
+
if os.path.exists(npz_file_norm):
|
937 |
+
# image_key is full path
|
938 |
+
npz_file_flip = base_name + '_flip.npz'
|
939 |
+
if not os.path.exists(npz_file_flip):
|
940 |
+
npz_file_flip = None
|
941 |
+
return npz_file_norm, npz_file_flip
|
942 |
+
|
943 |
+
# image_key is relative path
|
944 |
+
npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz')
|
945 |
+
npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz')
|
946 |
+
|
947 |
+
if not os.path.exists(npz_file_norm):
|
948 |
+
npz_file_norm = None
|
949 |
+
npz_file_flip = None
|
950 |
+
elif not os.path.exists(npz_file_flip):
|
951 |
+
npz_file_flip = None
|
952 |
+
|
953 |
+
return npz_file_norm, npz_file_flip
|
954 |
+
|
955 |
+
|
956 |
+
def debug_dataset(train_dataset, show_input_ids=False):
|
957 |
+
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
958 |
+
print("Escape for exit. / Escキーで中断、終了します")
|
959 |
+
|
960 |
+
train_dataset.set_current_epoch(1)
|
961 |
+
k = 0
|
962 |
+
for i, example in enumerate(train_dataset):
|
963 |
+
if example['latents'] is not None:
|
964 |
+
print(f"sample has latents from npz file: {example['latents'].size()}")
|
965 |
+
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
|
966 |
+
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
|
967 |
+
if show_input_ids:
|
968 |
+
print(f"input ids: {iid}")
|
969 |
+
if example['images'] is not None:
|
970 |
+
im = example['images'][j]
|
971 |
+
print(f"image size: {im.size()}")
|
972 |
+
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
|
973 |
+
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
|
974 |
+
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
|
975 |
+
if os.name == 'nt': # only windows
|
976 |
+
cv2.imshow("img", im)
|
977 |
+
k = cv2.waitKey()
|
978 |
+
cv2.destroyAllWindows()
|
979 |
+
if k == 27:
|
980 |
+
break
|
981 |
+
if k == 27 or (example['images'] is None and i >= 8):
|
982 |
+
break
|
983 |
+
|
984 |
+
|
985 |
+
def glob_images(directory, base="*"):
|
986 |
+
img_paths = []
|
987 |
+
for ext in IMAGE_EXTENSIONS:
|
988 |
+
if base == '*':
|
989 |
+
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
|
990 |
+
else:
|
991 |
+
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
992 |
+
# img_paths = list(set(img_paths)) # 重複を排除
|
993 |
+
# img_paths.sort()
|
994 |
+
return img_paths
|
995 |
+
|
996 |
+
|
997 |
+
def glob_images_pathlib(dir_path, recursive):
|
998 |
+
image_paths = []
|
999 |
+
if recursive:
|
1000 |
+
for ext in IMAGE_EXTENSIONS:
|
1001 |
+
image_paths += list(dir_path.rglob('*' + ext))
|
1002 |
+
else:
|
1003 |
+
for ext in IMAGE_EXTENSIONS:
|
1004 |
+
image_paths += list(dir_path.glob('*' + ext))
|
1005 |
+
# image_paths = list(set(image_paths)) # 重複を排除
|
1006 |
+
# image_paths.sort()
|
1007 |
+
return image_paths
|
1008 |
+
|
1009 |
+
# endregion
|
1010 |
+
|
1011 |
+
|
1012 |
+
# region モジュール入れ替え部
|
1013 |
+
"""
|
1014 |
+
高速化のためのモジュール入れ替え
|
1015 |
+
"""
|
1016 |
+
|
1017 |
+
# FlashAttentionを使うCrossAttention
|
1018 |
+
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
|
1019 |
+
# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
|
1020 |
+
|
1021 |
+
# constants
|
1022 |
+
|
1023 |
+
EPSILON = 1e-6
|
1024 |
+
|
1025 |
+
# helper functions
|
1026 |
+
|
1027 |
+
|
1028 |
+
def exists(val):
|
1029 |
+
return val is not None
|
1030 |
+
|
1031 |
+
|
1032 |
+
def default(val, d):
|
1033 |
+
return val if exists(val) else d
|
1034 |
+
|
1035 |
+
|
1036 |
+
def model_hash(filename):
|
1037 |
+
"""Old model hash used by stable-diffusion-webui"""
|
1038 |
+
try:
|
1039 |
+
with open(filename, "rb") as file:
|
1040 |
+
m = hashlib.sha256()
|
1041 |
+
|
1042 |
+
file.seek(0x100000)
|
1043 |
+
m.update(file.read(0x10000))
|
1044 |
+
return m.hexdigest()[0:8]
|
1045 |
+
except FileNotFoundError:
|
1046 |
+
return 'NOFILE'
|
1047 |
+
|
1048 |
+
|
1049 |
+
def calculate_sha256(filename):
|
1050 |
+
"""New model hash used by stable-diffusion-webui"""
|
1051 |
+
hash_sha256 = hashlib.sha256()
|
1052 |
+
blksize = 1024 * 1024
|
1053 |
+
|
1054 |
+
with open(filename, "rb") as f:
|
1055 |
+
for chunk in iter(lambda: f.read(blksize), b""):
|
1056 |
+
hash_sha256.update(chunk)
|
1057 |
+
|
1058 |
+
return hash_sha256.hexdigest()
|
1059 |
+
|
1060 |
+
|
1061 |
+
def precalculate_safetensors_hashes(tensors, metadata):
|
1062 |
+
"""Precalculate the model hashes needed by sd-webui-additional-networks to
|
1063 |
+
save time on indexing the model later."""
|
1064 |
+
|
1065 |
+
# Because writing user metadata to the file can change the result of
|
1066 |
+
# sd_models.model_hash(), only retain the training metadata for purposes of
|
1067 |
+
# calculating the hash, as they are meant to be immutable
|
1068 |
+
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
|
1069 |
+
|
1070 |
+
bytes = safetensors.torch.save(tensors, metadata)
|
1071 |
+
b = BytesIO(bytes)
|
1072 |
+
|
1073 |
+
model_hash = addnet_hash_safetensors(b)
|
1074 |
+
legacy_hash = addnet_hash_legacy(b)
|
1075 |
+
return model_hash, legacy_hash
|
1076 |
+
|
1077 |
+
|
1078 |
+
def addnet_hash_legacy(b):
|
1079 |
+
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
|
1080 |
+
m = hashlib.sha256()
|
1081 |
+
|
1082 |
+
b.seek(0x100000)
|
1083 |
+
m.update(b.read(0x10000))
|
1084 |
+
return m.hexdigest()[0:8]
|
1085 |
+
|
1086 |
+
|
1087 |
+
def addnet_hash_safetensors(b):
|
1088 |
+
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
|
1089 |
+
hash_sha256 = hashlib.sha256()
|
1090 |
+
blksize = 1024 * 1024
|
1091 |
+
|
1092 |
+
b.seek(0)
|
1093 |
+
header = b.read(8)
|
1094 |
+
n = int.from_bytes(header, "little")
|
1095 |
+
|
1096 |
+
offset = n + 8
|
1097 |
+
b.seek(offset)
|
1098 |
+
for chunk in iter(lambda: b.read(blksize), b""):
|
1099 |
+
hash_sha256.update(chunk)
|
1100 |
+
|
1101 |
+
return hash_sha256.hexdigest()
|
1102 |
+
|
1103 |
+
|
1104 |
+
def get_git_revision_hash() -> str:
|
1105 |
+
try:
|
1106 |
+
return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=os.path.dirname(__file__)).decode('ascii').strip()
|
1107 |
+
except:
|
1108 |
+
return "(unknown)"
|
1109 |
+
|
1110 |
+
|
1111 |
+
# flash attention forwards and backwards
|
1112 |
+
|
1113 |
+
# https://arxiv.org/abs/2205.14135
|
1114 |
+
|
1115 |
+
|
1116 |
+
class FlashAttentionFunction(torch.autograd.function.Function):
|
1117 |
+
@ staticmethod
|
1118 |
+
@ torch.no_grad()
|
1119 |
+
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
1120 |
+
""" Algorithm 2 in the paper """
|
1121 |
+
|
1122 |
+
device = q.device
|
1123 |
+
dtype = q.dtype
|
1124 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
1125 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
1126 |
+
|
1127 |
+
o = torch.zeros_like(q)
|
1128 |
+
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
1129 |
+
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
|
1130 |
+
|
1131 |
+
scale = (q.shape[-1] ** -0.5)
|
1132 |
+
|
1133 |
+
if not exists(mask):
|
1134 |
+
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
1135 |
+
else:
|
1136 |
+
mask = rearrange(mask, 'b n -> b 1 1 n')
|
1137 |
+
mask = mask.split(q_bucket_size, dim=-1)
|
1138 |
+
|
1139 |
+
row_splits = zip(
|
1140 |
+
q.split(q_bucket_size, dim=-2),
|
1141 |
+
o.split(q_bucket_size, dim=-2),
|
1142 |
+
mask,
|
1143 |
+
all_row_sums.split(q_bucket_size, dim=-2),
|
1144 |
+
all_row_maxes.split(q_bucket_size, dim=-2),
|
1145 |
+
)
|
1146 |
+
|
1147 |
+
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
1148 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
1149 |
+
|
1150 |
+
col_splits = zip(
|
1151 |
+
k.split(k_bucket_size, dim=-2),
|
1152 |
+
v.split(k_bucket_size, dim=-2),
|
1153 |
+
)
|
1154 |
+
|
1155 |
+
for k_ind, (kc, vc) in enumerate(col_splits):
|
1156 |
+
k_start_index = k_ind * k_bucket_size
|
1157 |
+
|
1158 |
+
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
|
1159 |
+
|
1160 |
+
if exists(row_mask):
|
1161 |
+
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
1162 |
+
|
1163 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
1164 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
|
1165 |
+
device=device).triu(q_start_index - k_start_index + 1)
|
1166 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
1167 |
+
|
1168 |
+
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
1169 |
+
attn_weights -= block_row_maxes
|
1170 |
+
exp_weights = torch.exp(attn_weights)
|
1171 |
+
|
1172 |
+
if exists(row_mask):
|
1173 |
+
exp_weights.masked_fill_(~row_mask, 0.)
|
1174 |
+
|
1175 |
+
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
|
1176 |
+
|
1177 |
+
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
1178 |
+
|
1179 |
+
exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
|
1180 |
+
|
1181 |
+
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
1182 |
+
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
1183 |
+
|
1184 |
+
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
|
1185 |
+
|
1186 |
+
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
|
1187 |
+
|
1188 |
+
row_maxes.copy_(new_row_maxes)
|
1189 |
+
row_sums.copy_(new_row_sums)
|
1190 |
+
|
1191 |
+
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
1192 |
+
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
1193 |
+
|
1194 |
+
return o
|
1195 |
+
|
1196 |
+
@ staticmethod
|
1197 |
+
@ torch.no_grad()
|
1198 |
+
def backward(ctx, do):
|
1199 |
+
""" Algorithm 4 in the paper """
|
1200 |
+
|
1201 |
+
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
1202 |
+
q, k, v, o, l, m = ctx.saved_tensors
|
1203 |
+
|
1204 |
+
device = q.device
|
1205 |
+
|
1206 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
1207 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
1208 |
+
|
1209 |
+
dq = torch.zeros_like(q)
|
1210 |
+
dk = torch.zeros_like(k)
|
1211 |
+
dv = torch.zeros_like(v)
|
1212 |
+
|
1213 |
+
row_splits = zip(
|
1214 |
+
q.split(q_bucket_size, dim=-2),
|
1215 |
+
o.split(q_bucket_size, dim=-2),
|
1216 |
+
do.split(q_bucket_size, dim=-2),
|
1217 |
+
mask,
|
1218 |
+
l.split(q_bucket_size, dim=-2),
|
1219 |
+
m.split(q_bucket_size, dim=-2),
|
1220 |
+
dq.split(q_bucket_size, dim=-2)
|
1221 |
+
)
|
1222 |
+
|
1223 |
+
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
1224 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
1225 |
+
|
1226 |
+
col_splits = zip(
|
1227 |
+
k.split(k_bucket_size, dim=-2),
|
1228 |
+
v.split(k_bucket_size, dim=-2),
|
1229 |
+
dk.split(k_bucket_size, dim=-2),
|
1230 |
+
dv.split(k_bucket_size, dim=-2),
|
1231 |
+
)
|
1232 |
+
|
1233 |
+
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
1234 |
+
k_start_index = k_ind * k_bucket_size
|
1235 |
+
|
1236 |
+
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
|
1237 |
+
|
1238 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
1239 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
|
1240 |
+
device=device).triu(q_start_index - k_start_index + 1)
|
1241 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
1242 |
+
|
1243 |
+
exp_attn_weights = torch.exp(attn_weights - mc)
|
1244 |
+
|
1245 |
+
if exists(row_mask):
|
1246 |
+
exp_attn_weights.masked_fill_(~row_mask, 0.)
|
1247 |
+
|
1248 |
+
p = exp_attn_weights / lc
|
1249 |
+
|
1250 |
+
dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
|
1251 |
+
dp = einsum('... i d, ... j d -> ... i j', doc, vc)
|
1252 |
+
|
1253 |
+
D = (doc * oc).sum(dim=-1, keepdims=True)
|
1254 |
+
ds = p * scale * (dp - D)
|
1255 |
+
|
1256 |
+
dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
|
1257 |
+
dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
|
1258 |
+
|
1259 |
+
dqc.add_(dq_chunk)
|
1260 |
+
dkc.add_(dk_chunk)
|
1261 |
+
dvc.add_(dv_chunk)
|
1262 |
+
|
1263 |
+
return dq, dk, dv, None, None, None, None
|
1264 |
+
|
1265 |
+
|
1266 |
+
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
1267 |
+
if mem_eff_attn:
|
1268 |
+
replace_unet_cross_attn_to_memory_efficient()
|
1269 |
+
elif xformers:
|
1270 |
+
replace_unet_cross_attn_to_xformers()
|
1271 |
+
|
1272 |
+
|
1273 |
+
def replace_unet_cross_attn_to_memory_efficient():
|
1274 |
+
print("Replace CrossAttention.forward to use FlashAttention (not xformers)")
|
1275 |
+
flash_func = FlashAttentionFunction
|
1276 |
+
|
1277 |
+
def forward_flash_attn(self, x, context=None, mask=None):
|
1278 |
+
q_bucket_size = 512
|
1279 |
+
k_bucket_size = 1024
|
1280 |
+
|
1281 |
+
h = self.heads
|
1282 |
+
q = self.to_q(x)
|
1283 |
+
|
1284 |
+
context = context if context is not None else x
|
1285 |
+
context = context.to(x.dtype)
|
1286 |
+
|
1287 |
+
if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
|
1288 |
+
context_k, context_v = self.hypernetwork.forward(x, context)
|
1289 |
+
context_k = context_k.to(x.dtype)
|
1290 |
+
context_v = context_v.to(x.dtype)
|
1291 |
+
else:
|
1292 |
+
context_k = context
|
1293 |
+
context_v = context
|
1294 |
+
|
1295 |
+
k = self.to_k(context_k)
|
1296 |
+
v = self.to_v(context_v)
|
1297 |
+
del context, x
|
1298 |
+
|
1299 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
|
1300 |
+
|
1301 |
+
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
|
1302 |
+
|
1303 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
1304 |
+
|
1305 |
+
# diffusers 0.7.0~ わざわざ変えるなよ (;´Д`)
|
1306 |
+
out = self.to_out[0](out)
|
1307 |
+
out = self.to_out[1](out)
|
1308 |
+
return out
|
1309 |
+
|
1310 |
+
diffusers.models.attention.CrossAttention.forward = forward_flash_attn
|
1311 |
+
|
1312 |
+
|
1313 |
+
def replace_unet_cross_attn_to_xformers():
|
1314 |
+
print("Replace CrossAttention.forward to use xformers")
|
1315 |
+
try:
|
1316 |
+
import xformers.ops
|
1317 |
+
except ImportError:
|
1318 |
+
raise ImportError("No xformers / xformersがインストールされていないようです")
|
1319 |
+
|
1320 |
+
def forward_xformers(self, x, context=None, mask=None):
|
1321 |
+
h = self.heads
|
1322 |
+
q_in = self.to_q(x)
|
1323 |
+
|
1324 |
+
context = default(context, x)
|
1325 |
+
context = context.to(x.dtype)
|
1326 |
+
|
1327 |
+
if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
|
1328 |
+
context_k, context_v = self.hypernetwork.forward(x, context)
|
1329 |
+
context_k = context_k.to(x.dtype)
|
1330 |
+
context_v = context_v.to(x.dtype)
|
1331 |
+
else:
|
1332 |
+
context_k = context
|
1333 |
+
context_v = context
|
1334 |
+
|
1335 |
+
k_in = self.to_k(context_k)
|
1336 |
+
v_in = self.to_v(context_v)
|
1337 |
+
|
1338 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
1339 |
+
del q_in, k_in, v_in
|
1340 |
+
|
1341 |
+
q = q.contiguous()
|
1342 |
+
k = k.contiguous()
|
1343 |
+
v = v.contiguous()
|
1344 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
1345 |
+
|
1346 |
+
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
1347 |
+
|
1348 |
+
# diffusers 0.7.0~
|
1349 |
+
out = self.to_out[0](out)
|
1350 |
+
out = self.to_out[1](out)
|
1351 |
+
return out
|
1352 |
+
|
1353 |
+
diffusers.models.attention.CrossAttention.forward = forward_xformers
|
1354 |
+
# endregion
|
1355 |
+
|
1356 |
+
|
1357 |
+
# region arguments
|
1358 |
+
|
1359 |
+
def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
1360 |
+
# for pretrained models
|
1361 |
+
parser.add_argument("--v2", action='store_true',
|
1362 |
+
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
|
1363 |
+
parser.add_argument("--v_parameterization", action='store_true',
|
1364 |
+
help='enable v-parameterization training / v-parameterization学習を有効にする')
|
1365 |
+
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
|
1366 |
+
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
1367 |
+
|
1368 |
+
|
1369 |
+
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
1370 |
+
parser.add_argument("--output_dir", type=str, default=None,
|
1371 |
+
help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
|
1372 |
+
parser.add_argument("--output_name", type=str, default=None,
|
1373 |
+
help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
|
1374 |
+
parser.add_argument("--save_precision", type=str, default=None,
|
1375 |
+
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
|
1376 |
+
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
1377 |
+
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
1378 |
+
parser.add_argument("--save_n_epoch_ratio", type=int, default=None,
|
1379 |
+
help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)")
|
1380 |
+
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
|
1381 |
+
parser.add_argument("--save_last_n_epochs_state", type=int, default=None,
|
1382 |
+
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
|
1383 |
+
parser.add_argument("--save_state", action="store_true",
|
1384 |
+
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
1385 |
+
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
1386 |
+
|
1387 |
+
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
1388 |
+
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
1389 |
+
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
1390 |
+
parser.add_argument("--use_8bit_adam", action="store_true",
|
1391 |
+
help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
1392 |
+
parser.add_argument("--use_lion_optimizer", action="store_true",
|
1393 |
+
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
1394 |
+
parser.add_argument("--mem_eff_attn", action="store_true",
|
1395 |
+
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
1396 |
+
parser.add_argument("--xformers", action="store_true",
|
1397 |
+
help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
1398 |
+
parser.add_argument("--vae", type=str, default=None,
|
1399 |
+
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
1400 |
+
|
1401 |
+
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
1402 |
+
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
1403 |
+
parser.add_argument("--max_train_epochs", type=int, default=None,
|
1404 |
+
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
1405 |
+
parser.add_argument("--max_data_loader_n_workers", type=int, default=8,
|
1406 |
+
help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
|
1407 |
+
parser.add_argument("--persistent_data_loader_workers", action="store_true",
|
1408 |
+
help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)")
|
1409 |
+
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
1410 |
+
parser.add_argument("--gradient_checkpointing", action="store_true",
|
1411 |
+
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
1412 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
|
1413 |
+
help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数")
|
1414 |
+
parser.add_argument("--mixed_precision", type=str, default="no",
|
1415 |
+
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
1416 |
+
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
|
1417 |
+
parser.add_argument("--clip_skip", type=int, default=None,
|
1418 |
+
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
1419 |
+
parser.add_argument("--logging_dir", type=str, default=None,
|
1420 |
+
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
1421 |
+
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
1422 |
+
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
1423 |
+
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
|
1424 |
+
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
1425 |
+
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
1426 |
+
parser.add_argument("--noise_offset", type=float, default=None,
|
1427 |
+
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
|
1428 |
+
parser.add_argument("--lowram", action="store_true",
|
1429 |
+
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
|
1430 |
+
|
1431 |
+
if support_dreambooth:
|
1432 |
+
# DreamBooth training
|
1433 |
+
parser.add_argument("--prior_loss_weight", type=float, default=1.0,
|
1434 |
+
help="loss weight for regularization images / 正則化画像のlossの重み")
|
1435 |
+
|
1436 |
+
|
1437 |
+
def verify_training_args(args: argparse.Namespace):
|
1438 |
+
if args.v_parameterization and not args.v2:
|
1439 |
+
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
|
1440 |
+
if args.v2 and args.clip_skip is not None:
|
1441 |
+
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
1442 |
+
|
1443 |
+
|
1444 |
+
def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool):
|
1445 |
+
# dataset common
|
1446 |
+
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
|
1447 |
+
parser.add_argument("--shuffle_caption", action="store_true",
|
1448 |
+
help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする")
|
1449 |
+
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
|
1450 |
+
parser.add_argument("--caption_extention", type=str, default=None,
|
1451 |
+
help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
|
1452 |
+
parser.add_argument("--keep_tokens", type=int, default=None,
|
1453 |
+
help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
|
1454 |
+
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
|
1455 |
+
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
|
1456 |
+
parser.add_argument("--face_crop_aug_range", type=str, default=None,
|
1457 |
+
help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)")
|
1458 |
+
parser.add_argument("--random_crop", action="store_true",
|
1459 |
+
help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)")
|
1460 |
+
parser.add_argument("--debug_dataset", action="store_true",
|
1461 |
+
help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
|
1462 |
+
parser.add_argument("--resolution", type=str, default=None,
|
1463 |
+
help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)")
|
1464 |
+
parser.add_argument("--cache_latents", action="store_true",
|
1465 |
+
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
|
1466 |
+
parser.add_argument("--enable_bucket", action="store_true",
|
1467 |
+
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
|
1468 |
+
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
1469 |
+
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
|
1470 |
+
parser.add_argument("--bucket_reso_steps", type=int, default=64,
|
1471 |
+
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
|
1472 |
+
parser.add_argument("--bucket_no_upscale", action="store_true",
|
1473 |
+
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
|
1474 |
+
|
1475 |
+
if support_caption_dropout:
|
1476 |
+
# Textual Inversion はcaptionのdropoutをsupportしない
|
1477 |
+
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
1478 |
+
parser.add_argument("--caption_dropout_rate", type=float, default=0,
|
1479 |
+
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
1480 |
+
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
|
1481 |
+
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
1482 |
+
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
|
1483 |
+
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
|
1484 |
+
|
1485 |
+
if support_dreambooth:
|
1486 |
+
# DreamBooth dataset
|
1487 |
+
parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
|
1488 |
+
|
1489 |
+
if support_caption:
|
1490 |
+
# caption dataset
|
1491 |
+
parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル")
|
1492 |
+
parser.add_argument("--dataset_repeats", type=int, default=1,
|
1493 |
+
help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数")
|
1494 |
+
|
1495 |
+
|
1496 |
+
def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
1497 |
+
parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"],
|
1498 |
+
help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)")
|
1499 |
+
parser.add_argument("--use_safetensors", action='store_true',
|
1500 |
+
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)")
|
1501 |
+
|
1502 |
+
# endregion
|
1503 |
+
|
1504 |
+
# region utils
|
1505 |
+
|
1506 |
+
|
1507 |
+
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
1508 |
+
# backward compatibility
|
1509 |
+
if args.caption_extention is not None:
|
1510 |
+
args.caption_extension = args.caption_extention
|
1511 |
+
args.caption_extention = None
|
1512 |
+
|
1513 |
+
if args.cache_latents:
|
1514 |
+
assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
|
1515 |
+
assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
|
1516 |
+
|
1517 |
+
# assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
|
1518 |
+
if args.resolution is not None:
|
1519 |
+
args.resolution = tuple([int(r) for r in args.resolution.split(',')])
|
1520 |
+
if len(args.resolution) == 1:
|
1521 |
+
args.resolution = (args.resolution[0], args.resolution[0])
|
1522 |
+
assert len(args.resolution) == 2, \
|
1523 |
+
f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
|
1524 |
+
|
1525 |
+
if args.face_crop_aug_range is not None:
|
1526 |
+
args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
|
1527 |
+
assert len(args.face_crop_aug_range) == 2 and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1], \
|
1528 |
+
f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}"
|
1529 |
+
else:
|
1530 |
+
args.face_crop_aug_range = None
|
1531 |
+
|
1532 |
+
if support_metadata:
|
1533 |
+
if args.in_json is not None and (args.color_aug or args.random_crop):
|
1534 |
+
print(f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます")
|
1535 |
+
|
1536 |
+
|
1537 |
+
def load_tokenizer(args: argparse.Namespace):
|
1538 |
+
print("prepare tokenizer")
|
1539 |
+
if args.v2:
|
1540 |
+
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
1541 |
+
else:
|
1542 |
+
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
|
1543 |
+
if args.max_token_length is not None:
|
1544 |
+
print(f"update token length: {args.max_token_length}")
|
1545 |
+
return tokenizer
|
1546 |
+
|
1547 |
+
|
1548 |
+
def prepare_accelerator(args: argparse.Namespace):
|
1549 |
+
if args.logging_dir is None:
|
1550 |
+
log_with = None
|
1551 |
+
logging_dir = None
|
1552 |
+
else:
|
1553 |
+
log_with = "tensorboard"
|
1554 |
+
log_prefix = "" if args.log_prefix is None else args.log_prefix
|
1555 |
+
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime())
|
1556 |
+
|
1557 |
+
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision,
|
1558 |
+
log_with=log_with, logging_dir=logging_dir)
|
1559 |
+
|
1560 |
+
# accelerateの互換性問題を解決する
|
1561 |
+
accelerator_0_15 = True
|
1562 |
+
try:
|
1563 |
+
accelerator.unwrap_model("dummy", True)
|
1564 |
+
print("Using accelerator 0.15.0 or above.")
|
1565 |
+
except TypeError:
|
1566 |
+
accelerator_0_15 = False
|
1567 |
+
|
1568 |
+
def unwrap_model(model):
|
1569 |
+
if accelerator_0_15:
|
1570 |
+
return accelerator.unwrap_model(model, True)
|
1571 |
+
return accelerator.unwrap_model(model)
|
1572 |
+
|
1573 |
+
return accelerator, unwrap_model
|
1574 |
+
|
1575 |
+
|
1576 |
+
def prepare_dtype(args: argparse.Namespace):
|
1577 |
+
weight_dtype = torch.float32
|
1578 |
+
if args.mixed_precision == "fp16":
|
1579 |
+
weight_dtype = torch.float16
|
1580 |
+
elif args.mixed_precision == "bf16":
|
1581 |
+
weight_dtype = torch.bfloat16
|
1582 |
+
|
1583 |
+
save_dtype = None
|
1584 |
+
if args.save_precision == "fp16":
|
1585 |
+
save_dtype = torch.float16
|
1586 |
+
elif args.save_precision == "bf16":
|
1587 |
+
save_dtype = torch.bfloat16
|
1588 |
+
elif args.save_precision == "float":
|
1589 |
+
save_dtype = torch.float32
|
1590 |
+
|
1591 |
+
return weight_dtype, save_dtype
|
1592 |
+
|
1593 |
+
|
1594 |
+
def load_target_model(args: argparse.Namespace, weight_dtype):
|
1595 |
+
load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers
|
1596 |
+
if load_stable_diffusion_format:
|
1597 |
+
print("load StableDiffusion checkpoint")
|
1598 |
+
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
|
1599 |
+
else:
|
1600 |
+
print("load Diffusers pretrained models")
|
1601 |
+
pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
|
1602 |
+
text_encoder = pipe.text_encoder
|
1603 |
+
vae = pipe.vae
|
1604 |
+
unet = pipe.unet
|
1605 |
+
del pipe
|
1606 |
+
|
1607 |
+
# VAEを読み込む
|
1608 |
+
if args.vae is not None:
|
1609 |
+
vae = model_util.load_vae(args.vae, weight_dtype)
|
1610 |
+
print("additional VAE loaded")
|
1611 |
+
|
1612 |
+
return text_encoder, vae, unet, load_stable_diffusion_format
|
1613 |
+
|
1614 |
+
|
1615 |
+
def patch_accelerator_for_fp16_training(accelerator):
|
1616 |
+
org_unscale_grads = accelerator.scaler._unscale_grads_
|
1617 |
+
|
1618 |
+
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
1619 |
+
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
|
1620 |
+
|
1621 |
+
accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
|
1622 |
+
|
1623 |
+
|
1624 |
+
def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None):
|
1625 |
+
# with no_token_padding, the length is not max length, return result immediately
|
1626 |
+
if input_ids.size()[-1] != tokenizer.model_max_length:
|
1627 |
+
return text_encoder(input_ids)[0]
|
1628 |
+
|
1629 |
+
b_size = input_ids.size()[0]
|
1630 |
+
input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
|
1631 |
+
|
1632 |
+
if args.clip_skip is None:
|
1633 |
+
encoder_hidden_states = text_encoder(input_ids)[0]
|
1634 |
+
else:
|
1635 |
+
enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
|
1636 |
+
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
1637 |
+
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
1638 |
+
|
1639 |
+
# bs*3, 77, 768 or 1024
|
1640 |
+
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
|
1641 |
+
|
1642 |
+
if args.max_token_length is not None:
|
1643 |
+
if args.v2:
|
1644 |
+
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
1645 |
+
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
1646 |
+
for i in range(1, args.max_token_length, tokenizer.model_max_length):
|
1647 |
+
chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # <BOS> の後から 最後の前まで
|
1648 |
+
if i > 0:
|
1649 |
+
for j in range(len(chunk)):
|
1650 |
+
if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
1651 |
+
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
1652 |
+
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
1653 |
+
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
1654 |
+
encoder_hidden_states = torch.cat(states_list, dim=1)
|
1655 |
+
else:
|
1656 |
+
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
1657 |
+
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
1658 |
+
for i in range(1, args.max_token_length, tokenizer.model_max_length):
|
1659 |
+
states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
1660 |
+
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
|
1661 |
+
encoder_hidden_states = torch.cat(states_list, dim=1)
|
1662 |
+
|
1663 |
+
if weight_dtype is not None:
|
1664 |
+
# this is required for additional network training
|
1665 |
+
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
|
1666 |
+
|
1667 |
+
return encoder_hidden_states
|
1668 |
+
|
1669 |
+
|
1670 |
+
def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch):
|
1671 |
+
model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
1672 |
+
ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt")
|
1673 |
+
return model_name, ckpt_name
|
1674 |
+
|
1675 |
+
|
1676 |
+
def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int):
|
1677 |
+
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
|
1678 |
+
if saving:
|
1679 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
1680 |
+
save_func()
|
1681 |
+
|
1682 |
+
if args.save_last_n_epochs is not None:
|
1683 |
+
remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
|
1684 |
+
remove_old_func(remove_epoch_no)
|
1685 |
+
return saving
|
1686 |
+
|
1687 |
+
|
1688 |
+
def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae):
|
1689 |
+
epoch_no = epoch + 1
|
1690 |
+
model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no)
|
1691 |
+
|
1692 |
+
if save_stable_diffusion_format:
|
1693 |
+
def save_sd():
|
1694 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
1695 |
+
print(f"saving checkpoint: {ckpt_file}")
|
1696 |
+
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
|
1697 |
+
src_path, epoch_no, global_step, save_dtype, vae)
|
1698 |
+
|
1699 |
+
def remove_sd(old_epoch_no):
|
1700 |
+
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
|
1701 |
+
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
1702 |
+
if os.path.exists(old_ckpt_file):
|
1703 |
+
print(f"removing old checkpoint: {old_ckpt_file}")
|
1704 |
+
os.remove(old_ckpt_file)
|
1705 |
+
|
1706 |
+
save_func = save_sd
|
1707 |
+
remove_old_func = remove_sd
|
1708 |
+
else:
|
1709 |
+
def save_du():
|
1710 |
+
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no))
|
1711 |
+
print(f"saving model: {out_dir}")
|
1712 |
+
os.makedirs(out_dir, exist_ok=True)
|
1713 |
+
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
|
1714 |
+
src_path, vae=vae, use_safetensors=use_safetensors)
|
1715 |
+
|
1716 |
+
def remove_du(old_epoch_no):
|
1717 |
+
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
|
1718 |
+
if os.path.exists(out_dir_old):
|
1719 |
+
print(f"removing old model: {out_dir_old}")
|
1720 |
+
shutil.rmtree(out_dir_old)
|
1721 |
+
|
1722 |
+
save_func = save_du
|
1723 |
+
remove_old_func = remove_du
|
1724 |
+
|
1725 |
+
saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
|
1726 |
+
if saving and args.save_state:
|
1727 |
+
save_state_on_epoch_end(args, accelerator, model_name, epoch_no)
|
1728 |
+
|
1729 |
+
|
1730 |
+
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no):
|
1731 |
+
print("saving state.")
|
1732 |
+
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
|
1733 |
+
|
1734 |
+
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
|
1735 |
+
if last_n_epochs is not None:
|
1736 |
+
remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs
|
1737 |
+
state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
|
1738 |
+
if os.path.exists(state_dir_old):
|
1739 |
+
print(f"removing old state: {state_dir_old}")
|
1740 |
+
shutil.rmtree(state_dir_old)
|
1741 |
+
|
1742 |
+
|
1743 |
+
def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, global_step: int, text_encoder, unet, vae):
|
1744 |
+
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
1745 |
+
|
1746 |
+
if save_stable_diffusion_format:
|
1747 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
1748 |
+
|
1749 |
+
ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt")
|
1750 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
1751 |
+
|
1752 |
+
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
|
1753 |
+
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
|
1754 |
+
src_path, epoch, global_step, save_dtype, vae)
|
1755 |
+
else:
|
1756 |
+
out_dir = os.path.join(args.output_dir, model_name)
|
1757 |
+
os.makedirs(out_dir, exist_ok=True)
|
1758 |
+
|
1759 |
+
print(f"save trained model as Diffusers to {out_dir}")
|
1760 |
+
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
|
1761 |
+
src_path, vae=vae, use_safetensors=use_safetensors)
|
1762 |
+
|
1763 |
+
|
1764 |
+
def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
1765 |
+
print("saving last state.")
|
1766 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
1767 |
+
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
1768 |
+
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
1769 |
+
|
1770 |
+
# endregion
|
1771 |
+
|
1772 |
+
# region 前処理用
|
1773 |
+
|
1774 |
+
|
1775 |
+
class ImageLoadingDataset(torch.utils.data.Dataset):
|
1776 |
+
def __init__(self, image_paths):
|
1777 |
+
self.images = image_paths
|
1778 |
+
|
1779 |
+
def __len__(self):
|
1780 |
+
return len(self.images)
|
1781 |
+
|
1782 |
+
def __getitem__(self, idx):
|
1783 |
+
img_path = self.images[idx]
|
1784 |
+
|
1785 |
+
try:
|
1786 |
+
image = Image.open(img_path).convert("RGB")
|
1787 |
+
# convert to tensor temporarily so dataloader will accept it
|
1788 |
+
tensor_pil = transforms.functional.pil_to_tensor(image)
|
1789 |
+
except Exception as e:
|
1790 |
+
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
1791 |
+
return None
|
1792 |
+
|
1793 |
+
return (tensor_pil, img_path)
|
1794 |
+
|
1795 |
+
|
1796 |
+
# endregion
|
fine_tune.py
CHANGED
@@ -13,11 +13,7 @@ import diffusers
|
|
13 |
from diffusers import DDPMScheduler
|
14 |
|
15 |
import library.train_util as train_util
|
16 |
-
|
17 |
-
from library.config_util import (
|
18 |
-
ConfigSanitizer,
|
19 |
-
BlueprintGenerator,
|
20 |
-
)
|
21 |
|
22 |
def collate_fn(examples):
|
23 |
return examples[0]
|
@@ -34,36 +30,25 @@ def train(args):
|
|
34 |
|
35 |
tokenizer = train_util.load_tokenizer(args)
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
"image_dir": args.train_data_dir,
|
49 |
-
"metadata_file": args.in_json,
|
50 |
-
}]
|
51 |
-
}]
|
52 |
-
}
|
53 |
-
|
54 |
-
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
55 |
-
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
56 |
|
57 |
if args.debug_dataset:
|
58 |
-
train_util.debug_dataset(
|
59 |
return
|
60 |
-
if len(
|
61 |
print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
|
62 |
return
|
63 |
|
64 |
-
if cache_latents:
|
65 |
-
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
66 |
-
|
67 |
# acceleratorを準備する
|
68 |
print("prepare accelerator")
|
69 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
@@ -124,7 +109,7 @@ def train(args):
|
|
124 |
vae.requires_grad_(False)
|
125 |
vae.eval()
|
126 |
with torch.no_grad():
|
127 |
-
|
128 |
vae.to("cpu")
|
129 |
if torch.cuda.is_available():
|
130 |
torch.cuda.empty_cache()
|
@@ -164,13 +149,33 @@ def train(args):
|
|
164 |
|
165 |
# 学習に必要なクラスを準備する
|
166 |
print("prepare optimizer, data loader etc.")
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
# dataloaderを準備する
|
170 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
171 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
172 |
train_dataloader = torch.utils.data.DataLoader(
|
173 |
-
|
174 |
|
175 |
# 学習ステップ数を計算する
|
176 |
if args.max_train_epochs is not None:
|
@@ -178,9 +183,8 @@ def train(args):
|
|
178 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
179 |
|
180 |
# lr schedulerを用意する
|
181 |
-
lr_scheduler =
|
182 |
-
|
183 |
-
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
184 |
|
185 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
186 |
if args.full_fp16:
|
@@ -214,7 +218,7 @@ def train(args):
|
|
214 |
# 学習する
|
215 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
216 |
print("running training / 学習開始")
|
217 |
-
print(f" num examples / サンプル数: {
|
218 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
219 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
220 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
@@ -233,7 +237,7 @@ def train(args):
|
|
233 |
|
234 |
for epoch in range(num_train_epochs):
|
235 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
236 |
-
|
237 |
|
238 |
for m in training_models:
|
239 |
m.train()
|
@@ -282,11 +286,11 @@ def train(args):
|
|
282 |
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
283 |
|
284 |
accelerator.backward(loss)
|
285 |
-
if accelerator.sync_gradients
|
286 |
params_to_clip = []
|
287 |
for m in training_models:
|
288 |
params_to_clip.extend(m.parameters())
|
289 |
-
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
290 |
|
291 |
optimizer.step()
|
292 |
lr_scheduler.step()
|
@@ -297,16 +301,11 @@ def train(args):
|
|
297 |
progress_bar.update(1)
|
298 |
global_step += 1
|
299 |
|
300 |
-
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
301 |
-
|
302 |
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
303 |
if args.logging_dir is not None:
|
304 |
-
logs = {"loss": current_loss, "lr":
|
305 |
-
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
306 |
-
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
307 |
accelerator.log(logs, step=global_step)
|
308 |
|
309 |
-
# TODO moving averageにする
|
310 |
loss_total += current_loss
|
311 |
avr_loss = loss_total / (step+1)
|
312 |
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
@@ -316,7 +315,7 @@ def train(args):
|
|
316 |
break
|
317 |
|
318 |
if args.logging_dir is not None:
|
319 |
-
logs = {"
|
320 |
accelerator.log(logs, step=epoch+1)
|
321 |
|
322 |
accelerator.wait_for_everyone()
|
@@ -326,8 +325,6 @@ def train(args):
|
|
326 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
327 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
328 |
|
329 |
-
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
330 |
-
|
331 |
is_main_process = accelerator.is_main_process
|
332 |
if is_main_process:
|
333 |
unet = unwrap_model(unet)
|
@@ -354,8 +351,6 @@ if __name__ == '__main__':
|
|
354 |
train_util.add_dataset_arguments(parser, False, True, True)
|
355 |
train_util.add_training_arguments(parser, False)
|
356 |
train_util.add_sd_saving_arguments(parser)
|
357 |
-
train_util.add_optimizer_arguments(parser)
|
358 |
-
config_util.add_config_arguments(parser)
|
359 |
|
360 |
parser.add_argument("--diffusers_xformers", action='store_true',
|
361 |
help='use xformers by diffusers / Diffusersでxformersを使用する')
|
|
|
13 |
from diffusers import DDPMScheduler
|
14 |
|
15 |
import library.train_util as train_util
|
16 |
+
|
|
|
|
|
|
|
|
|
17 |
|
18 |
def collate_fn(examples):
|
19 |
return examples[0]
|
|
|
30 |
|
31 |
tokenizer = train_util.load_tokenizer(args)
|
32 |
|
33 |
+
train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
34 |
+
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
35 |
+
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
36 |
+
args.bucket_reso_steps, args.bucket_no_upscale,
|
37 |
+
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
38 |
+
args.dataset_repeats, args.debug_dataset)
|
39 |
+
|
40 |
+
# 学習データのdropout率を設定する
|
41 |
+
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
|
42 |
+
|
43 |
+
train_dataset.make_buckets()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
if args.debug_dataset:
|
46 |
+
train_util.debug_dataset(train_dataset)
|
47 |
return
|
48 |
+
if len(train_dataset) == 0:
|
49 |
print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
|
50 |
return
|
51 |
|
|
|
|
|
|
|
52 |
# acceleratorを準備する
|
53 |
print("prepare accelerator")
|
54 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
|
|
109 |
vae.requires_grad_(False)
|
110 |
vae.eval()
|
111 |
with torch.no_grad():
|
112 |
+
train_dataset.cache_latents(vae)
|
113 |
vae.to("cpu")
|
114 |
if torch.cuda.is_available():
|
115 |
torch.cuda.empty_cache()
|
|
|
149 |
|
150 |
# 学習に必要なクラスを準備する
|
151 |
print("prepare optimizer, data loader etc.")
|
152 |
+
|
153 |
+
# 8-bit Adamを使う
|
154 |
+
if args.use_8bit_adam:
|
155 |
+
try:
|
156 |
+
import bitsandbytes as bnb
|
157 |
+
except ImportError:
|
158 |
+
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
159 |
+
print("use 8-bit Adam optimizer")
|
160 |
+
optimizer_class = bnb.optim.AdamW8bit
|
161 |
+
elif args.use_lion_optimizer:
|
162 |
+
try:
|
163 |
+
import lion_pytorch
|
164 |
+
except ImportError:
|
165 |
+
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
166 |
+
print("use Lion optimizer")
|
167 |
+
optimizer_class = lion_pytorch.Lion
|
168 |
+
else:
|
169 |
+
optimizer_class = torch.optim.AdamW
|
170 |
+
|
171 |
+
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
172 |
+
optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
|
173 |
|
174 |
# dataloaderを準備する
|
175 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
176 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
177 |
train_dataloader = torch.utils.data.DataLoader(
|
178 |
+
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
179 |
|
180 |
# 学習ステップ数を計算する
|
181 |
if args.max_train_epochs is not None:
|
|
|
183 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
184 |
|
185 |
# lr schedulerを用意する
|
186 |
+
lr_scheduler = diffusers.optimization.get_scheduler(
|
187 |
+
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
|
|
188 |
|
189 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
190 |
if args.full_fp16:
|
|
|
218 |
# 学習する
|
219 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
220 |
print("running training / 学習開始")
|
221 |
+
print(f" num examples / サンプル数: {train_dataset.num_train_images}")
|
222 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
223 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
224 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
|
|
237 |
|
238 |
for epoch in range(num_train_epochs):
|
239 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
240 |
+
train_dataset.set_current_epoch(epoch + 1)
|
241 |
|
242 |
for m in training_models:
|
243 |
m.train()
|
|
|
286 |
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
287 |
|
288 |
accelerator.backward(loss)
|
289 |
+
if accelerator.sync_gradients:
|
290 |
params_to_clip = []
|
291 |
for m in training_models:
|
292 |
params_to_clip.extend(m.parameters())
|
293 |
+
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
294 |
|
295 |
optimizer.step()
|
296 |
lr_scheduler.step()
|
|
|
301 |
progress_bar.update(1)
|
302 |
global_step += 1
|
303 |
|
|
|
|
|
304 |
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
305 |
if args.logging_dir is not None:
|
306 |
+
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
|
|
307 |
accelerator.log(logs, step=global_step)
|
308 |
|
|
|
309 |
loss_total += current_loss
|
310 |
avr_loss = loss_total / (step+1)
|
311 |
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
315 |
break
|
316 |
|
317 |
if args.logging_dir is not None:
|
318 |
+
logs = {"epoch_loss": loss_total / len(train_dataloader)}
|
319 |
accelerator.log(logs, step=epoch+1)
|
320 |
|
321 |
accelerator.wait_for_everyone()
|
|
|
325 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
326 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
327 |
|
|
|
|
|
328 |
is_main_process = accelerator.is_main_process
|
329 |
if is_main_process:
|
330 |
unet = unwrap_model(unet)
|
|
|
351 |
train_util.add_dataset_arguments(parser, False, True, True)
|
352 |
train_util.add_training_arguments(parser, False)
|
353 |
train_util.add_sd_saving_arguments(parser)
|
|
|
|
|
354 |
|
355 |
parser.add_argument("--diffusers_xformers", action='store_true',
|
356 |
help='use xformers by diffusers / Diffusersでxformersを使用する')
|
gen_img_diffusers.py
CHANGED
@@ -47,7 +47,7 @@ VGG(
|
|
47 |
"""
|
48 |
|
49 |
import json
|
50 |
-
from typing import
|
51 |
import glob
|
52 |
import importlib
|
53 |
import inspect
|
@@ -60,6 +60,7 @@ import math
|
|
60 |
import os
|
61 |
import random
|
62 |
import re
|
|
|
63 |
|
64 |
import diffusers
|
65 |
import numpy as np
|
@@ -80,9 +81,6 @@ from PIL import Image
|
|
80 |
from PIL.PngImagePlugin import PngInfo
|
81 |
|
82 |
import library.model_util as model_util
|
83 |
-
import library.train_util as train_util
|
84 |
-
import tools.original_control_net as original_control_net
|
85 |
-
from tools.original_control_net import ControlNetInfo
|
86 |
|
87 |
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
88 |
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
@@ -489,9 +487,6 @@ class PipelineLike():
|
|
489 |
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
490 |
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
491 |
|
492 |
-
# ControlNet
|
493 |
-
self.control_nets: List[ControlNetInfo] = []
|
494 |
-
|
495 |
# Textual Inversion
|
496 |
def add_token_replacement(self, target_token_id, rep_token_ids):
|
497 |
self.token_replacements[target_token_id] = rep_token_ids
|
@@ -505,11 +500,7 @@ class PipelineLike():
|
|
505 |
new_tokens.append(token)
|
506 |
return new_tokens
|
507 |
|
508 |
-
def set_control_nets(self, ctrl_nets):
|
509 |
-
self.control_nets = ctrl_nets
|
510 |
-
|
511 |
# region xformersとか使う部分:独自に書き換えるので関係なし
|
512 |
-
|
513 |
def enable_xformers_memory_efficient_attention(self):
|
514 |
r"""
|
515 |
Enable memory efficient attention as implemented in xformers.
|
@@ -590,8 +581,6 @@ class PipelineLike():
|
|
590 |
latents: Optional[torch.FloatTensor] = None,
|
591 |
max_embeddings_multiples: Optional[int] = 3,
|
592 |
output_type: Optional[str] = "pil",
|
593 |
-
vae_batch_size: float = None,
|
594 |
-
return_latents: bool = False,
|
595 |
# return_dict: bool = True,
|
596 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
597 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
@@ -683,9 +672,6 @@ class PipelineLike():
|
|
683 |
else:
|
684 |
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
685 |
|
686 |
-
vae_batch_size = batch_size if vae_batch_size is None else (
|
687 |
-
int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size)))
|
688 |
-
|
689 |
if strength < 0 or strength > 1:
|
690 |
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
691 |
|
@@ -766,7 +752,7 @@ class PipelineLike():
|
|
766 |
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
|
767 |
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
|
768 |
|
769 |
-
if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None
|
770 |
if isinstance(clip_guide_images, PIL.Image.Image):
|
771 |
clip_guide_images = [clip_guide_images]
|
772 |
|
@@ -779,7 +765,7 @@ class PipelineLike():
|
|
779 |
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
|
780 |
if len(image_embeddings_clip) == 1:
|
781 |
image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
|
782 |
-
|
783 |
size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
|
784 |
clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
|
785 |
clip_guide_images = torch.cat(clip_guide_images, dim=0)
|
@@ -788,10 +774,6 @@ class PipelineLike():
|
|
788 |
image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
|
789 |
if len(image_embeddings_vgg16) == 1:
|
790 |
image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
|
791 |
-
else:
|
792 |
-
# ControlNetのhintにguide imageを流用する
|
793 |
-
# 前処理はControlNet側で行う
|
794 |
-
pass
|
795 |
|
796 |
# set timesteps
|
797 |
self.scheduler.set_timesteps(num_inference_steps, self.device)
|
@@ -799,6 +781,7 @@ class PipelineLike():
|
|
799 |
latents_dtype = text_embeddings.dtype
|
800 |
init_latents_orig = None
|
801 |
mask = None
|
|
|
802 |
|
803 |
if init_image is None:
|
804 |
# get the initial random noise unless the user supplied it
|
@@ -830,8 +813,6 @@ class PipelineLike():
|
|
830 |
if isinstance(init_image[0], PIL.Image.Image):
|
831 |
init_image = [preprocess_image(im) for im in init_image]
|
832 |
init_image = torch.cat(init_image)
|
833 |
-
if isinstance(init_image, list):
|
834 |
-
init_image = torch.stack(init_image)
|
835 |
|
836 |
# mask image to tensor
|
837 |
if mask_image is not None:
|
@@ -842,24 +823,9 @@ class PipelineLike():
|
|
842 |
|
843 |
# encode the init image into latents and scale the latents
|
844 |
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
if vae_batch_size >= batch_size:
|
849 |
-
init_latent_dist = self.vae.encode(init_image).latent_dist
|
850 |
-
init_latents = init_latent_dist.sample(generator=generator)
|
851 |
-
else:
|
852 |
-
if torch.cuda.is_available():
|
853 |
-
torch.cuda.empty_cache()
|
854 |
-
init_latents = []
|
855 |
-
for i in tqdm(range(0, batch_size, vae_batch_size)):
|
856 |
-
init_latent_dist = self.vae.encode(init_image[i:i + vae_batch_size]
|
857 |
-
if vae_batch_size > 1 else init_image[i].unsqueeze(0)).latent_dist
|
858 |
-
init_latents.append(init_latent_dist.sample(generator=generator))
|
859 |
-
init_latents = torch.cat(init_latents)
|
860 |
-
|
861 |
-
init_latents = 0.18215 * init_latents
|
862 |
-
|
863 |
if len(init_latents) == 1:
|
864 |
init_latents = init_latents.repeat((batch_size, 1, 1, 1))
|
865 |
init_latents_orig = init_latents
|
@@ -898,21 +864,12 @@ class PipelineLike():
|
|
898 |
extra_step_kwargs["eta"] = eta
|
899 |
|
900 |
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
|
901 |
-
|
902 |
-
if self.control_nets:
|
903 |
-
guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
|
904 |
-
|
905 |
for i, t in enumerate(tqdm(timesteps)):
|
906 |
# expand the latents if we are doing classifier free guidance
|
907 |
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
908 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
909 |
-
|
910 |
# predict the noise residual
|
911 |
-
|
912 |
-
noise_pred = original_control_net.call_unet_and_control_net(
|
913 |
-
i, num_latent_input, self.unet, self.control_nets, guided_hints, i / len(timesteps), latent_model_input, t, text_embeddings).sample
|
914 |
-
else:
|
915 |
-
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
916 |
|
917 |
# perform guidance
|
918 |
if do_classifier_free_guidance:
|
@@ -954,19 +911,8 @@ class PipelineLike():
|
|
954 |
if is_cancelled_callback is not None and is_cancelled_callback():
|
955 |
return None
|
956 |
|
957 |
-
if return_latents:
|
958 |
-
return (latents, False)
|
959 |
-
|
960 |
latents = 1 / 0.18215 * latents
|
961 |
-
|
962 |
-
image = self.vae.decode(latents).sample
|
963 |
-
else:
|
964 |
-
if torch.cuda.is_available():
|
965 |
-
torch.cuda.empty_cache()
|
966 |
-
images = []
|
967 |
-
for i in tqdm(range(0, batch_size, vae_batch_size)):
|
968 |
-
images.append(self.vae.decode(latents[i:i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample)
|
969 |
-
image = torch.cat(images)
|
970 |
|
971 |
image = (image / 2 + 0.5).clamp(0, 1)
|
972 |
|
@@ -1649,11 +1595,10 @@ def get_unweighted_text_embeddings(
|
|
1649 |
if pad == eos: # v1
|
1650 |
text_input_chunk[:, -1] = text_input[0, -1]
|
1651 |
else: # v2
|
1652 |
-
|
1653 |
-
|
1654 |
-
|
1655 |
-
|
1656 |
-
text_input_chunk[j, 1] = eos
|
1657 |
|
1658 |
if clip_skip is None or clip_skip == 1:
|
1659 |
text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
@@ -1854,7 +1799,7 @@ def preprocess_mask(mask):
|
|
1854 |
mask = mask.convert("L")
|
1855 |
w, h = mask.size
|
1856 |
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
1857 |
-
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.
|
1858 |
mask = np.array(mask).astype(np.float32) / 255.0
|
1859 |
mask = np.tile(mask, (4, 1, 1))
|
1860 |
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
@@ -1872,35 +1817,6 @@ def preprocess_mask(mask):
|
|
1872 |
# return text_encoder
|
1873 |
|
1874 |
|
1875 |
-
class BatchDataBase(NamedTuple):
|
1876 |
-
# バッチ分割が必要ないデータ
|
1877 |
-
step: int
|
1878 |
-
prompt: str
|
1879 |
-
negative_prompt: str
|
1880 |
-
seed: int
|
1881 |
-
init_image: Any
|
1882 |
-
mask_image: Any
|
1883 |
-
clip_prompt: str
|
1884 |
-
guide_image: Any
|
1885 |
-
|
1886 |
-
|
1887 |
-
class BatchDataExt(NamedTuple):
|
1888 |
-
# バッチ分割が必要なデータ
|
1889 |
-
width: int
|
1890 |
-
height: int
|
1891 |
-
steps: int
|
1892 |
-
scale: float
|
1893 |
-
negative_scale: float
|
1894 |
-
strength: float
|
1895 |
-
network_muls: Tuple[float]
|
1896 |
-
|
1897 |
-
|
1898 |
-
class BatchData(NamedTuple):
|
1899 |
-
return_latents: bool
|
1900 |
-
base: BatchDataBase
|
1901 |
-
ext: BatchDataExt
|
1902 |
-
|
1903 |
-
|
1904 |
def main(args):
|
1905 |
if args.fp16:
|
1906 |
dtype = torch.float16
|
@@ -1965,7 +1881,10 @@ def main(args):
|
|
1965 |
# tokenizerを読み込む
|
1966 |
print("loading tokenizer")
|
1967 |
if use_stable_diffusion_format:
|
1968 |
-
|
|
|
|
|
|
|
1969 |
|
1970 |
# schedulerを用意する
|
1971 |
sched_init_args = {}
|
@@ -2076,13 +1995,11 @@ def main(args):
|
|
2076 |
# networkを組み込む
|
2077 |
if args.network_module:
|
2078 |
networks = []
|
2079 |
-
network_default_muls = []
|
2080 |
for i, network_module in enumerate(args.network_module):
|
2081 |
print("import network module:", network_module)
|
2082 |
imported_module = importlib.import_module(network_module)
|
2083 |
|
2084 |
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
2085 |
-
network_default_muls.append(network_mul)
|
2086 |
|
2087 |
net_kwargs = {}
|
2088 |
if args.network_args and i < len(args.network_args):
|
@@ -2097,7 +2014,7 @@ def main(args):
|
|
2097 |
network_weight = args.network_weights[i]
|
2098 |
print("load network weights from:", network_weight)
|
2099 |
|
2100 |
-
if model_util.is_safetensors(network_weight)
|
2101 |
from safetensors.torch import safe_open
|
2102 |
with safe_open(network_weight, framework="pt") as f:
|
2103 |
metadata = f.metadata()
|
@@ -2120,18 +2037,6 @@ def main(args):
|
|
2120 |
else:
|
2121 |
networks = []
|
2122 |
|
2123 |
-
# ControlNetの処理
|
2124 |
-
control_nets: List[ControlNetInfo] = []
|
2125 |
-
if args.control_net_models:
|
2126 |
-
for i, model in enumerate(args.control_net_models):
|
2127 |
-
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
|
2128 |
-
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
|
2129 |
-
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
|
2130 |
-
|
2131 |
-
ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
|
2132 |
-
prep = original_control_net.load_preprocess(prep_type)
|
2133 |
-
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
|
2134 |
-
|
2135 |
if args.opt_channels_last:
|
2136 |
print(f"set optimizing: channels last")
|
2137 |
text_encoder.to(memory_format=torch.channels_last)
|
@@ -2145,14 +2050,9 @@ def main(args):
|
|
2145 |
if vgg16_model is not None:
|
2146 |
vgg16_model.to(memory_format=torch.channels_last)
|
2147 |
|
2148 |
-
for cn in control_nets:
|
2149 |
-
cn.unet.to(memory_format=torch.channels_last)
|
2150 |
-
cn.net.to(memory_format=torch.channels_last)
|
2151 |
-
|
2152 |
pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
|
2153 |
clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
|
2154 |
vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
|
2155 |
-
pipe.set_control_nets(control_nets)
|
2156 |
print("pipeline is ready.")
|
2157 |
|
2158 |
if args.diffusers_xformers:
|
@@ -2277,34 +2177,18 @@ def main(args):
|
|
2277 |
mask_images = l
|
2278 |
|
2279 |
# 画像サイズにオプション指定があるときはリサイズする
|
2280 |
-
if args.W is not None and args.H is not None:
|
2281 |
-
|
2282 |
-
|
2283 |
-
init_images = resize_images(init_images, (args.W, args.H))
|
2284 |
if mask_images is not None:
|
2285 |
print(f"resize img2img mask images to {args.W}*{args.H}")
|
2286 |
mask_images = resize_images(mask_images, (args.W, args.H))
|
2287 |
|
2288 |
-
if networks and mask_images:
|
2289 |
-
# mask を領域情報として流用する、現在は1枚だけ対応
|
2290 |
-
# TODO 複数のnetwork classの混在時の考慮
|
2291 |
-
print("use mask as region")
|
2292 |
-
# import cv2
|
2293 |
-
# for i in range(3):
|
2294 |
-
# cv2.imshow("msk", np.array(mask_images[0])[:,:,i])
|
2295 |
-
# cv2.waitKey()
|
2296 |
-
# cv2.destroyAllWindows()
|
2297 |
-
networks[0].__class__.set_regions(networks, np.array(mask_images[0]))
|
2298 |
-
mask_images = None
|
2299 |
-
|
2300 |
prev_image = None # for VGG16 guided
|
2301 |
if args.guide_image_path is not None:
|
2302 |
-
print(f"load image for CLIP/VGG16
|
2303 |
-
guide_images =
|
2304 |
-
for
|
2305 |
-
guide_images.extend(load_images(p))
|
2306 |
-
|
2307 |
-
print(f"loaded {len(guide_images)} guide images for guidance")
|
2308 |
if len(guide_images) == 0:
|
2309 |
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
|
2310 |
guide_images = None
|
@@ -2335,46 +2219,33 @@ def main(args):
|
|
2335 |
iter_seed = random.randint(0, 0x7fffffff)
|
2336 |
|
2337 |
# バッチ処理の関数
|
2338 |
-
def process_batch(batch
|
2339 |
batch_size = len(batch)
|
2340 |
|
2341 |
# highres_fixの処理
|
2342 |
if highres_fix and not highres_1st:
|
2343 |
-
# 1st stage
|
2344 |
-
print("process 1st
|
2345 |
batch_1st = []
|
2346 |
-
for
|
2347 |
-
width_1st = int(
|
2348 |
-
height_1st = int(
|
2349 |
width_1st = width_1st - width_1st % 32
|
2350 |
height_1st = height_1st - height_1st % 32
|
2351 |
-
|
2352 |
-
ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale,
|
2353 |
-
ext.negative_scale, ext.strength, ext.network_muls)
|
2354 |
-
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
|
2355 |
images_1st = process_batch(batch_1st, True, True)
|
2356 |
|
2357 |
# 2nd stageのバッチを作成して以下処理する
|
2358 |
-
print("process 2nd
|
2359 |
-
if args.highres_fix_latents_upscaling:
|
2360 |
-
org_dtype = images_1st.dtype
|
2361 |
-
if images_1st.dtype == torch.bfloat16:
|
2362 |
-
images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
|
2363 |
-
images_1st = torch.nn.functional.interpolate(
|
2364 |
-
images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode='bilinear') # , antialias=True)
|
2365 |
-
images_1st = images_1st.to(org_dtype)
|
2366 |
-
|
2367 |
batch_2nd = []
|
2368 |
-
for i, (
|
2369 |
-
|
2370 |
-
|
2371 |
-
|
2372 |
-
batch_2nd.append(bd_2nd)
|
2373 |
batch = batch_2nd
|
2374 |
|
2375 |
-
|
2376 |
-
|
2377 |
-
(width, height, steps, scale, negative_scale, strength, network_muls) = batch[0]
|
2378 |
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
2379 |
|
2380 |
prompts = []
|
@@ -2407,7 +2278,7 @@ def main(args):
|
|
2407 |
all_images_are_same = True
|
2408 |
all_masks_are_same = True
|
2409 |
all_guide_images_are_same = True
|
2410 |
-
for i, (
|
2411 |
prompts.append(prompt)
|
2412 |
negative_prompts.append(negative_prompt)
|
2413 |
seeds.append(seed)
|
@@ -2424,13 +2295,9 @@ def main(args):
|
|
2424 |
all_masks_are_same = mask_images[-2] is mask_image
|
2425 |
|
2426 |
if guide_image is not None:
|
2427 |
-
|
2428 |
-
|
2429 |
-
all_guide_images_are_same =
|
2430 |
-
else:
|
2431 |
-
guide_images.append(guide_image)
|
2432 |
-
if i > 0 and all_guide_images_are_same:
|
2433 |
-
all_guide_images_are_same = guide_images[-2] is guide_image
|
2434 |
|
2435 |
# make start code
|
2436 |
torch.manual_seed(seed)
|
@@ -2453,24 +2320,10 @@ def main(args):
|
|
2453 |
if guide_images is not None and all_guide_images_are_same:
|
2454 |
guide_images = guide_images[0]
|
2455 |
|
2456 |
-
# ControlNet使用時はguide imageをリサイズする
|
2457 |
-
if control_nets:
|
2458 |
-
# TODO resample��メソッド
|
2459 |
-
guide_images = guide_images if type(guide_images) == list else [guide_images]
|
2460 |
-
guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images]
|
2461 |
-
if len(guide_images) == 1:
|
2462 |
-
guide_images = guide_images[0]
|
2463 |
-
|
2464 |
# generate
|
2465 |
-
if networks:
|
2466 |
-
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
2467 |
-
n.set_multiplier(m)
|
2468 |
-
|
2469 |
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
|
2470 |
-
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises,
|
2471 |
-
|
2472 |
-
clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
|
2473 |
-
if highres_1st and not args.highres_fix_save_1st: # return images or latents
|
2474 |
return images
|
2475 |
|
2476 |
# save image
|
@@ -2545,7 +2398,6 @@ def main(args):
|
|
2545 |
strength = 0.8 if args.strength is None else args.strength
|
2546 |
negative_prompt = ""
|
2547 |
clip_prompt = None
|
2548 |
-
network_muls = None
|
2549 |
|
2550 |
prompt_args = prompt.strip().split(' --')
|
2551 |
prompt = prompt_args[0]
|
@@ -2609,15 +2461,6 @@ def main(args):
|
|
2609 |
clip_prompt = m.group(1)
|
2610 |
print(f"clip prompt: {clip_prompt}")
|
2611 |
continue
|
2612 |
-
|
2613 |
-
m = re.match(r'am ([\d\.\-,]+)', parg, re.IGNORECASE)
|
2614 |
-
if m: # network multiplies
|
2615 |
-
network_muls = [float(v) for v in m.group(1).split(",")]
|
2616 |
-
while len(network_muls) < len(networks):
|
2617 |
-
network_muls.append(network_muls[-1])
|
2618 |
-
print(f"network mul: {network_muls}")
|
2619 |
-
continue
|
2620 |
-
|
2621 |
except ValueError as ex:
|
2622 |
print(f"Exception in parsing / 解析エラー: {parg}")
|
2623 |
print(ex)
|
@@ -2655,12 +2498,7 @@ def main(args):
|
|
2655 |
mask_image = mask_images[global_step % len(mask_images)]
|
2656 |
|
2657 |
if guide_images is not None:
|
2658 |
-
|
2659 |
-
c = len(control_nets)
|
2660 |
-
p = global_step % (len(guide_images) // c)
|
2661 |
-
guide_image = guide_images[p * c:p * c + c]
|
2662 |
-
else:
|
2663 |
-
guide_image = guide_images[global_step % len(guide_images)]
|
2664 |
elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
|
2665 |
if prev_image is None:
|
2666 |
print("Generate 1st image without guide image.")
|
@@ -2668,9 +2506,10 @@ def main(args):
|
|
2668 |
print("Use previous image as guide image.")
|
2669 |
guide_image = prev_image
|
2670 |
|
2671 |
-
|
2672 |
-
|
2673 |
-
|
|
|
2674 |
process_batch(batch_data, highres_fix)
|
2675 |
batch_data.clear()
|
2676 |
|
@@ -2714,8 +2553,6 @@ if __name__ == '__main__':
|
|
2714 |
parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
|
2715 |
parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
|
2716 |
parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
|
2717 |
-
parser.add_argument("--vae_batch_size", type=float, default=None,
|
2718 |
-
help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率")
|
2719 |
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
|
2720 |
parser.add_argument('--sampler', type=str, default='ddim',
|
2721 |
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
|
@@ -2727,8 +2564,6 @@ if __name__ == '__main__':
|
|
2727 |
parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
|
2728 |
parser.add_argument("--vae", type=str, default=None,
|
2729 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
2730 |
-
parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
|
2731 |
-
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
|
2732 |
# parser.add_argument("--replace_clip_l14_336", action='store_true',
|
2733 |
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
|
2734 |
parser.add_argument("--seed", type=int, default=None,
|
@@ -2743,15 +2578,12 @@ if __name__ == '__main__':
|
|
2743 |
parser.add_argument("--opt_channels_last", action='store_true',
|
2744 |
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
|
2745 |
parser.add_argument("--network_module", type=str, default=None, nargs='*',
|
2746 |
-
help='
|
2747 |
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
2748 |
-
help='
|
2749 |
-
parser.add_argument("--network_mul", type=float, default=None, nargs='*',
|
2750 |
-
help='additional network multiplier / 追加ネットワークの効果の倍率')
|
2751 |
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
2752 |
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
2753 |
-
parser.add_argument("--network_show_meta", action='store_true',
|
2754 |
-
help='show metadata of network model / ネットワークモデルのメタデータを表示する')
|
2755 |
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
|
2756 |
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
|
2757 |
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
@@ -2765,26 +2597,15 @@ if __name__ == '__main__':
|
|
2765 |
help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
|
2766 |
parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
|
2767 |
help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
|
2768 |
-
parser.add_argument("--guide_image_path", type=str, default=None,
|
2769 |
-
help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
|
2770 |
parser.add_argument("--highres_fix_scale", type=float, default=None,
|
2771 |
help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
|
2772 |
parser.add_argument("--highres_fix_steps", type=int, default=28,
|
2773 |
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
|
2774 |
parser.add_argument("--highres_fix_save_1st", action='store_true',
|
2775 |
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
|
2776 |
-
parser.add_argument("--highres_fix_latents_upscaling", action='store_true',
|
2777 |
-
help="use latents upscaling for highres fix / highres fixでlatentで拡大する")
|
2778 |
parser.add_argument("--negative_scale", type=float, default=None,
|
2779 |
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
|
2780 |
|
2781 |
-
parser.add_argument("--control_net_models", type=str, default=None, nargs='*',
|
2782 |
-
help='ControlNet models to use / 使用するControlNetのモデル名')
|
2783 |
-
parser.add_argument("--control_net_preps", type=str, default=None, nargs='*',
|
2784 |
-
help='ControlNet preprocess to use / 使用するControlNetのプリプロセス名')
|
2785 |
-
parser.add_argument("--control_net_weights", type=float, default=None, nargs='*', help='ControlNet weights / ControlNetの重み')
|
2786 |
-
parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*',
|
2787 |
-
help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率')
|
2788 |
-
|
2789 |
args = parser.parse_args()
|
2790 |
main(args)
|
|
|
47 |
"""
|
48 |
|
49 |
import json
|
50 |
+
from typing import List, Optional, Union
|
51 |
import glob
|
52 |
import importlib
|
53 |
import inspect
|
|
|
60 |
import os
|
61 |
import random
|
62 |
import re
|
63 |
+
from typing import Any, Callable, List, Optional, Union
|
64 |
|
65 |
import diffusers
|
66 |
import numpy as np
|
|
|
81 |
from PIL.PngImagePlugin import PngInfo
|
82 |
|
83 |
import library.model_util as model_util
|
|
|
|
|
|
|
84 |
|
85 |
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
86 |
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
|
|
487 |
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
488 |
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
489 |
|
|
|
|
|
|
|
490 |
# Textual Inversion
|
491 |
def add_token_replacement(self, target_token_id, rep_token_ids):
|
492 |
self.token_replacements[target_token_id] = rep_token_ids
|
|
|
500 |
new_tokens.append(token)
|
501 |
return new_tokens
|
502 |
|
|
|
|
|
|
|
503 |
# region xformersとか使う部分:独自に書き換えるので関係なし
|
|
|
504 |
def enable_xformers_memory_efficient_attention(self):
|
505 |
r"""
|
506 |
Enable memory efficient attention as implemented in xformers.
|
|
|
581 |
latents: Optional[torch.FloatTensor] = None,
|
582 |
max_embeddings_multiples: Optional[int] = 3,
|
583 |
output_type: Optional[str] = "pil",
|
|
|
|
|
584 |
# return_dict: bool = True,
|
585 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
586 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
|
|
672 |
else:
|
673 |
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
674 |
|
|
|
|
|
|
|
675 |
if strength < 0 or strength > 1:
|
676 |
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
677 |
|
|
|
752 |
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
|
753 |
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
|
754 |
|
755 |
+
if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None:
|
756 |
if isinstance(clip_guide_images, PIL.Image.Image):
|
757 |
clip_guide_images = [clip_guide_images]
|
758 |
|
|
|
765 |
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
|
766 |
if len(image_embeddings_clip) == 1:
|
767 |
image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
|
768 |
+
else:
|
769 |
size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
|
770 |
clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
|
771 |
clip_guide_images = torch.cat(clip_guide_images, dim=0)
|
|
|
774 |
image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
|
775 |
if len(image_embeddings_vgg16) == 1:
|
776 |
image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
|
|
|
|
|
|
|
|
|
777 |
|
778 |
# set timesteps
|
779 |
self.scheduler.set_timesteps(num_inference_steps, self.device)
|
|
|
781 |
latents_dtype = text_embeddings.dtype
|
782 |
init_latents_orig = None
|
783 |
mask = None
|
784 |
+
noise = None
|
785 |
|
786 |
if init_image is None:
|
787 |
# get the initial random noise unless the user supplied it
|
|
|
813 |
if isinstance(init_image[0], PIL.Image.Image):
|
814 |
init_image = [preprocess_image(im) for im in init_image]
|
815 |
init_image = torch.cat(init_image)
|
|
|
|
|
816 |
|
817 |
# mask image to tensor
|
818 |
if mask_image is not None:
|
|
|
823 |
|
824 |
# encode the init image into latents and scale the latents
|
825 |
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
826 |
+
init_latent_dist = self.vae.encode(init_image).latent_dist
|
827 |
+
init_latents = init_latent_dist.sample(generator=generator)
|
828 |
+
init_latents = 0.18215 * init_latents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
829 |
if len(init_latents) == 1:
|
830 |
init_latents = init_latents.repeat((batch_size, 1, 1, 1))
|
831 |
init_latents_orig = init_latents
|
|
|
864 |
extra_step_kwargs["eta"] = eta
|
865 |
|
866 |
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
|
|
|
|
|
|
|
|
|
867 |
for i, t in enumerate(tqdm(timesteps)):
|
868 |
# expand the latents if we are doing classifier free guidance
|
869 |
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
870 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
|
|
871 |
# predict the noise residual
|
872 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
|
|
|
|
|
|
|
|
873 |
|
874 |
# perform guidance
|
875 |
if do_classifier_free_guidance:
|
|
|
911 |
if is_cancelled_callback is not None and is_cancelled_callback():
|
912 |
return None
|
913 |
|
|
|
|
|
|
|
914 |
latents = 1 / 0.18215 * latents
|
915 |
+
image = self.vae.decode(latents).sample
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
916 |
|
917 |
image = (image / 2 + 0.5).clamp(0, 1)
|
918 |
|
|
|
1595 |
if pad == eos: # v1
|
1596 |
text_input_chunk[:, -1] = text_input[0, -1]
|
1597 |
else: # v2
|
1598 |
+
if text_input_chunk[:, -1] != eos and text_input_chunk[:, -1] != pad: # 最後に普通の文字がある
|
1599 |
+
text_input_chunk[:, -1] = eos
|
1600 |
+
if text_input_chunk[:, 1] == pad: # BOSだけであとはPAD
|
1601 |
+
text_input_chunk[:, 1] = eos
|
|
|
1602 |
|
1603 |
if clip_skip is None or clip_skip == 1:
|
1604 |
text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
|
|
1799 |
mask = mask.convert("L")
|
1800 |
w, h = mask.size
|
1801 |
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
1802 |
+
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.LANCZOS)
|
1803 |
mask = np.array(mask).astype(np.float32) / 255.0
|
1804 |
mask = np.tile(mask, (4, 1, 1))
|
1805 |
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
|
|
1817 |
# return text_encoder
|
1818 |
|
1819 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1820 |
def main(args):
|
1821 |
if args.fp16:
|
1822 |
dtype = torch.float16
|
|
|
1881 |
# tokenizerを読み込む
|
1882 |
print("loading tokenizer")
|
1883 |
if use_stable_diffusion_format:
|
1884 |
+
if args.v2:
|
1885 |
+
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
1886 |
+
else:
|
1887 |
+
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
|
1888 |
|
1889 |
# schedulerを用意する
|
1890 |
sched_init_args = {}
|
|
|
1995 |
# networkを組み込む
|
1996 |
if args.network_module:
|
1997 |
networks = []
|
|
|
1998 |
for i, network_module in enumerate(args.network_module):
|
1999 |
print("import network module:", network_module)
|
2000 |
imported_module = importlib.import_module(network_module)
|
2001 |
|
2002 |
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
|
|
2003 |
|
2004 |
net_kwargs = {}
|
2005 |
if args.network_args and i < len(args.network_args):
|
|
|
2014 |
network_weight = args.network_weights[i]
|
2015 |
print("load network weights from:", network_weight)
|
2016 |
|
2017 |
+
if model_util.is_safetensors(network_weight):
|
2018 |
from safetensors.torch import safe_open
|
2019 |
with safe_open(network_weight, framework="pt") as f:
|
2020 |
metadata = f.metadata()
|
|
|
2037 |
else:
|
2038 |
networks = []
|
2039 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2040 |
if args.opt_channels_last:
|
2041 |
print(f"set optimizing: channels last")
|
2042 |
text_encoder.to(memory_format=torch.channels_last)
|
|
|
2050 |
if vgg16_model is not None:
|
2051 |
vgg16_model.to(memory_format=torch.channels_last)
|
2052 |
|
|
|
|
|
|
|
|
|
2053 |
pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
|
2054 |
clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
|
2055 |
vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
|
|
|
2056 |
print("pipeline is ready.")
|
2057 |
|
2058 |
if args.diffusers_xformers:
|
|
|
2177 |
mask_images = l
|
2178 |
|
2179 |
# 画像サイズにオプション指定があるときはリサイズする
|
2180 |
+
if init_images is not None and args.W is not None and args.H is not None:
|
2181 |
+
print(f"resize img2img source images to {args.W}*{args.H}")
|
2182 |
+
init_images = resize_images(init_images, (args.W, args.H))
|
|
|
2183 |
if mask_images is not None:
|
2184 |
print(f"resize img2img mask images to {args.W}*{args.H}")
|
2185 |
mask_images = resize_images(mask_images, (args.W, args.H))
|
2186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2187 |
prev_image = None # for VGG16 guided
|
2188 |
if args.guide_image_path is not None:
|
2189 |
+
print(f"load image for CLIP/VGG16 guidance: {args.guide_image_path}")
|
2190 |
+
guide_images = load_images(args.guide_image_path)
|
2191 |
+
print(f"loaded {len(guide_images)} guide images for CLIP/VGG16 guidance")
|
|
|
|
|
|
|
2192 |
if len(guide_images) == 0:
|
2193 |
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
|
2194 |
guide_images = None
|
|
|
2219 |
iter_seed = random.randint(0, 0x7fffffff)
|
2220 |
|
2221 |
# バッチ処理の関数
|
2222 |
+
def process_batch(batch, highres_fix, highres_1st=False):
|
2223 |
batch_size = len(batch)
|
2224 |
|
2225 |
# highres_fixの処理
|
2226 |
if highres_fix and not highres_1st:
|
2227 |
+
# 1st stageのバッチを作成して呼び出す
|
2228 |
+
print("process 1st stage1")
|
2229 |
batch_1st = []
|
2230 |
+
for params1, (width, height, steps, scale, negative_scale, strength) in batch:
|
2231 |
+
width_1st = int(width * args.highres_fix_scale + .5)
|
2232 |
+
height_1st = int(height * args.highres_fix_scale + .5)
|
2233 |
width_1st = width_1st - width_1st % 32
|
2234 |
height_1st = height_1st - height_1st % 32
|
2235 |
+
batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, negative_scale, strength)))
|
|
|
|
|
|
|
2236 |
images_1st = process_batch(batch_1st, True, True)
|
2237 |
|
2238 |
# 2nd stageのバッチを作成して以下処理する
|
2239 |
+
print("process 2nd stage1")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2240 |
batch_2nd = []
|
2241 |
+
for i, (b1, image) in enumerate(zip(batch, images_1st)):
|
2242 |
+
image = image.resize((width, height), resample=PIL.Image.LANCZOS)
|
2243 |
+
(step, prompt, negative_prompt, seed, _, _, clip_prompt, guide_image), params2 = b1
|
2244 |
+
batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2))
|
|
|
2245 |
batch = batch_2nd
|
2246 |
|
2247 |
+
(step_first, _, _, _, init_image, mask_image, _, guide_image), (width,
|
2248 |
+
height, steps, scale, negative_scale, strength) = batch[0]
|
|
|
2249 |
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
2250 |
|
2251 |
prompts = []
|
|
|
2278 |
all_images_are_same = True
|
2279 |
all_masks_are_same = True
|
2280 |
all_guide_images_are_same = True
|
2281 |
+
for i, ((_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
|
2282 |
prompts.append(prompt)
|
2283 |
negative_prompts.append(negative_prompt)
|
2284 |
seeds.append(seed)
|
|
|
2295 |
all_masks_are_same = mask_images[-2] is mask_image
|
2296 |
|
2297 |
if guide_image is not None:
|
2298 |
+
guide_images.append(guide_image)
|
2299 |
+
if i > 0 and all_guide_images_are_same:
|
2300 |
+
all_guide_images_are_same = guide_images[-2] is guide_image
|
|
|
|
|
|
|
|
|
2301 |
|
2302 |
# make start code
|
2303 |
torch.manual_seed(seed)
|
|
|
2320 |
if guide_images is not None and all_guide_images_are_same:
|
2321 |
guide_images = guide_images[0]
|
2322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2323 |
# generate
|
|
|
|
|
|
|
|
|
2324 |
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
|
2325 |
+
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
|
2326 |
+
if highres_1st and not args.highres_fix_save_1st:
|
|
|
|
|
2327 |
return images
|
2328 |
|
2329 |
# save image
|
|
|
2398 |
strength = 0.8 if args.strength is None else args.strength
|
2399 |
negative_prompt = ""
|
2400 |
clip_prompt = None
|
|
|
2401 |
|
2402 |
prompt_args = prompt.strip().split(' --')
|
2403 |
prompt = prompt_args[0]
|
|
|
2461 |
clip_prompt = m.group(1)
|
2462 |
print(f"clip prompt: {clip_prompt}")
|
2463 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2464 |
except ValueError as ex:
|
2465 |
print(f"Exception in parsing / 解析エラー: {parg}")
|
2466 |
print(ex)
|
|
|
2498 |
mask_image = mask_images[global_step % len(mask_images)]
|
2499 |
|
2500 |
if guide_images is not None:
|
2501 |
+
guide_image = guide_images[global_step % len(guide_images)]
|
|
|
|
|
|
|
|
|
|
|
2502 |
elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
|
2503 |
if prev_image is None:
|
2504 |
print("Generate 1st image without guide image.")
|
|
|
2506 |
print("Use previous image as guide image.")
|
2507 |
guide_image = prev_image
|
2508 |
|
2509 |
+
# TODO named tupleか何かにする
|
2510 |
+
b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
2511 |
+
(width, height, steps, scale, negative_scale, strength))
|
2512 |
+
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
|
2513 |
process_batch(batch_data, highres_fix)
|
2514 |
batch_data.clear()
|
2515 |
|
|
|
2553 |
parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
|
2554 |
parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
|
2555 |
parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
|
|
|
|
|
2556 |
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
|
2557 |
parser.add_argument('--sampler', type=str, default='ddim',
|
2558 |
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
|
|
|
2564 |
parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
|
2565 |
parser.add_argument("--vae", type=str, default=None,
|
2566 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
|
|
|
|
2567 |
# parser.add_argument("--replace_clip_l14_336", action='store_true',
|
2568 |
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
|
2569 |
parser.add_argument("--seed", type=int, default=None,
|
|
|
2578 |
parser.add_argument("--opt_channels_last", action='store_true',
|
2579 |
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
|
2580 |
parser.add_argument("--network_module", type=str, default=None, nargs='*',
|
2581 |
+
help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
|
2582 |
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
2583 |
+
help='Hypernetwork weights to load / Hypernetworkの重み')
|
2584 |
+
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
|
|
2585 |
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
2586 |
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
|
|
|
|
2587 |
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
|
2588 |
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
|
2589 |
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
|
|
2597 |
help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
|
2598 |
parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
|
2599 |
help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
|
2600 |
+
parser.add_argument("--guide_image_path", type=str, default=None, help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
|
|
|
2601 |
parser.add_argument("--highres_fix_scale", type=float, default=None,
|
2602 |
help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
|
2603 |
parser.add_argument("--highres_fix_steps", type=int, default=28,
|
2604 |
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
|
2605 |
parser.add_argument("--highres_fix_save_1st", action='store_true',
|
2606 |
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
|
|
|
|
|
2607 |
parser.add_argument("--negative_scale", type=float, default=None,
|
2608 |
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
|
2609 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2610 |
args = parser.parse_args()
|
2611 |
main(args)
|
library.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: library
|
3 |
+
Version: 0.0.0
|
4 |
+
License-File: LICENSE.md
|
library.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LICENSE.md
|
2 |
+
README.md
|
3 |
+
setup.py
|
4 |
+
library/__init__.py
|
5 |
+
library/model_util.py
|
6 |
+
library/train_util.py
|
7 |
+
library.egg-info/PKG-INFO
|
8 |
+
library.egg-info/SOURCES.txt
|
9 |
+
library.egg-info/dependency_links.txt
|
10 |
+
library.egg-info/top_level.txt
|
library.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
library.egg-info/top_level.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
library
|
library/model_util.py
CHANGED
@@ -4,7 +4,7 @@
|
|
4 |
import math
|
5 |
import os
|
6 |
import torch
|
7 |
-
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
|
8 |
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
9 |
from safetensors.torch import load_file, save_file
|
10 |
|
@@ -916,11 +916,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
|
916 |
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
917 |
else:
|
918 |
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
919 |
-
|
920 |
-
logging.set_verbosity_error() # don't show annoying warning
|
921 |
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
922 |
-
logging.set_verbosity_warning()
|
923 |
-
|
924 |
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
925 |
print("loading text encoder:", info)
|
926 |
|
|
|
4 |
import math
|
5 |
import os
|
6 |
import torch
|
7 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
|
8 |
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
9 |
from safetensors.torch import load_file, save_file
|
10 |
|
|
|
916 |
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
917 |
else:
|
918 |
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
|
|
|
|
919 |
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
|
|
|
|
920 |
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
921 |
print("loading text encoder:", info)
|
922 |
|
library/train_util.py
CHANGED
@@ -1,21 +1,12 @@
|
|
1 |
# common functions for training
|
2 |
|
3 |
import argparse
|
4 |
-
import importlib
|
5 |
import json
|
6 |
-
import re
|
7 |
import shutil
|
8 |
import time
|
9 |
-
from typing import
|
10 |
-
Dict,
|
11 |
-
List,
|
12 |
-
NamedTuple,
|
13 |
-
Optional,
|
14 |
-
Sequence,
|
15 |
-
Tuple,
|
16 |
-
Union,
|
17 |
-
)
|
18 |
from accelerate import Accelerator
|
|
|
19 |
import glob
|
20 |
import math
|
21 |
import os
|
@@ -26,16 +17,10 @@ from io import BytesIO
|
|
26 |
|
27 |
from tqdm import tqdm
|
28 |
import torch
|
29 |
-
from torch.optim import Optimizer
|
30 |
from torchvision import transforms
|
31 |
from transformers import CLIPTokenizer
|
32 |
-
import transformers
|
33 |
import diffusers
|
34 |
-
from diffusers
|
35 |
-
from diffusers import (StableDiffusionPipeline, DDPMScheduler,
|
36 |
-
EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler,
|
37 |
-
LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler,
|
38 |
-
KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler)
|
39 |
import albumentations as albu
|
40 |
import numpy as np
|
41 |
from PIL import Image
|
@@ -210,95 +195,23 @@ class BucketBatchIndex(NamedTuple):
|
|
210 |
batch_index: int
|
211 |
|
212 |
|
213 |
-
class AugHelper:
|
214 |
-
def __init__(self):
|
215 |
-
# prepare all possible augmentators
|
216 |
-
color_aug_method = albu.OneOf([
|
217 |
-
albu.HueSaturationValue(8, 0, 0, p=.5),
|
218 |
-
albu.RandomGamma((95, 105), p=.5),
|
219 |
-
], p=.33)
|
220 |
-
flip_aug_method = albu.HorizontalFlip(p=0.5)
|
221 |
-
|
222 |
-
# key: (use_color_aug, use_flip_aug)
|
223 |
-
self.augmentors = {
|
224 |
-
(True, True): albu.Compose([
|
225 |
-
color_aug_method,
|
226 |
-
flip_aug_method,
|
227 |
-
], p=1.),
|
228 |
-
(True, False): albu.Compose([
|
229 |
-
color_aug_method,
|
230 |
-
], p=1.),
|
231 |
-
(False, True): albu.Compose([
|
232 |
-
flip_aug_method,
|
233 |
-
], p=1.),
|
234 |
-
(False, False): None
|
235 |
-
}
|
236 |
-
|
237 |
-
def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]:
|
238 |
-
return self.augmentors[(use_color_aug, use_flip_aug)]
|
239 |
-
|
240 |
-
|
241 |
-
class BaseSubset:
|
242 |
-
def __init__(self, image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, keep_tokens: int, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], random_crop: bool, caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float) -> None:
|
243 |
-
self.image_dir = image_dir
|
244 |
-
self.num_repeats = num_repeats
|
245 |
-
self.shuffle_caption = shuffle_caption
|
246 |
-
self.keep_tokens = keep_tokens
|
247 |
-
self.color_aug = color_aug
|
248 |
-
self.flip_aug = flip_aug
|
249 |
-
self.face_crop_aug_range = face_crop_aug_range
|
250 |
-
self.random_crop = random_crop
|
251 |
-
self.caption_dropout_rate = caption_dropout_rate
|
252 |
-
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
|
253 |
-
self.caption_tag_dropout_rate = caption_tag_dropout_rate
|
254 |
-
|
255 |
-
self.img_count = 0
|
256 |
-
|
257 |
-
|
258 |
-
class DreamBoothSubset(BaseSubset):
|
259 |
-
def __init__(self, image_dir: str, is_reg: bool, class_tokens: Optional[str], caption_extension: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
|
260 |
-
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
261 |
-
|
262 |
-
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
|
263 |
-
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
|
264 |
-
|
265 |
-
self.is_reg = is_reg
|
266 |
-
self.class_tokens = class_tokens
|
267 |
-
self.caption_extension = caption_extension
|
268 |
-
|
269 |
-
def __eq__(self, other) -> bool:
|
270 |
-
if not isinstance(other, DreamBoothSubset):
|
271 |
-
return NotImplemented
|
272 |
-
return self.image_dir == other.image_dir
|
273 |
-
|
274 |
-
|
275 |
-
class FineTuningSubset(BaseSubset):
|
276 |
-
def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
|
277 |
-
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
278 |
-
|
279 |
-
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
|
280 |
-
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
|
281 |
-
|
282 |
-
self.metadata_file = metadata_file
|
283 |
-
|
284 |
-
def __eq__(self, other) -> bool:
|
285 |
-
if not isinstance(other, FineTuningSubset):
|
286 |
-
return NotImplemented
|
287 |
-
return self.metadata_file == other.metadata_file
|
288 |
-
|
289 |
-
|
290 |
class BaseDataset(torch.utils.data.Dataset):
|
291 |
-
def __init__(self, tokenizer
|
292 |
super().__init__()
|
293 |
-
self.tokenizer = tokenizer
|
294 |
self.max_token_length = max_token_length
|
|
|
|
|
295 |
# width/height is used when enable_bucket==False
|
296 |
self.width, self.height = (None, None) if resolution is None else resolution
|
|
|
|
|
|
|
297 |
self.debug_dataset = debug_dataset
|
298 |
-
|
299 |
-
self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
|
300 |
-
|
301 |
self.token_padding_disabled = False
|
|
|
|
|
302 |
self.tag_frequency = {}
|
303 |
|
304 |
self.enable_bucket = False
|
@@ -312,28 +225,49 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
312 |
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
313 |
|
314 |
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
|
|
|
|
|
|
315 |
|
316 |
# augmentation
|
317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
|
319 |
self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
|
320 |
|
321 |
self.image_data: Dict[str, ImageInfo] = {}
|
322 |
-
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
|
323 |
|
324 |
self.replacements = {}
|
325 |
|
326 |
def set_current_epoch(self, epoch):
|
327 |
self.current_epoch = epoch
|
328 |
-
|
|
|
|
|
|
|
|
|
|
|
329 |
|
330 |
def set_tag_frequency(self, dir_name, captions):
|
331 |
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
332 |
self.tag_frequency[dir_name] = frequency_for_dir
|
333 |
for caption in captions:
|
334 |
for tag in caption.split(","):
|
335 |
-
tag
|
336 |
-
if tag:
|
337 |
tag = tag.lower()
|
338 |
frequency = frequency_for_dir.get(tag, 0)
|
339 |
frequency_for_dir[tag] = frequency + 1
|
@@ -344,36 +278,42 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
344 |
def add_replacement(self, str_from, str_to):
|
345 |
self.replacements[str_from] = str_to
|
346 |
|
347 |
-
def process_caption(self,
|
348 |
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
349 |
-
is_drop_out =
|
350 |
-
is_drop_out = is_drop_out or
|
351 |
|
352 |
if is_drop_out:
|
353 |
caption = ""
|
354 |
else:
|
355 |
-
if
|
356 |
def dropout_tags(tokens):
|
357 |
-
if
|
358 |
return tokens
|
359 |
l = []
|
360 |
for token in tokens:
|
361 |
-
if random.random() >=
|
362 |
l.append(token)
|
363 |
return l
|
364 |
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
|
|
|
|
|
|
|
|
|
|
370 |
|
371 |
-
|
372 |
-
|
373 |
|
374 |
-
|
375 |
|
376 |
-
|
|
|
377 |
|
378 |
# textual inversion対応
|
379 |
for str_from, str_to in self.replacements.items():
|
@@ -427,9 +367,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
427 |
input_ids = torch.stack(iids_list) # 3,77
|
428 |
return input_ids
|
429 |
|
430 |
-
def register_image(self, info: ImageInfo
|
431 |
self.image_data[info.image_key] = info
|
432 |
-
self.image_to_subset[info.image_key] = subset
|
433 |
|
434 |
def make_buckets(self):
|
435 |
'''
|
@@ -528,7 +467,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
528 |
img = np.array(image, np.uint8)
|
529 |
return img
|
530 |
|
531 |
-
def trim_and_resize_if_required(self,
|
532 |
image_height, image_width = image.shape[0:2]
|
533 |
|
534 |
if image_width != resized_size[0] or image_height != resized_size[1]:
|
@@ -538,27 +477,22 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
538 |
image_height, image_width = image.shape[0:2]
|
539 |
if image_width > reso[0]:
|
540 |
trim_size = image_width - reso[0]
|
541 |
-
p = trim_size // 2 if not
|
542 |
# print("w", trim_size, p)
|
543 |
image = image[:, p:p + reso[0]]
|
544 |
if image_height > reso[1]:
|
545 |
trim_size = image_height - reso[1]
|
546 |
-
p = trim_size // 2 if not
|
547 |
# print("h", trim_size, p)
|
548 |
image = image[p:p + reso[1]]
|
549 |
|
550 |
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
551 |
return image
|
552 |
|
553 |
-
def is_latent_cacheable(self):
|
554 |
-
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
555 |
-
|
556 |
def cache_latents(self, vae):
|
557 |
# TODO ここを高速化したい
|
558 |
print("caching latents.")
|
559 |
for info in tqdm(self.image_data.values()):
|
560 |
-
subset = self.image_to_subset[info.image_key]
|
561 |
-
|
562 |
if info.latents_npz is not None:
|
563 |
info.latents = self.load_latents_from_npz(info, False)
|
564 |
info.latents = torch.FloatTensor(info.latents)
|
@@ -568,13 +502,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
568 |
continue
|
569 |
|
570 |
image = self.load_image(info.absolute_path)
|
571 |
-
image = self.trim_and_resize_if_required(
|
572 |
|
573 |
img_tensor = self.image_transforms(image)
|
574 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
575 |
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
576 |
|
577 |
-
if
|
578 |
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
|
579 |
img_tensor = self.image_transforms(image)
|
580 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
@@ -584,11 +518,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
584 |
image = Image.open(image_path)
|
585 |
return image.size
|
586 |
|
587 |
-
def load_image_with_face_info(self,
|
588 |
img = self.load_image(image_path)
|
589 |
|
590 |
face_cx = face_cy = face_w = face_h = 0
|
591 |
-
if
|
592 |
tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
|
593 |
if len(tokens) >= 5:
|
594 |
face_cx = int(tokens[-4])
|
@@ -599,7 +533,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
599 |
return img, face_cx, face_cy, face_w, face_h
|
600 |
|
601 |
# いい感じに切り出す
|
602 |
-
def crop_target(self,
|
603 |
height, width = image.shape[0:2]
|
604 |
if height == self.height and width == self.width:
|
605 |
return image
|
@@ -607,8 +541,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
607 |
# 画像サイズはsizeより大きいのでリサイズする
|
608 |
face_size = max(face_w, face_h)
|
609 |
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
|
610 |
-
min_scale = min(1.0, max(min_scale, self.size / (face_size *
|
611 |
-
max_scale = min(1.0, max(min_scale, self.size / (face_size *
|
612 |
if min_scale >= max_scale: # range指定がmin==max
|
613 |
scale = min_scale
|
614 |
else:
|
@@ -626,13 +560,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
626 |
for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
|
627 |
p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
|
628 |
|
629 |
-
if
|
630 |
# 背景も含めるために顔を中心に置く確率を高めつつずらす
|
631 |
range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
|
632 |
p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
|
633 |
else:
|
634 |
# range指定があるときのみ、すこしだけランダムに(わりと適当)
|
635 |
-
if
|
636 |
if face_size > self.size // 10 and face_size >= 40:
|
637 |
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
|
638 |
|
@@ -655,6 +589,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
655 |
return self._length
|
656 |
|
657 |
def __getitem__(self, index):
|
|
|
|
|
|
|
658 |
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
659 |
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
660 |
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
@@ -667,29 +604,28 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
667 |
|
668 |
for image_key in bucket[image_index:image_index + bucket_batch_size]:
|
669 |
image_info = self.image_data[image_key]
|
670 |
-
subset = self.image_to_subset[image_key]
|
671 |
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
672 |
|
673 |
# image/latentsを処理する
|
674 |
if image_info.latents is not None:
|
675 |
-
latents = image_info.latents if not
|
676 |
image = None
|
677 |
elif image_info.latents_npz is not None:
|
678 |
-
latents = self.load_latents_from_npz(image_info,
|
679 |
latents = torch.FloatTensor(latents)
|
680 |
image = None
|
681 |
else:
|
682 |
# 画像を読み込み、必要ならcropする
|
683 |
-
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(
|
684 |
im_h, im_w = img.shape[0:2]
|
685 |
|
686 |
if self.enable_bucket:
|
687 |
-
img = self.trim_and_resize_if_required(
|
688 |
else:
|
689 |
if face_cx > 0: # 顔位置情報あり
|
690 |
-
img = self.crop_target(
|
691 |
elif im_h > self.height or im_w > self.width:
|
692 |
-
assert
|
693 |
if im_h > self.height:
|
694 |
p = random.randint(0, im_h - self.height)
|
695 |
img = img[p:p + self.height]
|
@@ -701,9 +637,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
701 |
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
702 |
|
703 |
# augmentation
|
704 |
-
|
705 |
-
|
706 |
-
img = aug(image=img)['image']
|
707 |
|
708 |
latents = None
|
709 |
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
@@ -711,7 +646,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
711 |
images.append(image)
|
712 |
latents_list.append(latents)
|
713 |
|
714 |
-
caption = self.process_caption(
|
715 |
captions.append(caption)
|
716 |
if not self.token_padding_disabled: # this option might be omitted in future
|
717 |
input_ids_list.append(self.get_input_ids(caption))
|
@@ -742,8 +677,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
742 |
|
743 |
|
744 |
class DreamBoothDataset(BaseDataset):
|
745 |
-
def __init__(self,
|
746 |
-
super().__init__(tokenizer, max_token_length,
|
|
|
747 |
|
748 |
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
749 |
|
@@ -766,7 +702,7 @@ class DreamBoothDataset(BaseDataset):
|
|
766 |
self.bucket_reso_steps = None # この情報は使われない
|
767 |
self.bucket_no_upscale = False
|
768 |
|
769 |
-
def read_caption(img_path
|
770 |
# captionの候補ファイル名を作る
|
771 |
base_name = os.path.splitext(img_path)[0]
|
772 |
base_name_face_det = base_name
|
@@ -789,181 +725,153 @@ class DreamBoothDataset(BaseDataset):
|
|
789 |
break
|
790 |
return caption
|
791 |
|
792 |
-
def load_dreambooth_dir(
|
793 |
-
if not os.path.isdir(
|
794 |
-
print(f"
|
795 |
-
return [], []
|
796 |
|
797 |
-
|
798 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
799 |
|
800 |
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
801 |
captions = []
|
802 |
for img_path in img_paths:
|
803 |
-
cap_for_img = read_caption(img_path
|
804 |
-
if cap_for_img is None
|
805 |
-
print(f"neither caption file nor class tokens are found. use empty caption for {img_path}")
|
806 |
-
captions.append("")
|
807 |
-
else:
|
808 |
-
captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
|
809 |
|
810 |
-
self.set_tag_frequency(os.path.basename(
|
811 |
|
812 |
-
return img_paths, captions
|
813 |
|
814 |
-
print("prepare images.")
|
|
|
815 |
num_train_images = 0
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
if subset.num_repeats < 1:
|
820 |
-
print(
|
821 |
-
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
822 |
-
continue
|
823 |
-
|
824 |
-
if subset in self.subsets:
|
825 |
-
print(
|
826 |
-
f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
827 |
-
continue
|
828 |
-
|
829 |
-
img_paths, captions = load_dreambooth_dir(subset)
|
830 |
-
if len(img_paths) < 1:
|
831 |
-
print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
|
832 |
-
continue
|
833 |
-
|
834 |
-
if subset.is_reg:
|
835 |
-
num_reg_images += subset.num_repeats * len(img_paths)
|
836 |
-
else:
|
837 |
-
num_train_images += subset.num_repeats * len(img_paths)
|
838 |
|
839 |
for img_path, caption in zip(img_paths, captions):
|
840 |
-
info = ImageInfo(img_path,
|
841 |
-
|
842 |
-
reg_infos.append(info)
|
843 |
-
else:
|
844 |
-
self.register_image(info, subset)
|
845 |
|
846 |
-
|
847 |
-
self.subsets.append(subset)
|
848 |
|
849 |
print(f"{num_train_images} train images with repeating.")
|
850 |
self.num_train_images = num_train_images
|
851 |
|
852 |
-
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
|
857 |
-
print("no regularization images / 正則化画像が見つかりませんでした")
|
858 |
-
else:
|
859 |
-
# num_repeatsを計算する:どうせ大した数ではないのでループで処理する
|
860 |
-
n = 0
|
861 |
-
first_loop = True
|
862 |
-
while n < num_train_images:
|
863 |
-
for info in reg_infos:
|
864 |
-
if first_loop:
|
865 |
-
self.register_image(info, subset)
|
866 |
-
n += info.num_repeats
|
867 |
-
else:
|
868 |
-
info.num_repeats += 1
|
869 |
-
n += 1
|
870 |
-
if n >= num_train_images:
|
871 |
-
break
|
872 |
-
first_loop = False
|
873 |
-
|
874 |
-
self.num_reg_images = num_reg_images
|
875 |
-
|
876 |
-
|
877 |
-
class FineTuningDataset(BaseDataset):
|
878 |
-
def __init__(self, subsets: Sequence[FineTuningSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset) -> None:
|
879 |
-
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
880 |
|
881 |
-
|
|
|
|
|
|
|
882 |
|
883 |
-
|
884 |
-
|
|
|
885 |
|
886 |
-
|
887 |
-
if subset.num_repeats < 1:
|
888 |
-
print(
|
889 |
-
f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
890 |
-
continue
|
891 |
|
892 |
-
|
893 |
-
|
894 |
-
|
895 |
-
continue
|
896 |
|
897 |
-
|
898 |
-
|
899 |
-
print(f"loading existing metadata: {subset.metadata_file}")
|
900 |
-
with open(subset.metadata_file, "rt", encoding='utf-8') as f:
|
901 |
-
metadata = json.load(f)
|
902 |
else:
|
903 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
904 |
|
905 |
-
|
906 |
-
print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します")
|
907 |
-
continue
|
908 |
-
|
909 |
-
tags_list = []
|
910 |
-
for image_key, img_md in metadata.items():
|
911 |
-
# path情報を作る
|
912 |
-
if os.path.exists(image_key):
|
913 |
-
abs_path = image_key
|
914 |
-
else:
|
915 |
-
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
916 |
-
if os.path.exists(npz_path):
|
917 |
-
abs_path = npz_path
|
918 |
-
else:
|
919 |
-
# わりといい加減だがいい方法が思いつかん
|
920 |
-
abs_path = glob_images(subset.image_dir, image_key)
|
921 |
-
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
922 |
-
abs_path = abs_path[0]
|
923 |
-
|
924 |
-
caption = img_md.get('caption')
|
925 |
-
tags = img_md.get('tags')
|
926 |
-
if caption is None:
|
927 |
-
caption = tags
|
928 |
-
elif tags is not None and len(tags) > 0:
|
929 |
-
caption = caption + ', ' + tags
|
930 |
-
tags_list.append(tags)
|
931 |
-
|
932 |
-
if caption is None:
|
933 |
-
caption = ""
|
934 |
|
935 |
-
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
|
936 |
-
image_info.image_size = img_md.get('train_resolution')
|
937 |
|
938 |
-
|
939 |
-
|
940 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
941 |
|
942 |
-
|
|
|
|
|
943 |
|
944 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
945 |
|
946 |
-
|
947 |
-
|
948 |
-
|
949 |
-
self.subsets.append(subset)
|
950 |
|
951 |
# check existence of all npz files
|
952 |
-
use_npz_latents =
|
953 |
if use_npz_latents:
|
954 |
-
flip_aug_in_subset = False
|
955 |
npz_any = False
|
956 |
npz_all = True
|
957 |
-
|
958 |
for image_info in self.image_data.values():
|
959 |
-
subset = self.image_to_subset[image_info.image_key]
|
960 |
-
|
961 |
has_npz = image_info.latents_npz is not None
|
962 |
npz_any = npz_any or has_npz
|
963 |
|
964 |
-
if
|
965 |
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
966 |
-
flip_aug_in_subset = True
|
967 |
npz_all = npz_all and has_npz
|
968 |
|
969 |
if npz_any and not npz_all:
|
@@ -975,7 +883,7 @@ class FineTuningDataset(BaseDataset):
|
|
975 |
elif not npz_all:
|
976 |
use_npz_latents = False
|
977 |
print(f"some of npz file does not exist. ignore npz files / いくつ���のnpzファイルが見つからないためnpzファイルを無視します")
|
978 |
-
if
|
979 |
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
980 |
# else:
|
981 |
# print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
@@ -1021,7 +929,7 @@ class FineTuningDataset(BaseDataset):
|
|
1021 |
for image_info in self.image_data.values():
|
1022 |
image_info.latents_npz = image_info.latents_npz_flipped = None
|
1023 |
|
1024 |
-
def image_key_to_npz_file(self,
|
1025 |
base_name = os.path.splitext(image_key)[0]
|
1026 |
npz_file_norm = base_name + '.npz'
|
1027 |
|
@@ -1033,8 +941,8 @@ class FineTuningDataset(BaseDataset):
|
|
1033 |
return npz_file_norm, npz_file_flip
|
1034 |
|
1035 |
# image_key is relative path
|
1036 |
-
npz_file_norm = os.path.join(
|
1037 |
-
npz_file_flip = os.path.join(
|
1038 |
|
1039 |
if not os.path.exists(npz_file_norm):
|
1040 |
npz_file_norm = None
|
@@ -1045,60 +953,13 @@ class FineTuningDataset(BaseDataset):
|
|
1045 |
return npz_file_norm, npz_file_flip
|
1046 |
|
1047 |
|
1048 |
-
# behave as Dataset mock
|
1049 |
-
class DatasetGroup(torch.utils.data.ConcatDataset):
|
1050 |
-
def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
|
1051 |
-
self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]]
|
1052 |
-
|
1053 |
-
super().__init__(datasets)
|
1054 |
-
|
1055 |
-
self.image_data = {}
|
1056 |
-
self.num_train_images = 0
|
1057 |
-
self.num_reg_images = 0
|
1058 |
-
|
1059 |
-
# simply concat together
|
1060 |
-
# TODO: handling image_data key duplication among dataset
|
1061 |
-
# In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset.
|
1062 |
-
for dataset in datasets:
|
1063 |
-
self.image_data.update(dataset.image_data)
|
1064 |
-
self.num_train_images += dataset.num_train_images
|
1065 |
-
self.num_reg_images += dataset.num_reg_images
|
1066 |
-
|
1067 |
-
def add_replacement(self, str_from, str_to):
|
1068 |
-
for dataset in self.datasets:
|
1069 |
-
dataset.add_replacement(str_from, str_to)
|
1070 |
-
|
1071 |
-
# def make_buckets(self):
|
1072 |
-
# for dataset in self.datasets:
|
1073 |
-
# dataset.make_buckets()
|
1074 |
-
|
1075 |
-
def cache_latents(self, vae):
|
1076 |
-
for i, dataset in enumerate(self.datasets):
|
1077 |
-
print(f"[Dataset {i}]")
|
1078 |
-
dataset.cache_latents(vae)
|
1079 |
-
|
1080 |
-
def is_latent_cacheable(self) -> bool:
|
1081 |
-
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
1082 |
-
|
1083 |
-
def set_current_epoch(self, epoch):
|
1084 |
-
for dataset in self.datasets:
|
1085 |
-
dataset.set_current_epoch(epoch)
|
1086 |
-
|
1087 |
-
def disable_token_padding(self):
|
1088 |
-
for dataset in self.datasets:
|
1089 |
-
dataset.disable_token_padding()
|
1090 |
-
|
1091 |
-
|
1092 |
def debug_dataset(train_dataset, show_input_ids=False):
|
1093 |
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
1094 |
print("Escape for exit. / Escキーで中断、終了します")
|
1095 |
|
1096 |
train_dataset.set_current_epoch(1)
|
1097 |
k = 0
|
1098 |
-
|
1099 |
-
random.shuffle(indices)
|
1100 |
-
for i, idx in enumerate(indices):
|
1101 |
-
example = train_dataset[idx]
|
1102 |
if example['latents'] is not None:
|
1103 |
print(f"sample has latents from npz file: {example['latents'].size()}")
|
1104 |
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
|
@@ -1503,35 +1364,6 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
|
1503 |
help='enable v-parameterization training / v-parameterization学習を有効にする')
|
1504 |
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
|
1505 |
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
1506 |
-
parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
|
1507 |
-
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
|
1508 |
-
|
1509 |
-
|
1510 |
-
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
1511 |
-
parser.add_argument("--optimizer_type", type=str, default="",
|
1512 |
-
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor")
|
1513 |
-
|
1514 |
-
# backward compatibility
|
1515 |
-
parser.add_argument("--use_8bit_adam", action="store_true",
|
1516 |
-
help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
1517 |
-
parser.add_argument("--use_lion_optimizer", action="store_true",
|
1518 |
-
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
1519 |
-
|
1520 |
-
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
1521 |
-
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
1522 |
-
help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない")
|
1523 |
-
|
1524 |
-
parser.add_argument("--optimizer_args", type=str, default=None, nargs='*',
|
1525 |
-
help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")")
|
1526 |
-
|
1527 |
-
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
1528 |
-
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor")
|
1529 |
-
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
1530 |
-
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
1531 |
-
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
1532 |
-
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
1533 |
-
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
1534 |
-
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
1535 |
|
1536 |
|
1537 |
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
@@ -1555,6 +1387,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|
1555 |
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
1556 |
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
1557 |
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
|
|
|
|
|
|
|
|
1558 |
parser.add_argument("--mem_eff_attn", action="store_true",
|
1559 |
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
1560 |
parser.add_argument("--xformers", action="store_true",
|
@@ -1562,6 +1398,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|
1562 |
parser.add_argument("--vae", type=str, default=None,
|
1563 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
1564 |
|
|
|
1565 |
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
1566 |
parser.add_argument("--max_train_epochs", type=int, default=None,
|
1567 |
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
@@ -1582,23 +1419,15 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|
1582 |
parser.add_argument("--logging_dir", type=str, default=None,
|
1583 |
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
1584 |
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
|
|
|
|
|
|
|
|
1585 |
parser.add_argument("--noise_offset", type=float, default=None,
|
1586 |
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
|
1587 |
parser.add_argument("--lowram", action="store_true",
|
1588 |
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
|
1589 |
|
1590 |
-
parser.add_argument("--sample_every_n_steps", type=int, default=None,
|
1591 |
-
help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する")
|
1592 |
-
parser.add_argument("--sample_every_n_epochs", type=int, default=None,
|
1593 |
-
help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)")
|
1594 |
-
parser.add_argument("--sample_prompts", type=str, default=None,
|
1595 |
-
help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル")
|
1596 |
-
parser.add_argument('--sample_sampler', type=str, default='ddim',
|
1597 |
-
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
|
1598 |
-
'dpmsolver++', 'dpmsingle',
|
1599 |
-
'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'],
|
1600 |
-
help=f'sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類')
|
1601 |
-
|
1602 |
if support_dreambooth:
|
1603 |
# DreamBooth training
|
1604 |
parser.add_argument("--prior_loss_weight", type=float, default=1.0,
|
@@ -1620,8 +1449,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
|
1620 |
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
|
1621 |
parser.add_argument("--caption_extention", type=str, default=None,
|
1622 |
help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
|
1623 |
-
parser.add_argument("--keep_tokens", type=int, default=
|
1624 |
-
help="keep heading N tokens when shuffling caption tokens
|
1625 |
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
|
1626 |
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
|
1627 |
parser.add_argument("--face_crop_aug_range", type=str, default=None,
|
@@ -1646,11 +1475,11 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
|
1646 |
if support_caption_dropout:
|
1647 |
# Textual Inversion はcaptionのdropoutをsupportしない
|
1648 |
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
1649 |
-
parser.add_argument("--caption_dropout_rate", type=float, default=0
|
1650 |
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
1651 |
-
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=
|
1652 |
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
1653 |
-
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0
|
1654 |
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
|
1655 |
|
1656 |
if support_dreambooth:
|
@@ -1675,256 +1504,16 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
|
1675 |
# region utils
|
1676 |
|
1677 |
|
1678 |
-
def get_optimizer(args, trainable_params):
|
1679 |
-
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
|
1680 |
-
|
1681 |
-
optimizer_type = args.optimizer_type
|
1682 |
-
if args.use_8bit_adam:
|
1683 |
-
assert not args.use_lion_optimizer, "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています"
|
1684 |
-
assert optimizer_type is None or optimizer_type == "", "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています"
|
1685 |
-
optimizer_type = "AdamW8bit"
|
1686 |
-
|
1687 |
-
elif args.use_lion_optimizer:
|
1688 |
-
assert optimizer_type is None or optimizer_type == "", "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています"
|
1689 |
-
optimizer_type = "Lion"
|
1690 |
-
|
1691 |
-
if optimizer_type is None or optimizer_type == "":
|
1692 |
-
optimizer_type = "AdamW"
|
1693 |
-
optimizer_type = optimizer_type.lower()
|
1694 |
-
|
1695 |
-
# 引数を分解する:boolとfloat、tupleのみ対応
|
1696 |
-
optimizer_kwargs = {}
|
1697 |
-
if args.optimizer_args is not None and len(args.optimizer_args) > 0:
|
1698 |
-
for arg in args.optimizer_args:
|
1699 |
-
key, value = arg.split('=')
|
1700 |
-
|
1701 |
-
value = value.split(",")
|
1702 |
-
for i in range(len(value)):
|
1703 |
-
if value[i].lower() == "true" or value[i].lower() == "false":
|
1704 |
-
value[i] = (value[i].lower() == "true")
|
1705 |
-
else:
|
1706 |
-
value[i] = float(value[i])
|
1707 |
-
if len(value) == 1:
|
1708 |
-
value = value[0]
|
1709 |
-
else:
|
1710 |
-
value = tuple(value)
|
1711 |
-
|
1712 |
-
optimizer_kwargs[key] = value
|
1713 |
-
# print("optkwargs:", optimizer_kwargs)
|
1714 |
-
|
1715 |
-
lr = args.learning_rate
|
1716 |
-
|
1717 |
-
if optimizer_type == "AdamW8bit".lower():
|
1718 |
-
try:
|
1719 |
-
import bitsandbytes as bnb
|
1720 |
-
except ImportError:
|
1721 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
1722 |
-
print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
|
1723 |
-
optimizer_class = bnb.optim.AdamW8bit
|
1724 |
-
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1725 |
-
|
1726 |
-
elif optimizer_type == "SGDNesterov8bit".lower():
|
1727 |
-
try:
|
1728 |
-
import bitsandbytes as bnb
|
1729 |
-
except ImportError:
|
1730 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
1731 |
-
print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
|
1732 |
-
if "momentum" not in optimizer_kwargs:
|
1733 |
-
print(f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
|
1734 |
-
optimizer_kwargs["momentum"] = 0.9
|
1735 |
-
|
1736 |
-
optimizer_class = bnb.optim.SGD8bit
|
1737 |
-
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
1738 |
-
|
1739 |
-
elif optimizer_type == "Lion".lower():
|
1740 |
-
try:
|
1741 |
-
import lion_pytorch
|
1742 |
-
except ImportError:
|
1743 |
-
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
1744 |
-
print(f"use Lion optimizer | {optimizer_kwargs}")
|
1745 |
-
optimizer_class = lion_pytorch.Lion
|
1746 |
-
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1747 |
-
|
1748 |
-
elif optimizer_type == "SGDNesterov".lower():
|
1749 |
-
print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
|
1750 |
-
if "momentum" not in optimizer_kwargs:
|
1751 |
-
print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
|
1752 |
-
optimizer_kwargs["momentum"] = 0.9
|
1753 |
-
|
1754 |
-
optimizer_class = torch.optim.SGD
|
1755 |
-
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
1756 |
-
|
1757 |
-
elif optimizer_type == "DAdaptation".lower():
|
1758 |
-
try:
|
1759 |
-
import dadaptation
|
1760 |
-
except ImportError:
|
1761 |
-
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
1762 |
-
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
|
1763 |
-
|
1764 |
-
actual_lr = lr
|
1765 |
-
lr_count = 1
|
1766 |
-
if type(trainable_params) == list and type(trainable_params[0]) == dict:
|
1767 |
-
lrs = set()
|
1768 |
-
actual_lr = trainable_params[0].get("lr", actual_lr)
|
1769 |
-
for group in trainable_params:
|
1770 |
-
lrs.add(group.get("lr", actual_lr))
|
1771 |
-
lr_count = len(lrs)
|
1772 |
-
|
1773 |
-
if actual_lr <= 0.1:
|
1774 |
-
print(
|
1775 |
-
f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}')
|
1776 |
-
print('recommend option: lr=1.0 / 推奨は1.0です')
|
1777 |
-
if lr_count > 1:
|
1778 |
-
print(
|
1779 |
-
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}")
|
1780 |
-
|
1781 |
-
optimizer_class = dadaptation.DAdaptAdam
|
1782 |
-
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1783 |
-
|
1784 |
-
elif optimizer_type == "Adafactor".lower():
|
1785 |
-
# 引数を確認して適宜補正する
|
1786 |
-
if "relative_step" not in optimizer_kwargs:
|
1787 |
-
optimizer_kwargs["relative_step"] = True # default
|
1788 |
-
if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
|
1789 |
-
print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします")
|
1790 |
-
optimizer_kwargs["relative_step"] = True
|
1791 |
-
print(f"use Adafactor optimizer | {optimizer_kwargs}")
|
1792 |
-
|
1793 |
-
if optimizer_kwargs["relative_step"]:
|
1794 |
-
print(f"relative_step is true / relative_stepがtrueです")
|
1795 |
-
if lr != 0.0:
|
1796 |
-
print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
|
1797 |
-
args.learning_rate = None
|
1798 |
-
|
1799 |
-
# trainable_paramsがgroupだった時の処理:lrを削除する
|
1800 |
-
if type(trainable_params) == list and type(trainable_params[0]) == dict:
|
1801 |
-
has_group_lr = False
|
1802 |
-
for group in trainable_params:
|
1803 |
-
p = group.pop("lr", None)
|
1804 |
-
has_group_lr = has_group_lr or (p is not None)
|
1805 |
-
|
1806 |
-
if has_group_lr:
|
1807 |
-
# 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない
|
1808 |
-
print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます")
|
1809 |
-
args.unet_lr = None
|
1810 |
-
args.text_encoder_lr = None
|
1811 |
-
|
1812 |
-
if args.lr_scheduler != "adafactor":
|
1813 |
-
print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
|
1814 |
-
args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
|
1815 |
-
|
1816 |
-
lr = None
|
1817 |
-
else:
|
1818 |
-
if args.max_grad_norm != 0.0:
|
1819 |
-
print(f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません")
|
1820 |
-
if args.lr_scheduler != "constant_with_warmup":
|
1821 |
-
print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
|
1822 |
-
if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
|
1823 |
-
print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
|
1824 |
-
|
1825 |
-
optimizer_class = transformers.optimization.Adafactor
|
1826 |
-
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1827 |
-
|
1828 |
-
elif optimizer_type == "AdamW".lower():
|
1829 |
-
print(f"use AdamW optimizer | {optimizer_kwargs}")
|
1830 |
-
optimizer_class = torch.optim.AdamW
|
1831 |
-
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1832 |
-
|
1833 |
-
else:
|
1834 |
-
# 任意のoptimizerを使う
|
1835 |
-
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
1836 |
-
print(f"use {optimizer_type} | {optimizer_kwargs}")
|
1837 |
-
if "." not in optimizer_type:
|
1838 |
-
optimizer_module = torch.optim
|
1839 |
-
else:
|
1840 |
-
values = optimizer_type.split(".")
|
1841 |
-
optimizer_module = importlib.import_module(".".join(values[:-1]))
|
1842 |
-
optimizer_type = values[-1]
|
1843 |
-
|
1844 |
-
optimizer_class = getattr(optimizer_module, optimizer_type)
|
1845 |
-
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1846 |
-
|
1847 |
-
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
1848 |
-
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
|
1849 |
-
|
1850 |
-
return optimizer_name, optimizer_args, optimizer
|
1851 |
-
|
1852 |
-
|
1853 |
-
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
1854 |
-
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
1855 |
-
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
1856 |
-
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
1857 |
-
|
1858 |
-
|
1859 |
-
def get_scheduler_fix(
|
1860 |
-
name: Union[str, SchedulerType],
|
1861 |
-
optimizer: Optimizer,
|
1862 |
-
num_warmup_steps: Optional[int] = None,
|
1863 |
-
num_training_steps: Optional[int] = None,
|
1864 |
-
num_cycles: int = 1,
|
1865 |
-
power: float = 1.0,
|
1866 |
-
):
|
1867 |
-
"""
|
1868 |
-
Unified API to get any scheduler from its name.
|
1869 |
-
Args:
|
1870 |
-
name (`str` or `SchedulerType`):
|
1871 |
-
The name of the scheduler to use.
|
1872 |
-
optimizer (`torch.optim.Optimizer`):
|
1873 |
-
The optimizer that will be used during training.
|
1874 |
-
num_warmup_steps (`int`, *optional*):
|
1875 |
-
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
1876 |
-
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
1877 |
-
num_training_steps (`int``, *optional*):
|
1878 |
-
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
1879 |
-
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
1880 |
-
num_cycles (`int`, *optional*):
|
1881 |
-
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
1882 |
-
power (`float`, *optional*, defaults to 1.0):
|
1883 |
-
Power factor. See `POLYNOMIAL` scheduler
|
1884 |
-
last_epoch (`int`, *optional*, defaults to -1):
|
1885 |
-
The index of the last epoch when resuming training.
|
1886 |
-
"""
|
1887 |
-
if name.startswith("adafactor"):
|
1888 |
-
assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
|
1889 |
-
initial_lr = float(name.split(':')[1])
|
1890 |
-
# print("adafactor scheduler init lr", initial_lr)
|
1891 |
-
return transformers.optimization.AdafactorSchedule(optimizer, initial_lr)
|
1892 |
-
|
1893 |
-
name = SchedulerType(name)
|
1894 |
-
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
1895 |
-
if name == SchedulerType.CONSTANT:
|
1896 |
-
return schedule_func(optimizer)
|
1897 |
-
|
1898 |
-
# All other schedulers require `num_warmup_steps`
|
1899 |
-
if num_warmup_steps is None:
|
1900 |
-
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
1901 |
-
|
1902 |
-
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
1903 |
-
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
1904 |
-
|
1905 |
-
# All other schedulers require `num_training_steps`
|
1906 |
-
if num_training_steps is None:
|
1907 |
-
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
1908 |
-
|
1909 |
-
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
1910 |
-
return schedule_func(
|
1911 |
-
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
1912 |
-
)
|
1913 |
-
|
1914 |
-
if name == SchedulerType.POLYNOMIAL:
|
1915 |
-
return schedule_func(
|
1916 |
-
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
1917 |
-
)
|
1918 |
-
|
1919 |
-
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
1920 |
-
|
1921 |
-
|
1922 |
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
1923 |
# backward compatibility
|
1924 |
if args.caption_extention is not None:
|
1925 |
args.caption_extension = args.caption_extention
|
1926 |
args.caption_extention = None
|
1927 |
|
|
|
|
|
|
|
|
|
1928 |
# assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
|
1929 |
if args.resolution is not None:
|
1930 |
args.resolution = tuple([int(r) for r in args.resolution.split(',')])
|
@@ -1947,28 +1536,12 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
|
1947 |
|
1948 |
def load_tokenizer(args: argparse.Namespace):
|
1949 |
print("prepare tokenizer")
|
1950 |
-
|
1951 |
-
|
1952 |
-
|
1953 |
-
|
1954 |
-
|
1955 |
-
if os.path.exists(local_tokenizer_path):
|
1956 |
-
print(f"load tokenizer from cache: {local_tokenizer_path}")
|
1957 |
-
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2
|
1958 |
-
|
1959 |
-
if tokenizer is None:
|
1960 |
-
if args.v2:
|
1961 |
-
tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer")
|
1962 |
-
else:
|
1963 |
-
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
1964 |
-
|
1965 |
-
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
1966 |
print(f"update token length: {args.max_token_length}")
|
1967 |
-
|
1968 |
-
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
1969 |
-
print(f"save Tokenizer to cache: {local_tokenizer_path}")
|
1970 |
-
tokenizer.save_pretrained(local_tokenizer_path)
|
1971 |
-
|
1972 |
return tokenizer
|
1973 |
|
1974 |
|
@@ -2019,19 +1592,13 @@ def prepare_dtype(args: argparse.Namespace):
|
|
2019 |
|
2020 |
|
2021 |
def load_target_model(args: argparse.Namespace, weight_dtype):
|
2022 |
-
|
2023 |
-
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
2024 |
-
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
2025 |
if load_stable_diffusion_format:
|
2026 |
print("load StableDiffusion checkpoint")
|
2027 |
-
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2,
|
2028 |
else:
|
2029 |
print("load Diffusers pretrained models")
|
2030 |
-
|
2031 |
-
pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
|
2032 |
-
except EnvironmentError as ex:
|
2033 |
-
print(
|
2034 |
-
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}")
|
2035 |
text_encoder = pipe.text_encoder
|
2036 |
vae = pipe.vae
|
2037 |
unet = pipe.unet
|
@@ -2200,197 +1767,6 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
|
2200 |
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
2201 |
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
2202 |
|
2203 |
-
|
2204 |
-
# scheduler:
|
2205 |
-
SCHEDULER_LINEAR_START = 0.00085
|
2206 |
-
SCHEDULER_LINEAR_END = 0.0120
|
2207 |
-
SCHEDULER_TIMESTEPS = 1000
|
2208 |
-
SCHEDLER_SCHEDULE = 'scaled_linear'
|
2209 |
-
|
2210 |
-
|
2211 |
-
def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None):
|
2212 |
-
"""
|
2213 |
-
生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない
|
2214 |
-
clip skipは対応した
|
2215 |
-
"""
|
2216 |
-
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
2217 |
-
return
|
2218 |
-
if args.sample_every_n_epochs is not None:
|
2219 |
-
# sample_every_n_steps は無視する
|
2220 |
-
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
2221 |
-
return
|
2222 |
-
else:
|
2223 |
-
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
2224 |
-
return
|
2225 |
-
|
2226 |
-
print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
2227 |
-
if not os.path.isfile(args.sample_prompts):
|
2228 |
-
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
2229 |
-
return
|
2230 |
-
|
2231 |
-
org_vae_device = vae.device # CPUにいるはず
|
2232 |
-
vae.to(device)
|
2233 |
-
|
2234 |
-
# clip skip 対応のための wrapper を作る
|
2235 |
-
if args.clip_skip is None:
|
2236 |
-
text_encoder_or_wrapper = text_encoder
|
2237 |
-
else:
|
2238 |
-
class Wrapper():
|
2239 |
-
def __init__(self, tenc) -> None:
|
2240 |
-
self.tenc = tenc
|
2241 |
-
self.config = {}
|
2242 |
-
super().__init__()
|
2243 |
-
|
2244 |
-
def __call__(self, input_ids, attention_mask):
|
2245 |
-
enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True)
|
2246 |
-
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
2247 |
-
encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states)
|
2248 |
-
pooled_output = enc_out['pooler_output']
|
2249 |
-
return encoder_hidden_states, pooled_output # 1st output is only used
|
2250 |
-
|
2251 |
-
text_encoder_or_wrapper = Wrapper(text_encoder)
|
2252 |
-
|
2253 |
-
# read prompts
|
2254 |
-
with open(args.sample_prompts, 'rt', encoding='utf-8') as f:
|
2255 |
-
prompts = f.readlines()
|
2256 |
-
|
2257 |
-
# schedulerを用意する
|
2258 |
-
sched_init_args = {}
|
2259 |
-
if args.sample_sampler == "ddim":
|
2260 |
-
scheduler_cls = DDIMScheduler
|
2261 |
-
elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
|
2262 |
-
scheduler_cls = DDPMScheduler
|
2263 |
-
elif args.sample_sampler == "pndm":
|
2264 |
-
scheduler_cls = PNDMScheduler
|
2265 |
-
elif args.sample_sampler == 'lms' or args.sample_sampler == 'k_lms':
|
2266 |
-
scheduler_cls = LMSDiscreteScheduler
|
2267 |
-
elif args.sample_sampler == 'euler' or args.sample_sampler == 'k_euler':
|
2268 |
-
scheduler_cls = EulerDiscreteScheduler
|
2269 |
-
elif args.sample_sampler == 'euler_a' or args.sample_sampler == 'k_euler_a':
|
2270 |
-
scheduler_cls = EulerAncestralDiscreteScheduler
|
2271 |
-
elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
|
2272 |
-
scheduler_cls = DPMSolverMultistepScheduler
|
2273 |
-
sched_init_args['algorithm_type'] = args.sample_sampler
|
2274 |
-
elif args.sample_sampler == "dpmsingle":
|
2275 |
-
scheduler_cls = DPMSolverSinglestepScheduler
|
2276 |
-
elif args.sample_sampler == "heun":
|
2277 |
-
scheduler_cls = HeunDiscreteScheduler
|
2278 |
-
elif args.sample_sampler == 'dpm_2' or args.sample_sampler == 'k_dpm_2':
|
2279 |
-
scheduler_cls = KDPM2DiscreteScheduler
|
2280 |
-
elif args.sample_sampler == 'dpm_2_a' or args.sample_sampler == 'k_dpm_2_a':
|
2281 |
-
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
2282 |
-
else:
|
2283 |
-
scheduler_cls = DDIMScheduler
|
2284 |
-
|
2285 |
-
if args.v_parameterization:
|
2286 |
-
sched_init_args['prediction_type'] = 'v_prediction'
|
2287 |
-
|
2288 |
-
scheduler = scheduler_cls(num_train_timesteps=SCHEDULER_TIMESTEPS,
|
2289 |
-
beta_start=SCHEDULER_LINEAR_START, beta_end=SCHEDULER_LINEAR_END,
|
2290 |
-
beta_schedule=SCHEDLER_SCHEDULE, **sched_init_args)
|
2291 |
-
|
2292 |
-
# clip_sample=Trueにする
|
2293 |
-
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
2294 |
-
# print("set clip_sample to True")
|
2295 |
-
scheduler.config.clip_sample = True
|
2296 |
-
|
2297 |
-
pipeline = StableDiffusionPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer,
|
2298 |
-
scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False)
|
2299 |
-
pipeline.to(device)
|
2300 |
-
|
2301 |
-
save_dir = args.output_dir + "/sample"
|
2302 |
-
os.makedirs(save_dir, exist_ok=True)
|
2303 |
-
|
2304 |
-
rng_state = torch.get_rng_state()
|
2305 |
-
cuda_rng_state = torch.cuda.get_rng_state()
|
2306 |
-
|
2307 |
-
with torch.no_grad():
|
2308 |
-
with accelerator.autocast():
|
2309 |
-
for i, prompt in enumerate(prompts):
|
2310 |
-
if not accelerator.is_main_process:
|
2311 |
-
continue
|
2312 |
-
prompt = prompt.strip()
|
2313 |
-
if len(prompt) == 0 or prompt[0] == '#':
|
2314 |
-
continue
|
2315 |
-
|
2316 |
-
# subset of gen_img_diffusers
|
2317 |
-
prompt_args = prompt.split(' --')
|
2318 |
-
prompt = prompt_args[0]
|
2319 |
-
negative_prompt = None
|
2320 |
-
sample_steps = 30
|
2321 |
-
width = height = 512
|
2322 |
-
scale = 7.5
|
2323 |
-
seed = None
|
2324 |
-
for parg in prompt_args:
|
2325 |
-
try:
|
2326 |
-
m = re.match(r'w (\d+)', parg, re.IGNORECASE)
|
2327 |
-
if m:
|
2328 |
-
width = int(m.group(1))
|
2329 |
-
continue
|
2330 |
-
|
2331 |
-
m = re.match(r'h (\d+)', parg, re.IGNORECASE)
|
2332 |
-
if m:
|
2333 |
-
height = int(m.group(1))
|
2334 |
-
continue
|
2335 |
-
|
2336 |
-
m = re.match(r'd (\d+)', parg, re.IGNORECASE)
|
2337 |
-
if m:
|
2338 |
-
seed = int(m.group(1))
|
2339 |
-
continue
|
2340 |
-
|
2341 |
-
m = re.match(r's (\d+)', parg, re.IGNORECASE)
|
2342 |
-
if m: # steps
|
2343 |
-
sample_steps = max(1, min(1000, int(m.group(1))))
|
2344 |
-
continue
|
2345 |
-
|
2346 |
-
m = re.match(r'l ([\d\.]+)', parg, re.IGNORECASE)
|
2347 |
-
if m: # scale
|
2348 |
-
scale = float(m.group(1))
|
2349 |
-
continue
|
2350 |
-
|
2351 |
-
m = re.match(r'n (.+)', parg, re.IGNORECASE)
|
2352 |
-
if m: # negative prompt
|
2353 |
-
negative_prompt = m.group(1)
|
2354 |
-
continue
|
2355 |
-
|
2356 |
-
except ValueError as ex:
|
2357 |
-
print(f"Exception in parsing / 解析エラー: {parg}")
|
2358 |
-
print(ex)
|
2359 |
-
|
2360 |
-
if seed is not None:
|
2361 |
-
torch.manual_seed(seed)
|
2362 |
-
torch.cuda.manual_seed(seed)
|
2363 |
-
|
2364 |
-
if prompt_replacement is not None:
|
2365 |
-
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
2366 |
-
if negative_prompt is not None:
|
2367 |
-
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
2368 |
-
|
2369 |
-
height = max(64, height - height % 8) # round to divisible by 8
|
2370 |
-
width = max(64, width - width % 8) # round to divisible by 8
|
2371 |
-
print(f"prompt: {prompt}")
|
2372 |
-
print(f"negative_prompt: {negative_prompt}")
|
2373 |
-
print(f"height: {height}")
|
2374 |
-
print(f"width: {width}")
|
2375 |
-
print(f"sample_steps: {sample_steps}")
|
2376 |
-
print(f"scale: {scale}")
|
2377 |
-
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
|
2378 |
-
|
2379 |
-
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
|
2380 |
-
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
2381 |
-
seed_suffix = "" if seed is None else f"_{seed}"
|
2382 |
-
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
|
2383 |
-
|
2384 |
-
image.save(os.path.join(save_dir, img_filename))
|
2385 |
-
|
2386 |
-
# clear pipeline and cache to reduce vram usage
|
2387 |
-
del pipeline
|
2388 |
-
torch.cuda.empty_cache()
|
2389 |
-
|
2390 |
-
torch.set_rng_state(rng_state)
|
2391 |
-
torch.cuda.set_rng_state(cuda_rng_state)
|
2392 |
-
vae.to(org_vae_device)
|
2393 |
-
|
2394 |
# endregion
|
2395 |
|
2396 |
# region 前処理用
|
|
|
1 |
# common functions for training
|
2 |
|
3 |
import argparse
|
|
|
4 |
import json
|
|
|
5 |
import shutil
|
6 |
import time
|
7 |
+
from typing import Dict, List, NamedTuple, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from accelerate import Accelerator
|
9 |
+
from torch.autograd.function import Function
|
10 |
import glob
|
11 |
import math
|
12 |
import os
|
|
|
17 |
|
18 |
from tqdm import tqdm
|
19 |
import torch
|
|
|
20 |
from torchvision import transforms
|
21 |
from transformers import CLIPTokenizer
|
|
|
22 |
import diffusers
|
23 |
+
from diffusers import DDPMScheduler, StableDiffusionPipeline
|
|
|
|
|
|
|
|
|
24 |
import albumentations as albu
|
25 |
import numpy as np
|
26 |
from PIL import Image
|
|
|
195 |
batch_index: int
|
196 |
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
class BaseDataset(torch.utils.data.Dataset):
|
199 |
+
def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
|
200 |
super().__init__()
|
201 |
+
self.tokenizer: CLIPTokenizer = tokenizer
|
202 |
self.max_token_length = max_token_length
|
203 |
+
self.shuffle_caption = shuffle_caption
|
204 |
+
self.shuffle_keep_tokens = shuffle_keep_tokens
|
205 |
# width/height is used when enable_bucket==False
|
206 |
self.width, self.height = (None, None) if resolution is None else resolution
|
207 |
+
self.face_crop_aug_range = face_crop_aug_range
|
208 |
+
self.flip_aug = flip_aug
|
209 |
+
self.color_aug = color_aug
|
210 |
self.debug_dataset = debug_dataset
|
211 |
+
self.random_crop = random_crop
|
|
|
|
|
212 |
self.token_padding_disabled = False
|
213 |
+
self.dataset_dirs_info = {}
|
214 |
+
self.reg_dataset_dirs_info = {}
|
215 |
self.tag_frequency = {}
|
216 |
|
217 |
self.enable_bucket = False
|
|
|
225 |
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
226 |
|
227 |
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
228 |
+
self.dropout_rate: float = 0
|
229 |
+
self.dropout_every_n_epochs: int = None
|
230 |
+
self.tag_dropout_rate: float = 0
|
231 |
|
232 |
# augmentation
|
233 |
+
flip_p = 0.5 if flip_aug else 0.0
|
234 |
+
if color_aug:
|
235 |
+
# わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
|
236 |
+
self.aug = albu.Compose([
|
237 |
+
albu.OneOf([
|
238 |
+
albu.HueSaturationValue(8, 0, 0, p=.5),
|
239 |
+
albu.RandomGamma((95, 105), p=.5),
|
240 |
+
], p=.33),
|
241 |
+
albu.HorizontalFlip(p=flip_p)
|
242 |
+
], p=1.)
|
243 |
+
elif flip_aug:
|
244 |
+
self.aug = albu.Compose([
|
245 |
+
albu.HorizontalFlip(p=flip_p)
|
246 |
+
], p=1.)
|
247 |
+
else:
|
248 |
+
self.aug = None
|
249 |
|
250 |
self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
|
251 |
|
252 |
self.image_data: Dict[str, ImageInfo] = {}
|
|
|
253 |
|
254 |
self.replacements = {}
|
255 |
|
256 |
def set_current_epoch(self, epoch):
|
257 |
self.current_epoch = epoch
|
258 |
+
|
259 |
+
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
|
260 |
+
# コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
|
261 |
+
self.dropout_rate = dropout_rate
|
262 |
+
self.dropout_every_n_epochs = dropout_every_n_epochs
|
263 |
+
self.tag_dropout_rate = tag_dropout_rate
|
264 |
|
265 |
def set_tag_frequency(self, dir_name, captions):
|
266 |
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
267 |
self.tag_frequency[dir_name] = frequency_for_dir
|
268 |
for caption in captions:
|
269 |
for tag in caption.split(","):
|
270 |
+
if tag and not tag.isspace():
|
|
|
271 |
tag = tag.lower()
|
272 |
frequency = frequency_for_dir.get(tag, 0)
|
273 |
frequency_for_dir[tag] = frequency + 1
|
|
|
278 |
def add_replacement(self, str_from, str_to):
|
279 |
self.replacements[str_from] = str_to
|
280 |
|
281 |
+
def process_caption(self, caption):
|
282 |
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
283 |
+
is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
|
284 |
+
is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
|
285 |
|
286 |
if is_drop_out:
|
287 |
caption = ""
|
288 |
else:
|
289 |
+
if self.shuffle_caption or self.tag_dropout_rate > 0:
|
290 |
def dropout_tags(tokens):
|
291 |
+
if self.tag_dropout_rate <= 0:
|
292 |
return tokens
|
293 |
l = []
|
294 |
for token in tokens:
|
295 |
+
if random.random() >= self.tag_dropout_rate:
|
296 |
l.append(token)
|
297 |
return l
|
298 |
|
299 |
+
tokens = [t.strip() for t in caption.strip().split(",")]
|
300 |
+
if self.shuffle_keep_tokens is None:
|
301 |
+
if self.shuffle_caption:
|
302 |
+
random.shuffle(tokens)
|
303 |
+
|
304 |
+
tokens = dropout_tags(tokens)
|
305 |
+
else:
|
306 |
+
if len(tokens) > self.shuffle_keep_tokens:
|
307 |
+
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
308 |
+
tokens = tokens[self.shuffle_keep_tokens:]
|
309 |
|
310 |
+
if self.shuffle_caption:
|
311 |
+
random.shuffle(tokens)
|
312 |
|
313 |
+
tokens = dropout_tags(tokens)
|
314 |
|
315 |
+
tokens = keep_tokens + tokens
|
316 |
+
caption = ", ".join(tokens)
|
317 |
|
318 |
# textual inversion対応
|
319 |
for str_from, str_to in self.replacements.items():
|
|
|
367 |
input_ids = torch.stack(iids_list) # 3,77
|
368 |
return input_ids
|
369 |
|
370 |
+
def register_image(self, info: ImageInfo):
|
371 |
self.image_data[info.image_key] = info
|
|
|
372 |
|
373 |
def make_buckets(self):
|
374 |
'''
|
|
|
467 |
img = np.array(image, np.uint8)
|
468 |
return img
|
469 |
|
470 |
+
def trim_and_resize_if_required(self, image, reso, resized_size):
|
471 |
image_height, image_width = image.shape[0:2]
|
472 |
|
473 |
if image_width != resized_size[0] or image_height != resized_size[1]:
|
|
|
477 |
image_height, image_width = image.shape[0:2]
|
478 |
if image_width > reso[0]:
|
479 |
trim_size = image_width - reso[0]
|
480 |
+
p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
|
481 |
# print("w", trim_size, p)
|
482 |
image = image[:, p:p + reso[0]]
|
483 |
if image_height > reso[1]:
|
484 |
trim_size = image_height - reso[1]
|
485 |
+
p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
|
486 |
# print("h", trim_size, p)
|
487 |
image = image[p:p + reso[1]]
|
488 |
|
489 |
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
490 |
return image
|
491 |
|
|
|
|
|
|
|
492 |
def cache_latents(self, vae):
|
493 |
# TODO ここを高速化したい
|
494 |
print("caching latents.")
|
495 |
for info in tqdm(self.image_data.values()):
|
|
|
|
|
496 |
if info.latents_npz is not None:
|
497 |
info.latents = self.load_latents_from_npz(info, False)
|
498 |
info.latents = torch.FloatTensor(info.latents)
|
|
|
502 |
continue
|
503 |
|
504 |
image = self.load_image(info.absolute_path)
|
505 |
+
image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
|
506 |
|
507 |
img_tensor = self.image_transforms(image)
|
508 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
509 |
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
510 |
|
511 |
+
if self.flip_aug:
|
512 |
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
|
513 |
img_tensor = self.image_transforms(image)
|
514 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
|
|
518 |
image = Image.open(image_path)
|
519 |
return image.size
|
520 |
|
521 |
+
def load_image_with_face_info(self, image_path: str):
|
522 |
img = self.load_image(image_path)
|
523 |
|
524 |
face_cx = face_cy = face_w = face_h = 0
|
525 |
+
if self.face_crop_aug_range is not None:
|
526 |
tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
|
527 |
if len(tokens) >= 5:
|
528 |
face_cx = int(tokens[-4])
|
|
|
533 |
return img, face_cx, face_cy, face_w, face_h
|
534 |
|
535 |
# いい感じに切り出す
|
536 |
+
def crop_target(self, image, face_cx, face_cy, face_w, face_h):
|
537 |
height, width = image.shape[0:2]
|
538 |
if height == self.height and width == self.width:
|
539 |
return image
|
|
|
541 |
# 画像サイズはsizeより大きいのでリサイズする
|
542 |
face_size = max(face_w, face_h)
|
543 |
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
|
544 |
+
min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
|
545 |
+
max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
|
546 |
if min_scale >= max_scale: # range指定がmin==max
|
547 |
scale = min_scale
|
548 |
else:
|
|
|
560 |
for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
|
561 |
p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
|
562 |
|
563 |
+
if self.random_crop:
|
564 |
# 背景も含めるために顔を中心に置く確率を高めつつずらす
|
565 |
range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
|
566 |
p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
|
567 |
else:
|
568 |
# range指定があるときのみ、すこしだけランダムに(わりと適当)
|
569 |
+
if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
|
570 |
if face_size > self.size // 10 and face_size >= 40:
|
571 |
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
|
572 |
|
|
|
589 |
return self._length
|
590 |
|
591 |
def __getitem__(self, index):
|
592 |
+
if index == 0:
|
593 |
+
self.shuffle_buckets()
|
594 |
+
|
595 |
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
596 |
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
597 |
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
|
|
604 |
|
605 |
for image_key in bucket[image_index:image_index + bucket_batch_size]:
|
606 |
image_info = self.image_data[image_key]
|
|
|
607 |
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
608 |
|
609 |
# image/latentsを処理する
|
610 |
if image_info.latents is not None:
|
611 |
+
latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped
|
612 |
image = None
|
613 |
elif image_info.latents_npz is not None:
|
614 |
+
latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5)
|
615 |
latents = torch.FloatTensor(latents)
|
616 |
image = None
|
617 |
else:
|
618 |
# 画像を読み込み、必要ならcropする
|
619 |
+
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
|
620 |
im_h, im_w = img.shape[0:2]
|
621 |
|
622 |
if self.enable_bucket:
|
623 |
+
img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
|
624 |
else:
|
625 |
if face_cx > 0: # 顔位置情報あり
|
626 |
+
img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
|
627 |
elif im_h > self.height or im_w > self.width:
|
628 |
+
assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
|
629 |
if im_h > self.height:
|
630 |
p = random.randint(0, im_h - self.height)
|
631 |
img = img[p:p + self.height]
|
|
|
637 |
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
638 |
|
639 |
# augmentation
|
640 |
+
if self.aug is not None:
|
641 |
+
img = self.aug(image=img)['image']
|
|
|
642 |
|
643 |
latents = None
|
644 |
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
|
|
646 |
images.append(image)
|
647 |
latents_list.append(latents)
|
648 |
|
649 |
+
caption = self.process_caption(image_info.caption)
|
650 |
captions.append(caption)
|
651 |
if not self.token_padding_disabled: # this option might be omitted in future
|
652 |
input_ids_list.append(self.get_input_ids(caption))
|
|
|
677 |
|
678 |
|
679 |
class DreamBoothDataset(BaseDataset):
|
680 |
+
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
|
681 |
+
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
682 |
+
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
683 |
|
684 |
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
685 |
|
|
|
702 |
self.bucket_reso_steps = None # この情報は使われない
|
703 |
self.bucket_no_upscale = False
|
704 |
|
705 |
+
def read_caption(img_path):
|
706 |
# captionの候補ファイル名を作る
|
707 |
base_name = os.path.splitext(img_path)[0]
|
708 |
base_name_face_det = base_name
|
|
|
725 |
break
|
726 |
return caption
|
727 |
|
728 |
+
def load_dreambooth_dir(dir):
|
729 |
+
if not os.path.isdir(dir):
|
730 |
+
# print(f"ignore file: {dir}")
|
731 |
+
return 0, [], []
|
732 |
|
733 |
+
tokens = os.path.basename(dir).split('_')
|
734 |
+
try:
|
735 |
+
n_repeats = int(tokens[0])
|
736 |
+
except ValueError as e:
|
737 |
+
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
|
738 |
+
return 0, [], []
|
739 |
+
|
740 |
+
caption_by_folder = '_'.join(tokens[1:])
|
741 |
+
img_paths = glob_images(dir, "*")
|
742 |
+
print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
|
743 |
|
744 |
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
745 |
captions = []
|
746 |
for img_path in img_paths:
|
747 |
+
cap_for_img = read_caption(img_path)
|
748 |
+
captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
|
|
|
|
|
|
|
|
|
749 |
|
750 |
+
self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
|
751 |
|
752 |
+
return n_repeats, img_paths, captions
|
753 |
|
754 |
+
print("prepare train images.")
|
755 |
+
train_dirs = os.listdir(train_data_dir)
|
756 |
num_train_images = 0
|
757 |
+
for dir in train_dirs:
|
758 |
+
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
|
759 |
+
num_train_images += n_repeats * len(img_paths)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
760 |
|
761 |
for img_path, caption in zip(img_paths, captions):
|
762 |
+
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
|
763 |
+
self.register_image(info)
|
|
|
|
|
|
|
764 |
|
765 |
+
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
|
|
766 |
|
767 |
print(f"{num_train_images} train images with repeating.")
|
768 |
self.num_train_images = num_train_images
|
769 |
|
770 |
+
# reg imageは数を数えて学習画像と同じ枚数にする
|
771 |
+
num_reg_images = 0
|
772 |
+
if reg_data_dir:
|
773 |
+
print("prepare reg images.")
|
774 |
+
reg_infos: List[ImageInfo] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
775 |
|
776 |
+
reg_dirs = os.listdir(reg_data_dir)
|
777 |
+
for dir in reg_dirs:
|
778 |
+
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
|
779 |
+
num_reg_images += n_repeats * len(img_paths)
|
780 |
|
781 |
+
for img_path, caption in zip(img_paths, captions):
|
782 |
+
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
783 |
+
reg_infos.append(info)
|
784 |
|
785 |
+
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
|
|
|
|
|
|
|
|
786 |
|
787 |
+
print(f"{num_reg_images} reg images.")
|
788 |
+
if num_train_images < num_reg_images:
|
789 |
+
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
|
|
|
790 |
|
791 |
+
if num_reg_images == 0:
|
792 |
+
print("no regularization images / 正則化画像が見つかりませんでした")
|
|
|
|
|
|
|
793 |
else:
|
794 |
+
# num_repeatsを計算する:どうせ大した数ではないのでループで処理する
|
795 |
+
n = 0
|
796 |
+
first_loop = True
|
797 |
+
while n < num_train_images:
|
798 |
+
for info in reg_infos:
|
799 |
+
if first_loop:
|
800 |
+
self.register_image(info)
|
801 |
+
n += info.num_repeats
|
802 |
+
else:
|
803 |
+
info.num_repeats += 1
|
804 |
+
n += 1
|
805 |
+
if n >= num_train_images:
|
806 |
+
break
|
807 |
+
first_loop = False
|
808 |
|
809 |
+
self.num_reg_images = num_reg_images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
810 |
|
|
|
|
|
811 |
|
812 |
+
class FineTuningDataset(BaseDataset):
|
813 |
+
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
|
814 |
+
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
815 |
+
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
816 |
+
|
817 |
+
# メタデータを読み込む
|
818 |
+
if os.path.exists(json_file_name):
|
819 |
+
print(f"loading existing metadata: {json_file_name}")
|
820 |
+
with open(json_file_name, "rt", encoding='utf-8') as f:
|
821 |
+
metadata = json.load(f)
|
822 |
+
else:
|
823 |
+
raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}")
|
824 |
|
825 |
+
self.metadata = metadata
|
826 |
+
self.train_data_dir = train_data_dir
|
827 |
+
self.batch_size = batch_size
|
828 |
|
829 |
+
tags_list = []
|
830 |
+
for image_key, img_md in metadata.items():
|
831 |
+
# path情報を作る
|
832 |
+
if os.path.exists(image_key):
|
833 |
+
abs_path = image_key
|
834 |
+
else:
|
835 |
+
# わりといい加減だがいい方法が思いつかん
|
836 |
+
abs_path = glob_images(train_data_dir, image_key)
|
837 |
+
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
838 |
+
abs_path = abs_path[0]
|
839 |
+
|
840 |
+
caption = img_md.get('caption')
|
841 |
+
tags = img_md.get('tags')
|
842 |
+
if caption is None:
|
843 |
+
caption = tags
|
844 |
+
elif tags is not None and len(tags) > 0:
|
845 |
+
caption = caption + ', ' + tags
|
846 |
+
tags_list.append(tags)
|
847 |
+
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
|
848 |
+
|
849 |
+
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
|
850 |
+
image_info.image_size = img_md.get('train_resolution')
|
851 |
+
|
852 |
+
if not self.color_aug and not self.random_crop:
|
853 |
+
# if npz exists, use them
|
854 |
+
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
|
855 |
+
|
856 |
+
self.register_image(image_info)
|
857 |
+
self.num_train_images = len(metadata) * dataset_repeats
|
858 |
+
self.num_reg_images = 0
|
859 |
|
860 |
+
# TODO do not record tag freq when no tag
|
861 |
+
self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
|
862 |
+
self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
|
|
|
863 |
|
864 |
# check existence of all npz files
|
865 |
+
use_npz_latents = not (self.color_aug or self.random_crop)
|
866 |
if use_npz_latents:
|
|
|
867 |
npz_any = False
|
868 |
npz_all = True
|
|
|
869 |
for image_info in self.image_data.values():
|
|
|
|
|
870 |
has_npz = image_info.latents_npz is not None
|
871 |
npz_any = npz_any or has_npz
|
872 |
|
873 |
+
if self.flip_aug:
|
874 |
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
|
|
875 |
npz_all = npz_all and has_npz
|
876 |
|
877 |
if npz_any and not npz_all:
|
|
|
883 |
elif not npz_all:
|
884 |
use_npz_latents = False
|
885 |
print(f"some of npz file does not exist. ignore npz files / いくつ���のnpzファイルが見つからないためnpzファイルを無視します")
|
886 |
+
if self.flip_aug:
|
887 |
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
888 |
# else:
|
889 |
# print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
|
|
929 |
for image_info in self.image_data.values():
|
930 |
image_info.latents_npz = image_info.latents_npz_flipped = None
|
931 |
|
932 |
+
def image_key_to_npz_file(self, image_key):
|
933 |
base_name = os.path.splitext(image_key)[0]
|
934 |
npz_file_norm = base_name + '.npz'
|
935 |
|
|
|
941 |
return npz_file_norm, npz_file_flip
|
942 |
|
943 |
# image_key is relative path
|
944 |
+
npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz')
|
945 |
+
npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz')
|
946 |
|
947 |
if not os.path.exists(npz_file_norm):
|
948 |
npz_file_norm = None
|
|
|
953 |
return npz_file_norm, npz_file_flip
|
954 |
|
955 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
956 |
def debug_dataset(train_dataset, show_input_ids=False):
|
957 |
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
958 |
print("Escape for exit. / Escキーで中断、終了します")
|
959 |
|
960 |
train_dataset.set_current_epoch(1)
|
961 |
k = 0
|
962 |
+
for i, example in enumerate(train_dataset):
|
|
|
|
|
|
|
963 |
if example['latents'] is not None:
|
964 |
print(f"sample has latents from npz file: {example['latents'].size()}")
|
965 |
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
|
|
|
1364 |
help='enable v-parameterization training / v-parameterization学習を有効にする')
|
1365 |
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
|
1366 |
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1367 |
|
1368 |
|
1369 |
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
|
|
1387 |
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
1388 |
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
1389 |
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
1390 |
+
parser.add_argument("--use_8bit_adam", action="store_true",
|
1391 |
+
help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
1392 |
+
parser.add_argument("--use_lion_optimizer", action="store_true",
|
1393 |
+
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
1394 |
parser.add_argument("--mem_eff_attn", action="store_true",
|
1395 |
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
1396 |
parser.add_argument("--xformers", action="store_true",
|
|
|
1398 |
parser.add_argument("--vae", type=str, default=None,
|
1399 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
1400 |
|
1401 |
+
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
1402 |
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
1403 |
parser.add_argument("--max_train_epochs", type=int, default=None,
|
1404 |
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
|
|
1419 |
parser.add_argument("--logging_dir", type=str, default=None,
|
1420 |
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
1421 |
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
1422 |
+
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
1423 |
+
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
|
1424 |
+
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
1425 |
+
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
1426 |
parser.add_argument("--noise_offset", type=float, default=None,
|
1427 |
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
|
1428 |
parser.add_argument("--lowram", action="store_true",
|
1429 |
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
|
1430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1431 |
if support_dreambooth:
|
1432 |
# DreamBooth training
|
1433 |
parser.add_argument("--prior_loss_weight", type=float, default=1.0,
|
|
|
1449 |
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
|
1450 |
parser.add_argument("--caption_extention", type=str, default=None,
|
1451 |
help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
|
1452 |
+
parser.add_argument("--keep_tokens", type=int, default=None,
|
1453 |
+
help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
|
1454 |
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
|
1455 |
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
|
1456 |
parser.add_argument("--face_crop_aug_range", type=str, default=None,
|
|
|
1475 |
if support_caption_dropout:
|
1476 |
# Textual Inversion はcaptionのdropoutをsupportしない
|
1477 |
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
1478 |
+
parser.add_argument("--caption_dropout_rate", type=float, default=0,
|
1479 |
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
1480 |
+
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
|
1481 |
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
1482 |
+
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
|
1483 |
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
|
1484 |
|
1485 |
if support_dreambooth:
|
|
|
1504 |
# region utils
|
1505 |
|
1506 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1507 |
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
1508 |
# backward compatibility
|
1509 |
if args.caption_extention is not None:
|
1510 |
args.caption_extension = args.caption_extention
|
1511 |
args.caption_extention = None
|
1512 |
|
1513 |
+
if args.cache_latents:
|
1514 |
+
assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
|
1515 |
+
assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
|
1516 |
+
|
1517 |
# assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
|
1518 |
if args.resolution is not None:
|
1519 |
args.resolution = tuple([int(r) for r in args.resolution.split(',')])
|
|
|
1536 |
|
1537 |
def load_tokenizer(args: argparse.Namespace):
|
1538 |
print("prepare tokenizer")
|
1539 |
+
if args.v2:
|
1540 |
+
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
1541 |
+
else:
|
1542 |
+
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
|
1543 |
+
if args.max_token_length is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1544 |
print(f"update token length: {args.max_token_length}")
|
|
|
|
|
|
|
|
|
|
|
1545 |
return tokenizer
|
1546 |
|
1547 |
|
|
|
1592 |
|
1593 |
|
1594 |
def load_target_model(args: argparse.Namespace, weight_dtype):
|
1595 |
+
load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers
|
|
|
|
|
1596 |
if load_stable_diffusion_format:
|
1597 |
print("load StableDiffusion checkpoint")
|
1598 |
+
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
|
1599 |
else:
|
1600 |
print("load Diffusers pretrained models")
|
1601 |
+
pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
|
|
|
|
|
|
|
|
|
1602 |
text_encoder = pipe.text_encoder
|
1603 |
vae = pipe.vae
|
1604 |
unet = pipe.unet
|
|
|
1767 |
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
1768 |
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
1769 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1770 |
# endregion
|
1771 |
|
1772 |
# region 前処理用
|
lora_train_popup.py
ADDED
@@ -0,0 +1,862 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
from functools import partial
|
5 |
+
from typing import Union
|
6 |
+
import os
|
7 |
+
import tkinter as tk
|
8 |
+
from tkinter import filedialog as fd, ttk
|
9 |
+
from tkinter import simpledialog as sd
|
10 |
+
from tkinter import messagebox as mb
|
11 |
+
|
12 |
+
import torch.cuda
|
13 |
+
import train_network
|
14 |
+
import library.train_util as util
|
15 |
+
import argparse
|
16 |
+
|
17 |
+
|
18 |
+
class ArgStore:
|
19 |
+
# Represents the entirety of all possible inputs for sd-scripts. they are ordered from most important to least
|
20 |
+
def __init__(self):
|
21 |
+
# Important, these are the most likely things you will modify
|
22 |
+
self.base_model: str = r"" # example path, r"E:\sd\stable-diffusion-webui\models\Stable-diffusion\nai.ckpt"
|
23 |
+
self.img_folder: str = r"" # is the folder path to your img folder, make sure to follow the guide here for folder setup: https://rentry.org/2chAI_LoRA_Dreambooth_guide_english#for-kohyas-script
|
24 |
+
self.output_folder: str = r"" # just the folder all epochs/safetensors are output
|
25 |
+
self.change_output_name: Union[str, None] = None # changes the output name of the epochs
|
26 |
+
self.save_json_folder: Union[str, None] = None # OPTIONAL, saves a json folder of your config to whatever location you set here.
|
27 |
+
self.load_json_path: Union[str, None] = None # OPTIONAL, loads a json file partially changes the config to match. things like folder paths do not get modified.
|
28 |
+
self.json_load_skip_list: Union[list[str], None] = ["save_json_folder", "reg_img_folder",
|
29 |
+
"lora_model_for_resume", "change_output_name",
|
30 |
+
"training_comment",
|
31 |
+
"json_load_skip_list"] # OPTIONAL, allows the user to define what they skip when loading a json, by default it loads everything, including all paths, set it up like this ["base_model", "img_folder", "output_folder"]
|
32 |
+
self.caption_dropout_rate: Union[float, None] = None # The rate at which captions for files get dropped.
|
33 |
+
self.caption_dropout_every_n_epochs: Union[int, None] = None # Defines how often an epoch will completely ignore
|
34 |
+
# captions, EX. 3 means it will ignore captions at epochs 3, 6, and 9
|
35 |
+
self.caption_tag_dropout_rate: Union[float, None] = None # Defines the rate at which a tag would be dropped, rather than the entire caption file
|
36 |
+
self.noise_offset: Union[float, None] = None # OPTIONAL, seems to help allow SD to gen better blacks and whites
|
37 |
+
# Kohya recommends, if you have it set, to use 0.1, not sure how
|
38 |
+
# high the value can be, I'm going to assume maximum of 1
|
39 |
+
|
40 |
+
self.net_dim: int = 128 # network dimension, 128 is the most common, however you might be able to get lesser to work
|
41 |
+
self.alpha: float = 128 # represents the scalar for training. the lower the alpha, the less gets learned per step. if you want the older way of training, set this to dim
|
42 |
+
# list of schedulers: linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup
|
43 |
+
self.scheduler: str = "cosine_with_restarts" # the scheduler for learning rate. Each does something specific
|
44 |
+
self.cosine_restarts: Union[int, None] = 1 # OPTIONAL, represents the number of times it restarts. Only matters if you are using cosine_with_restarts
|
45 |
+
self.scheduler_power: Union[float, None] = 1 # OPTIONAL, represents the power of the polynomial. Only matters if you are using polynomial
|
46 |
+
self.warmup_lr_ratio: Union[float, None] = None # OPTIONAL, Calculates the number of warmup steps based on the ratio given. Make sure to set this if you are using constant_with_warmup, None to ignore
|
47 |
+
self.learning_rate: Union[float, None] = 1e-4 # OPTIONAL, when not set, lr gets set to 1e-3 as per adamW. Personally, I suggest actually setting this as lower lr seems to be a small bit better.
|
48 |
+
self.text_encoder_lr: Union[float, None] = None # OPTIONAL, Sets a specific lr for the text encoder, this overwrites the base lr I believe, None to ignore
|
49 |
+
self.unet_lr: Union[float, None] = None # OPTIONAL, Sets a specific lr for the unet, this overwrites the base lr I believe, None to ignore
|
50 |
+
self.num_workers: int = 1 # The number of threads that are being used to load images, lower speeds up the start of epochs, but slows down the loading of data. The assumption here is that it increases the training time as you reduce this value
|
51 |
+
self.persistent_workers: bool = True # makes workers persistent, further reduces/eliminates the lag in between epochs. however it may increase memory usage
|
52 |
+
|
53 |
+
self.batch_size: int = 1 # The number of images that get processed at one time, this is directly proportional to your vram and resolution. with 12gb of vram, at 512 reso, you can get a maximum of 6 batch size
|
54 |
+
self.num_epochs: int = 1 # The number of epochs, if you set max steps this value is ignored as it doesn't calculate steps.
|
55 |
+
self.save_every_n_epochs: Union[int, None] = None # OPTIONAL, how often to save epochs, None to ignore
|
56 |
+
self.shuffle_captions: bool = False # OPTIONAL, False to ignore
|
57 |
+
self.keep_tokens: Union[int, None] = None # OPTIONAL, None to ignore
|
58 |
+
self.max_steps: Union[int, None] = None # OPTIONAL, if you have specific steps you want to hit, this allows you to set it directly. None to ignore
|
59 |
+
self.tag_occurrence_txt_file: bool = False # OPTIONAL, creates a txt file that has the entire occurrence of all tags in your dataset
|
60 |
+
# the metadata will also have this so long as you have metadata on, so no reason to have this on by default
|
61 |
+
# will automatically output to the same folder as your output checkpoints
|
62 |
+
self.sort_tag_occurrence_alphabetically: bool = False # OPTIONAL, only applies if tag_occurrence_txt_file is also true
|
63 |
+
# Will change the output to be alphabetically vs being occurrence based
|
64 |
+
|
65 |
+
# These are the second most likely things you will modify
|
66 |
+
self.train_resolution: int = 512
|
67 |
+
self.min_bucket_resolution: int = 320
|
68 |
+
self.max_bucket_resolution: int = 960
|
69 |
+
self.lora_model_for_resume: Union[str, None] = None # OPTIONAL, takes an input lora to continue training from, not exactly the way it *should* be, but it works, None to ignore
|
70 |
+
self.save_state: bool = False # OPTIONAL, is the intended way to save a training state to use for continuing training, False to ignore
|
71 |
+
self.load_previous_save_state: Union[str, None] = None # OPTIONAL, is the intended way to load a training state to use for continuing training, None to ignore
|
72 |
+
self.training_comment: Union[str, None] = None # OPTIONAL, great way to put in things like activation tokens right into the metadata. seems to not work at this point and time
|
73 |
+
self.unet_only: bool = False # OPTIONAL, set it to only train the unet
|
74 |
+
self.text_only: bool = False # OPTIONAL, set it to only train the text encoder
|
75 |
+
|
76 |
+
# These are the least likely things you will modify
|
77 |
+
self.reg_img_folder: Union[str, None] = None # OPTIONAL, None to ignore
|
78 |
+
self.clip_skip: int = 2 # If you are training on a model that is anime based, keep this at 2 as most models are designed for that
|
79 |
+
self.test_seed: int = 23 # this is the "reproducable seed", basically if you set the seed to this, you should be able to input a prompt from one of your training images and get a close representation of it
|
80 |
+
self.prior_loss_weight: float = 1 # is the loss weight much like Dreambooth, is required for LoRA training
|
81 |
+
self.gradient_checkpointing: bool = False # OPTIONAL, enables gradient checkpointing
|
82 |
+
self.gradient_acc_steps: Union[int, None] = None # OPTIONAL, not sure exactly what this means
|
83 |
+
self.mixed_precision: str = "fp16" # If you have the ability to use bf16, do it, it's better
|
84 |
+
self.save_precision: str = "fp16" # You can also save in bf16, but because it's not universally supported, I suggest you keep saving at fp16
|
85 |
+
self.save_as: str = "safetensors" # list is pt, ckpt, safetensors
|
86 |
+
self.caption_extension: str = ".txt" # the other option is .captions, but since wd1.4 tagger outputs as txt files, this is the default
|
87 |
+
self.max_clip_token_length = 150 # can be 75, 150, or 225 I believe, there is no reason to go higher than 150 though
|
88 |
+
self.buckets: bool = True
|
89 |
+
self.xformers: bool = True
|
90 |
+
self.use_8bit_adam: bool = True
|
91 |
+
self.cache_latents: bool = True
|
92 |
+
self.color_aug: bool = False # IMPORTANT: Clashes with cache_latents, only have one of the two on!
|
93 |
+
self.flip_aug: bool = False
|
94 |
+
self.vae: Union[str, None] = None # Seems to only make results worse when not using that specific vae, should probably not use
|
95 |
+
self.no_meta: bool = False # This removes the metadata that now gets saved into safetensors, (you should keep this on)
|
96 |
+
self.log_dir: Union[str, None] = None # output of logs, not useful to most people.
|
97 |
+
self.v2: bool = False # Sets up training for SD2.1
|
98 |
+
self.v_parameterization: bool = False # Only is used when v2 is also set and you are using the 768x version of v2
|
99 |
+
|
100 |
+
# Creates the dict that is used for the rest of the code, to facilitate easier json saving and loading
|
101 |
+
@staticmethod
|
102 |
+
def convert_args_to_dict():
|
103 |
+
return ArgStore().__dict__
|
104 |
+
|
105 |
+
|
106 |
+
def main():
|
107 |
+
parser = argparse.ArgumentParser()
|
108 |
+
setup_args(parser)
|
109 |
+
pre_args = parser.parse_args()
|
110 |
+
queues = 0
|
111 |
+
args_queue = []
|
112 |
+
cont = True
|
113 |
+
while cont:
|
114 |
+
arg_dict = ArgStore.convert_args_to_dict()
|
115 |
+
ret = mb.askyesno(message="Do you want to load a json config file?")
|
116 |
+
if ret:
|
117 |
+
load_json(ask_file("select json to load from", {"json"}), arg_dict)
|
118 |
+
arg_dict = ask_elements_trunc(arg_dict)
|
119 |
+
else:
|
120 |
+
arg_dict = ask_elements(arg_dict)
|
121 |
+
if pre_args.save_json_path or arg_dict["save_json_folder"]:
|
122 |
+
save_json(pre_args.save_json_path if pre_args.save_json_path else arg_dict['save_json_folder'], arg_dict)
|
123 |
+
args = create_arg_space(arg_dict)
|
124 |
+
args = parser.parse_args(args)
|
125 |
+
queues += 1
|
126 |
+
args_queue.append(args)
|
127 |
+
if arg_dict['tag_occurrence_txt_file']:
|
128 |
+
get_occurrence_of_tags(arg_dict)
|
129 |
+
ret = mb.askyesno(message="Do you want to queue another training?")
|
130 |
+
if not ret:
|
131 |
+
cont = False
|
132 |
+
for args in args_queue:
|
133 |
+
try:
|
134 |
+
train_network.train(args)
|
135 |
+
except Exception as e:
|
136 |
+
print(f"Failed to train this set of args.\nSkipping this training session.\nError is: {e}")
|
137 |
+
gc.collect()
|
138 |
+
torch.cuda.empty_cache()
|
139 |
+
|
140 |
+
|
141 |
+
def create_arg_space(args: dict) -> [str]:
|
142 |
+
# This is the list of args that are to be used regardless of setup
|
143 |
+
output = ["--network_module=networks.lora", f"--pretrained_model_name_or_path={args['base_model']}",
|
144 |
+
f"--train_data_dir={args['img_folder']}", f"--output_dir={args['output_folder']}",
|
145 |
+
f"--prior_loss_weight={args['prior_loss_weight']}", f"--caption_extension=" + args['caption_extension'],
|
146 |
+
f"--resolution={args['train_resolution']}", f"--train_batch_size={args['batch_size']}",
|
147 |
+
f"--mixed_precision={args['mixed_precision']}", f"--save_precision={args['save_precision']}",
|
148 |
+
f"--network_dim={args['net_dim']}", f"--save_model_as={args['save_as']}",
|
149 |
+
f"--clip_skip={args['clip_skip']}", f"--seed={args['test_seed']}",
|
150 |
+
f"--max_token_length={args['max_clip_token_length']}", f"--lr_scheduler={args['scheduler']}",
|
151 |
+
f"--network_alpha={args['alpha']}", f"--max_data_loader_n_workers={args['num_workers']}"]
|
152 |
+
if not args['max_steps']:
|
153 |
+
output.append(f"--max_train_epochs={args['num_epochs']}")
|
154 |
+
output += create_optional_args(args, find_max_steps(args))
|
155 |
+
else:
|
156 |
+
output.append(f"--max_train_steps={args['max_steps']}")
|
157 |
+
output += create_optional_args(args, args['max_steps'])
|
158 |
+
return output
|
159 |
+
|
160 |
+
|
161 |
+
def create_optional_args(args: dict, steps):
|
162 |
+
output = []
|
163 |
+
if args["reg_img_folder"]:
|
164 |
+
output.append(f"--reg_data_dir={args['reg_img_folder']}")
|
165 |
+
|
166 |
+
if args['lora_model_for_resume']:
|
167 |
+
output.append(f"--network_weights={args['lora_model_for_resume']}")
|
168 |
+
|
169 |
+
if args['save_every_n_epochs']:
|
170 |
+
output.append(f"--save_every_n_epochs={args['save_every_n_epochs']}")
|
171 |
+
else:
|
172 |
+
output.append("--save_every_n_epochs=999999")
|
173 |
+
|
174 |
+
if args['shuffle_captions']:
|
175 |
+
output.append("--shuffle_caption")
|
176 |
+
|
177 |
+
if args['keep_tokens'] and args['keep_tokens'] > 0:
|
178 |
+
output.append(f"--keep_tokens={args['keep_tokens']}")
|
179 |
+
|
180 |
+
if args['buckets']:
|
181 |
+
output.append("--enable_bucket")
|
182 |
+
output.append(f"--min_bucket_reso={args['min_bucket_resolution']}")
|
183 |
+
output.append(f"--max_bucket_reso={args['max_bucket_resolution']}")
|
184 |
+
|
185 |
+
if args['use_8bit_adam']:
|
186 |
+
output.append("--use_8bit_adam")
|
187 |
+
|
188 |
+
if args['xformers']:
|
189 |
+
output.append("--xformers")
|
190 |
+
|
191 |
+
if args['color_aug']:
|
192 |
+
if args['cache_latents']:
|
193 |
+
print("color_aug and cache_latents conflict with one another. Please select only one")
|
194 |
+
quit(1)
|
195 |
+
output.append("--color_aug")
|
196 |
+
|
197 |
+
if args['flip_aug']:
|
198 |
+
output.append("--flip_aug")
|
199 |
+
|
200 |
+
if args['cache_latents']:
|
201 |
+
output.append("--cache_latents")
|
202 |
+
|
203 |
+
if args['warmup_lr_ratio'] and args['warmup_lr_ratio'] > 0:
|
204 |
+
warmup_steps = int(steps * args['warmup_lr_ratio'])
|
205 |
+
output.append(f"--lr_warmup_steps={warmup_steps}")
|
206 |
+
|
207 |
+
if args['gradient_checkpointing']:
|
208 |
+
output.append("--gradient_checkpointing")
|
209 |
+
|
210 |
+
if args['gradient_acc_steps'] and args['gradient_acc_steps'] > 0 and args['gradient_checkpointing']:
|
211 |
+
output.append(f"--gradient_accumulation_steps={args['gradient_acc_steps']}")
|
212 |
+
|
213 |
+
if args['learning_rate'] and args['learning_rate'] > 0:
|
214 |
+
output.append(f"--learning_rate={args['learning_rate']}")
|
215 |
+
|
216 |
+
if args['text_encoder_lr'] and args['text_encoder_lr'] > 0:
|
217 |
+
output.append(f"--text_encoder_lr={args['text_encoder_lr']}")
|
218 |
+
|
219 |
+
if args['unet_lr'] and args['unet_lr'] > 0:
|
220 |
+
output.append(f"--unet_lr={args['unet_lr']}")
|
221 |
+
|
222 |
+
if args['vae']:
|
223 |
+
output.append(f"--vae={args['vae']}")
|
224 |
+
|
225 |
+
if args['no_meta']:
|
226 |
+
output.append("--no_metadata")
|
227 |
+
|
228 |
+
if args['save_state']:
|
229 |
+
output.append("--save_state")
|
230 |
+
|
231 |
+
if args['load_previous_save_state']:
|
232 |
+
output.append(f"--resume={args['load_previous_save_state']}")
|
233 |
+
|
234 |
+
if args['change_output_name']:
|
235 |
+
output.append(f"--output_name={args['change_output_name']}")
|
236 |
+
|
237 |
+
if args['training_comment']:
|
238 |
+
output.append(f"--training_comment={args['training_comment']}")
|
239 |
+
|
240 |
+
if args['cosine_restarts'] and args['scheduler'] == "cosine_with_restarts":
|
241 |
+
output.append(f"--lr_scheduler_num_cycles={args['cosine_restarts']}")
|
242 |
+
|
243 |
+
if args['scheduler_power'] and args['scheduler'] == "polynomial":
|
244 |
+
output.append(f"--lr_scheduler_power={args['scheduler_power']}")
|
245 |
+
|
246 |
+
if args['persistent_workers']:
|
247 |
+
output.append(f"--persistent_data_loader_workers")
|
248 |
+
|
249 |
+
if args['unet_only']:
|
250 |
+
output.append("--network_train_unet_only")
|
251 |
+
|
252 |
+
if args['text_only'] and not args['unet_only']:
|
253 |
+
output.append("--network_train_text_encoder_only")
|
254 |
+
|
255 |
+
if args["log_dir"]:
|
256 |
+
output.append(f"--logging_dir={args['log_dir']}")
|
257 |
+
|
258 |
+
if args['caption_dropout_rate']:
|
259 |
+
output.append(f"--caption_dropout_rate={args['caption_dropout_rate']}")
|
260 |
+
|
261 |
+
if args['caption_dropout_every_n_epochs']:
|
262 |
+
output.append(f"--caption_dropout_every_n_epochs={args['caption_dropout_every_n_epochs']}")
|
263 |
+
|
264 |
+
if args['caption_tag_dropout_rate']:
|
265 |
+
output.append(f"--caption_tag_dropout_rate={args['caption_tag_dropout_rate']}")
|
266 |
+
|
267 |
+
if args['v2']:
|
268 |
+
output.append("--v2")
|
269 |
+
|
270 |
+
if args['v2'] and args['v_parameterization']:
|
271 |
+
output.append("--v_parameterization")
|
272 |
+
|
273 |
+
if args['noise_offset']:
|
274 |
+
output.append(f"--noise_offset={args['noise_offset']}")
|
275 |
+
return output
|
276 |
+
|
277 |
+
|
278 |
+
def find_max_steps(args: dict) -> int:
|
279 |
+
total_steps = 0
|
280 |
+
folders = os.listdir(args["img_folder"])
|
281 |
+
for folder in folders:
|
282 |
+
if not os.path.isdir(os.path.join(args["img_folder"], folder)):
|
283 |
+
continue
|
284 |
+
num_repeats = folder.split("_")
|
285 |
+
if len(num_repeats) < 2:
|
286 |
+
print(f"folder {folder} is not in the correct format. Format is x_name. skipping")
|
287 |
+
continue
|
288 |
+
try:
|
289 |
+
num_repeats = int(num_repeats[0])
|
290 |
+
except ValueError:
|
291 |
+
print(f"folder {folder} is not in the correct format. Format is x_name. skipping")
|
292 |
+
continue
|
293 |
+
imgs = 0
|
294 |
+
for file in os.listdir(os.path.join(args["img_folder"], folder)):
|
295 |
+
if os.path.isdir(file):
|
296 |
+
continue
|
297 |
+
ext = file.split(".")
|
298 |
+
if ext[-1].lower() in {"png", "bmp", "gif", "jpeg", "jpg", "webp"}:
|
299 |
+
imgs += 1
|
300 |
+
total_steps += (num_repeats * imgs)
|
301 |
+
total_steps = int((total_steps / args["batch_size"]) * args["num_epochs"])
|
302 |
+
return total_steps
|
303 |
+
|
304 |
+
|
305 |
+
def add_misc_args(parser):
|
306 |
+
parser.add_argument("--save_json_path", type=str, default=None,
|
307 |
+
help="Path to save a configuration json file to")
|
308 |
+
parser.add_argument("--load_json_path", type=str, default=None,
|
309 |
+
help="Path to a json file to configure things from")
|
310 |
+
parser.add_argument("--no_metadata", action='store_true',
|
311 |
+
help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
312 |
+
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
313 |
+
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)")
|
314 |
+
|
315 |
+
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
316 |
+
parser.add_argument("--text_encoder_lr", type=float, default=None,
|
317 |
+
help="learning rate for Text Encoder / Text Encoderの学習率")
|
318 |
+
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
319 |
+
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
320 |
+
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
321 |
+
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
322 |
+
|
323 |
+
parser.add_argument("--network_weights", type=str, default=None,
|
324 |
+
help="pretrained weights for network / 学習するネットワークの初期重み")
|
325 |
+
parser.add_argument("--network_module", type=str, default=None,
|
326 |
+
help='network module to train / 学習対象のネットワークのモジュール')
|
327 |
+
parser.add_argument("--network_dim", type=int, default=None,
|
328 |
+
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
|
329 |
+
parser.add_argument("--network_alpha", type=float, default=1,
|
330 |
+
help='alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)')
|
331 |
+
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
332 |
+
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
333 |
+
parser.add_argument("--network_train_unet_only", action="store_true",
|
334 |
+
help="only training U-Net part / U-Net関連部分のみ学習する")
|
335 |
+
parser.add_argument("--network_train_text_encoder_only", action="store_true",
|
336 |
+
help="only training Text Encoder part / Text Encoder関連部分のみ学習する")
|
337 |
+
parser.add_argument("--training_comment", type=str, default=None,
|
338 |
+
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
|
339 |
+
|
340 |
+
|
341 |
+
def setup_args(parser):
|
342 |
+
util.add_sd_models_arguments(parser)
|
343 |
+
util.add_dataset_arguments(parser, True, True, True)
|
344 |
+
util.add_training_arguments(parser, True)
|
345 |
+
add_misc_args(parser)
|
346 |
+
|
347 |
+
|
348 |
+
def get_occurrence_of_tags(args):
|
349 |
+
extension = args['caption_extension']
|
350 |
+
img_folder = args['img_folder']
|
351 |
+
output_folder = args['output_folder']
|
352 |
+
occurrence_dict = {}
|
353 |
+
print(img_folder)
|
354 |
+
for folder in os.listdir(img_folder):
|
355 |
+
print(folder)
|
356 |
+
if not os.path.isdir(os.path.join(img_folder, folder)):
|
357 |
+
continue
|
358 |
+
for file in os.listdir(os.path.join(img_folder, folder)):
|
359 |
+
if not os.path.isfile(os.path.join(img_folder, folder, file)):
|
360 |
+
continue
|
361 |
+
ext = os.path.splitext(file)[1]
|
362 |
+
if ext != extension:
|
363 |
+
continue
|
364 |
+
get_tags_from_file(os.path.join(img_folder, folder, file), occurrence_dict)
|
365 |
+
if not args['sort_tag_occurrence_alphabetically']:
|
366 |
+
output_list = {k: v for k, v in sorted(occurrence_dict.items(), key=lambda item: item[1], reverse=True)}
|
367 |
+
else:
|
368 |
+
output_list = {k: v for k, v in sorted(occurrence_dict.items(), key=lambda item: item[0])}
|
369 |
+
name = args['change_output_name'] if args['change_output_name'] else "last"
|
370 |
+
with open(os.path.join(output_folder, f"{name}.txt"), "w") as f:
|
371 |
+
f.write(f"Below is a list of keywords used during the training of {args['change_output_name']}:\n")
|
372 |
+
for k, v in output_list.items():
|
373 |
+
f.write(f"[{v}] {k}\n")
|
374 |
+
print(f"Created a txt file named {name}.txt in the output folder")
|
375 |
+
|
376 |
+
|
377 |
+
def get_tags_from_file(file, occurrence_dict):
|
378 |
+
f = open(file)
|
379 |
+
temp = f.read().replace(", ", ",").split(",")
|
380 |
+
f.close()
|
381 |
+
for tag in temp:
|
382 |
+
if tag in occurrence_dict:
|
383 |
+
occurrence_dict[tag] += 1
|
384 |
+
else:
|
385 |
+
occurrence_dict[tag] = 1
|
386 |
+
|
387 |
+
|
388 |
+
def ask_file(message, accepted_ext_list, file_path=None):
|
389 |
+
mb.showinfo(message=message)
|
390 |
+
res = ""
|
391 |
+
_initialdir = ""
|
392 |
+
_initialfile = ""
|
393 |
+
if file_path != None:
|
394 |
+
_initialdir = os.path.dirname(file_path) if os.path.exists(file_path) else ""
|
395 |
+
_initialfile = os.path.basename(file_path) if os.path.exists(file_path) else ""
|
396 |
+
|
397 |
+
while res == "":
|
398 |
+
res = fd.askopenfilename(title=message, initialdir=_initialdir, initialfile=_initialfile)
|
399 |
+
if res == "" or type(res) == tuple:
|
400 |
+
ret = mb.askretrycancel(message="Do you want to to cancel training?")
|
401 |
+
if not ret:
|
402 |
+
exit()
|
403 |
+
continue
|
404 |
+
elif not os.path.exists(res):
|
405 |
+
res = ""
|
406 |
+
continue
|
407 |
+
_, name = os.path.split(res)
|
408 |
+
split_name = name.split(".")
|
409 |
+
if split_name[-1] not in accepted_ext_list:
|
410 |
+
res = ""
|
411 |
+
return res
|
412 |
+
|
413 |
+
|
414 |
+
def ask_dir(message, dir_path=None):
|
415 |
+
mb.showinfo(message=message)
|
416 |
+
res = ""
|
417 |
+
_initialdir = ""
|
418 |
+
if dir_path != None:
|
419 |
+
_initialdir = dir_path if os.path.exists(dir_path) else ""
|
420 |
+
while res == "":
|
421 |
+
res = fd.askdirectory(title=message, initialdir=_initialdir)
|
422 |
+
if res == "" or type(res) == tuple:
|
423 |
+
ret = mb.askretrycancel(message="Do you want to to cancel training?")
|
424 |
+
if not ret:
|
425 |
+
exit()
|
426 |
+
continue
|
427 |
+
if not os.path.exists(res):
|
428 |
+
res = ""
|
429 |
+
return res
|
430 |
+
|
431 |
+
|
432 |
+
def ask_elements_trunc(args: dict):
|
433 |
+
args['base_model'] = ask_file("Select your base model", {"ckpt", "safetensors"}, args['base_model'])
|
434 |
+
args['img_folder'] = ask_dir("Select your image folder", args['img_folder'])
|
435 |
+
args['output_folder'] = ask_dir("Select your output folder", args['output_folder'])
|
436 |
+
|
437 |
+
ret = mb.askyesno(message="Do you want to save a json of your configuration?")
|
438 |
+
if ret:
|
439 |
+
args['save_json_folder'] = ask_dir("Select the folder to save json files to", args['save_json_folder'])
|
440 |
+
else:
|
441 |
+
args['save_json_folder'] = None
|
442 |
+
|
443 |
+
ret = mb.askyesno(message="Are you training on a SD2 based model?")
|
444 |
+
if ret:
|
445 |
+
args['v2'] = True
|
446 |
+
|
447 |
+
ret = mb.askyesno(message="Are you training on an realistic model?")
|
448 |
+
if ret:
|
449 |
+
args['clip_skip'] = 1
|
450 |
+
|
451 |
+
if args['v2']:
|
452 |
+
ret = mb.askyesno(message="Are you training on a model based on the 768x version of SD2?")
|
453 |
+
if ret:
|
454 |
+
args['v_parameterization'] = True
|
455 |
+
|
456 |
+
ret = mb.askyesno(message="Do you want to use regularization images?")
|
457 |
+
if ret:
|
458 |
+
args['reg_img_folder'] = ask_dir("Select your regularization folder", args['reg_img_folder'])
|
459 |
+
else:
|
460 |
+
args['reg_img_folder'] = None
|
461 |
+
|
462 |
+
ret = mb.askyesno(message="Do you want to continue from an earlier version?")
|
463 |
+
if ret:
|
464 |
+
args['lora_model_for_resume'] = ask_file("Select your lora model", {"ckpt", "pt", "safetensors"},
|
465 |
+
args['lora_model_for_resume'])
|
466 |
+
else:
|
467 |
+
args['lora_model_for_resume'] = None
|
468 |
+
|
469 |
+
ret = mb.askyesno(message="Do you want to flip all of your images? It is supposed to reduce biases\n"
|
470 |
+
"within your dataset but it can also ruin learning an asymmetrical element\n")
|
471 |
+
if ret:
|
472 |
+
args['flip_aug'] = True
|
473 |
+
|
474 |
+
ret = mb.askyesno(message="Do you want to change the name of output checkpoints?")
|
475 |
+
if ret:
|
476 |
+
ret = sd.askstring(title="output_name", prompt="What do you want your output name to be?\n"
|
477 |
+
"Cancel keeps outputs the original")
|
478 |
+
if ret:
|
479 |
+
args['change_output_name'] = ret
|
480 |
+
else:
|
481 |
+
args['change_output_name'] = None
|
482 |
+
|
483 |
+
ret = sd.askstring(title="comment",
|
484 |
+
prompt="Do you want to set a comment that gets put into the metadata?\nA good use of this would "
|
485 |
+
"be to include how to use, such as activation keywords.\nCancel will leave empty")
|
486 |
+
if ret is None:
|
487 |
+
args['training_comment'] = ret
|
488 |
+
else:
|
489 |
+
args['training_comment'] = None
|
490 |
+
|
491 |
+
ret = mb.askyesno(message="Do you want to train only one of unet and text encoder?")
|
492 |
+
if ret:
|
493 |
+
button = ButtonBox("Which do you want to train with?", ["unet_only", "text_only"])
|
494 |
+
button.window.mainloop()
|
495 |
+
if button.current_value != "":
|
496 |
+
args[button.current_value] = True
|
497 |
+
|
498 |
+
ret = mb.askyesno(message="Do you want to save a txt file that contains a list\n"
|
499 |
+
"of all tags that you have used in your training data?\n")
|
500 |
+
if ret:
|
501 |
+
args['tag_occurrence_txt_file'] = True
|
502 |
+
button = ButtonBox("How do you want tags to be ordered?", ["alphabetically", "occurrence-ly"])
|
503 |
+
button.window.mainloop()
|
504 |
+
if button.current_value == "alphabetically":
|
505 |
+
args['sort_tag_occurrence_alphabetically'] = True
|
506 |
+
|
507 |
+
ret = mb.askyesno(message="Do you want to use caption dropout?")
|
508 |
+
if ret:
|
509 |
+
ret = mb.askyesno(message="Do you want full caption files to dropout randomly?")
|
510 |
+
if ret:
|
511 |
+
ret = sd.askinteger(title="Caption_File_Dropout",
|
512 |
+
prompt="How often do you want caption files to drop out?\n"
|
513 |
+
"enter a number from 0 to 100 that is the percentage chance of dropout\n"
|
514 |
+
"Cancel sets to 0")
|
515 |
+
if ret and 0 <= ret <= 100:
|
516 |
+
args['caption_dropout_rate'] = ret / 100.0
|
517 |
+
|
518 |
+
ret = mb.askyesno(message="Do you want to have full epochs have no captions?")
|
519 |
+
if ret:
|
520 |
+
ret = sd.askinteger(title="Caption_epoch_dropout", prompt="The number set here is how often you will have an"
|
521 |
+
"epoch with no captions\nSo if you set 3, then every"
|
522 |
+
"three epochs will not have captions (3, 6, 9)\n"
|
523 |
+
"Cancel will set to None")
|
524 |
+
if ret:
|
525 |
+
args['caption_dropout_every_n_epochs'] = ret
|
526 |
+
|
527 |
+
ret = mb.askyesno(message="Do you want to have tags to randomly drop?")
|
528 |
+
if ret:
|
529 |
+
ret = sd.askinteger(title="Caption_tag_dropout", prompt="How often do you want tags to randomly drop out?\n"
|
530 |
+
"Enter a number between 0 and 100, that is the percentage"
|
531 |
+
"chance of dropout.\nCancel sets to 0")
|
532 |
+
if ret and 0 <= ret <= 100:
|
533 |
+
args['caption_tag_dropout_rate'] = ret / 100.0
|
534 |
+
|
535 |
+
ret = mb.askyesno(message="Do you want to use noise offset? Noise offset seems to allow for SD to better generate\n"
|
536 |
+
"darker or lighter images using this than normal.")
|
537 |
+
if ret:
|
538 |
+
ret = sd.askfloat(title="noise_offset", prompt="What value do you want to set? recommended value is 0.1,\n"
|
539 |
+
"but it can go higher. Cancel defaults to 0.1")
|
540 |
+
if ret:
|
541 |
+
args['noise_offset'] = ret
|
542 |
+
else:
|
543 |
+
args['noise_offset'] = 0.1
|
544 |
+
return args
|
545 |
+
|
546 |
+
|
547 |
+
def ask_elements(args: dict):
|
548 |
+
# start with file dialog
|
549 |
+
args['base_model'] = ask_file("Select your base model", {"ckpt", "safetensors"}, args['base_model'])
|
550 |
+
args['img_folder'] = ask_dir("Select your image folder", args['img_folder'])
|
551 |
+
args['output_folder'] = ask_dir("Select your output folder", args['output_folder'])
|
552 |
+
|
553 |
+
# optional file dialog
|
554 |
+
ret = mb.askyesno(message="Do you want to save a json of your configuration?")
|
555 |
+
if ret:
|
556 |
+
args['save_json_folder'] = ask_dir("Select the folder to save json files to", args['save_json_folder'])
|
557 |
+
else:
|
558 |
+
args['save_json_folder'] = None
|
559 |
+
|
560 |
+
ret = mb.askyesno(message="Are you training on a SD2 based model?")
|
561 |
+
if ret:
|
562 |
+
args['v2'] = True
|
563 |
+
|
564 |
+
ret = mb.askyesno(message="Are you training on an realistic model?")
|
565 |
+
if ret:
|
566 |
+
args['clip_skip'] = 1
|
567 |
+
|
568 |
+
if args['v2']:
|
569 |
+
ret = mb.askyesno(message="Are you training on a model based on the 768x version of SD2?")
|
570 |
+
if ret:
|
571 |
+
args['v_parameterization'] = True
|
572 |
+
|
573 |
+
ret = mb.askyesno(message="Do you want to use regularization images?")
|
574 |
+
if ret:
|
575 |
+
args['reg_img_folder'] = ask_dir("Select your regularization folder", args['reg_img_folder'])
|
576 |
+
else:
|
577 |
+
args['reg_img_folder'] = None
|
578 |
+
|
579 |
+
ret = mb.askyesno(message="Do you want to continue from an earlier version?")
|
580 |
+
if ret:
|
581 |
+
args['lora_model_for_resume'] = ask_file("Select your lora model", {"ckpt", "pt", "safetensors"},
|
582 |
+
args['lora_model_for_resume'])
|
583 |
+
else:
|
584 |
+
args['lora_model_for_resume'] = None
|
585 |
+
|
586 |
+
ret = mb.askyesno(message="Do you want to flip all of your images? It is supposed to reduce biases\n"
|
587 |
+
"within your dataset but it can also ruin learning an asymmetrical element\n")
|
588 |
+
if ret:
|
589 |
+
args['flip_aug'] = True
|
590 |
+
|
591 |
+
# text based required elements
|
592 |
+
ret = sd.askinteger(title="batch_size",
|
593 |
+
prompt="The number of images that get processed at one time, this is directly proportional to "
|
594 |
+
"your vram and resolution. with 12gb of vram, at 512 reso, you can get a maximum of 6 "
|
595 |
+
"batch size\nHow large is your batch size going to be?\nCancel will default to 1")
|
596 |
+
if ret is None:
|
597 |
+
args['batch_size'] = 1
|
598 |
+
else:
|
599 |
+
args['batch_size'] = ret
|
600 |
+
|
601 |
+
ret = sd.askinteger(title="num_epochs", prompt="How many epochs do you want?\nCancel will default to 1")
|
602 |
+
if ret is None:
|
603 |
+
args['num_epochs'] = 1
|
604 |
+
else:
|
605 |
+
args['num_epochs'] = ret
|
606 |
+
|
607 |
+
ret = sd.askinteger(title="network_dim", prompt="What is the dim size you want to use?\nCancel will default to 128")
|
608 |
+
if ret is None:
|
609 |
+
args['net_dim'] = 128
|
610 |
+
else:
|
611 |
+
args['net_dim'] = ret
|
612 |
+
|
613 |
+
ret = sd.askfloat(title="alpha", prompt="Alpha is the scalar of the training, generally a good starting point is "
|
614 |
+
"0.5x dim size\nWhat Alpha do you want?\nCancel will default to equal to "
|
615 |
+
"0.5 x network_dim")
|
616 |
+
if ret is None:
|
617 |
+
args['alpha'] = args['net_dim'] / 2
|
618 |
+
else:
|
619 |
+
args['alpha'] = ret
|
620 |
+
|
621 |
+
ret = sd.askinteger(title="resolution", prompt="How large of a resolution do you want to train at?\n"
|
622 |
+
"Cancel will default to 512")
|
623 |
+
if ret is None:
|
624 |
+
args['train_resolution'] = 512
|
625 |
+
else:
|
626 |
+
args['train_resolution'] = ret
|
627 |
+
|
628 |
+
ret = sd.askfloat(title="learning_rate", prompt="What learning rate do you want to use?\n"
|
629 |
+
"Cancel will default to 1e-4")
|
630 |
+
if ret is None:
|
631 |
+
args['learning_rate'] = 1e-4
|
632 |
+
else:
|
633 |
+
args['learning_rate'] = ret
|
634 |
+
|
635 |
+
ret = sd.askfloat(title="text_encoder_lr", prompt="Do you want to set the text_encoder_lr?\n"
|
636 |
+
"Cancel will default to None")
|
637 |
+
if ret is None:
|
638 |
+
args['text_encoder_lr'] = None
|
639 |
+
else:
|
640 |
+
args['text_encoder_lr'] = ret
|
641 |
+
|
642 |
+
ret = sd.askfloat(title="unet_lr", prompt="Do you want to set the unet_lr?\nCancel will default to None")
|
643 |
+
if ret is None:
|
644 |
+
args['unet_lr'] = None
|
645 |
+
else:
|
646 |
+
args['unet_lr'] = ret
|
647 |
+
|
648 |
+
button = ButtonBox("Which scheduler do you want?", ["cosine_with_restarts", "cosine", "polynomial",
|
649 |
+
"constant", "constant_with_warmup", "linear"])
|
650 |
+
button.window.mainloop()
|
651 |
+
args['scheduler'] = button.current_value if button.current_value != "" else "cosine_with_restarts"
|
652 |
+
|
653 |
+
if args['scheduler'] == "cosine_with_restarts":
|
654 |
+
ret = sd.askinteger(title="Cycle Count",
|
655 |
+
prompt="How many times do you want cosine to restart?\nThis is the entire amount of times "
|
656 |
+
"it will restart for the entire training\nCancel will default to 1")
|
657 |
+
if ret is None:
|
658 |
+
args['cosine_restarts'] = 1
|
659 |
+
else:
|
660 |
+
args['cosine_restarts'] = ret
|
661 |
+
|
662 |
+
if args['scheduler'] == "polynomial":
|
663 |
+
ret = sd.askfloat(title="Poly Strength",
|
664 |
+
prompt="What power do you want to set your polynomial to?\nhigher power means that the "
|
665 |
+
"model reduces the learning more more aggressively from initial training.\n1 = "
|
666 |
+
"linear\nCancel sets to 1")
|
667 |
+
if ret is None:
|
668 |
+
args['scheduler_power'] = 1
|
669 |
+
else:
|
670 |
+
args['scheduler_power'] = ret
|
671 |
+
|
672 |
+
ret = mb.askyesno(message="Do you want to save epochs as it trains?")
|
673 |
+
if ret:
|
674 |
+
ret = sd.askinteger(title="save_epoch",
|
675 |
+
prompt="How often do you want to save epochs?\nCancel will default to 1")
|
676 |
+
if ret is None:
|
677 |
+
args['save_every_n_epochs'] = 1
|
678 |
+
else:
|
679 |
+
args['save_every_n_epochs'] = ret
|
680 |
+
|
681 |
+
ret = mb.askyesno(message="Do you want to shuffle captions?")
|
682 |
+
if ret:
|
683 |
+
args['shuffle_captions'] = True
|
684 |
+
else:
|
685 |
+
args['shuffle_captions'] = False
|
686 |
+
|
687 |
+
ret = mb.askyesno(message="Do you want to keep some tokens at the front of your captions?")
|
688 |
+
if ret:
|
689 |
+
ret = sd.askinteger(title="keep_tokens", prompt="How many do you want to keep at the front?"
|
690 |
+
"\nCancel will default to 1")
|
691 |
+
if ret is None:
|
692 |
+
args['keep_tokens'] = 1
|
693 |
+
else:
|
694 |
+
args['keep_tokens'] = ret
|
695 |
+
|
696 |
+
ret = mb.askyesno(message="Do you want to have a warmup ratio?")
|
697 |
+
if ret:
|
698 |
+
ret = sd.askfloat(title="warmup_ratio", prompt="What is the ratio of steps to use as warmup "
|
699 |
+
"steps?\nCancel will default to None")
|
700 |
+
if ret is None:
|
701 |
+
args['warmup_lr_ratio'] = None
|
702 |
+
else:
|
703 |
+
args['warmup_lr_ratio'] = ret
|
704 |
+
|
705 |
+
ret = mb.askyesno(message="Do you want to change the name of output checkpoints?")
|
706 |
+
if ret:
|
707 |
+
ret = sd.askstring(title="output_name", prompt="What do you want your output name to be?\n"
|
708 |
+
"Cancel keeps outputs the original")
|
709 |
+
if ret:
|
710 |
+
args['change_output_name'] = ret
|
711 |
+
else:
|
712 |
+
args['change_output_name'] = None
|
713 |
+
|
714 |
+
ret = sd.askstring(title="comment",
|
715 |
+
prompt="Do you want to set a comment that gets put into the metadata?\nA good use of this would "
|
716 |
+
"be to include how to use, such as activation keywords.\nCancel will leave empty")
|
717 |
+
if ret is None:
|
718 |
+
args['training_comment'] = ret
|
719 |
+
else:
|
720 |
+
args['training_comment'] = None
|
721 |
+
|
722 |
+
ret = mb.askyesno(message="Do you want to train only one of unet and text encoder?")
|
723 |
+
if ret:
|
724 |
+
if ret:
|
725 |
+
button = ButtonBox("Which do you want to train with?", ["unet_only", "text_only"])
|
726 |
+
button.window.mainloop()
|
727 |
+
if button.current_value != "":
|
728 |
+
args[button.current_value] = True
|
729 |
+
|
730 |
+
ret = mb.askyesno(message="Do you want to save a txt file that contains a list\n"
|
731 |
+
"of all tags that you have used in your training data?\n")
|
732 |
+
if ret:
|
733 |
+
args['tag_occurrence_txt_file'] = True
|
734 |
+
button = ButtonBox("How do you want tags to be ordered?", ["alphabetically", "occurrence-ly"])
|
735 |
+
button.window.mainloop()
|
736 |
+
if button.current_value == "alphabetically":
|
737 |
+
args['sort_tag_occurrence_alphabetically'] = True
|
738 |
+
|
739 |
+
ret = mb.askyesno(message="Do you want to use caption dropout?")
|
740 |
+
if ret:
|
741 |
+
ret = mb.askyesno(message="Do you want full caption files to dropout randomly?")
|
742 |
+
if ret:
|
743 |
+
ret = sd.askinteger(title="Caption_File_Dropout",
|
744 |
+
prompt="How often do you want caption files to drop out?\n"
|
745 |
+
"enter a number from 0 to 100 that is the percentage chance of dropout\n"
|
746 |
+
"Cancel sets to 0")
|
747 |
+
if ret and 0 <= ret <= 100:
|
748 |
+
args['caption_dropout_rate'] = ret / 100.0
|
749 |
+
|
750 |
+
ret = mb.askyesno(message="Do you want to have full epochs have no captions?")
|
751 |
+
if ret:
|
752 |
+
ret = sd.askinteger(title="Caption_epoch_dropout", prompt="The number set here is how often you will have an"
|
753 |
+
"epoch with no captions\nSo if you set 3, then every"
|
754 |
+
"three epochs will not have captions (3, 6, 9)\n"
|
755 |
+
"Cancel will set to None")
|
756 |
+
if ret:
|
757 |
+
args['caption_dropout_every_n_epochs'] = ret
|
758 |
+
|
759 |
+
ret = mb.askyesno(message="Do you want to have tags to randomly drop?")
|
760 |
+
if ret:
|
761 |
+
ret = sd.askinteger(title="Caption_tag_dropout", prompt="How often do you want tags to randomly drop out?\n"
|
762 |
+
"Enter a number between 0 and 100, that is the percentage"
|
763 |
+
"chance of dropout.\nCancel sets to 0")
|
764 |
+
if ret and 0 <= ret <= 100:
|
765 |
+
args['caption_tag_dropout_rate'] = ret / 100.0
|
766 |
+
|
767 |
+
ret = mb.askyesno(message="Do you want to use noise offset? Noise offset seems to allow for SD to better generate\n"
|
768 |
+
"darker or lighter images using this than normal.")
|
769 |
+
if ret:
|
770 |
+
ret = sd.askfloat(title="noise_offset", prompt="What value do you want to set? recommended value is 0.1,\n"
|
771 |
+
"but it can go higher. Cancel defaults to 0.1")
|
772 |
+
if ret:
|
773 |
+
args['noise_offset'] = ret
|
774 |
+
else:
|
775 |
+
args['noise_offset'] = 0.1
|
776 |
+
return args
|
777 |
+
|
778 |
+
|
779 |
+
def save_json(path, obj: dict) -> None:
|
780 |
+
fp = open(os.path.join(path, f"config-{time.time()}.json"), "w")
|
781 |
+
json.dump(obj, fp=fp, indent=4)
|
782 |
+
fp.close()
|
783 |
+
|
784 |
+
|
785 |
+
def load_json(path, obj: dict) -> dict:
|
786 |
+
with open(path) as f:
|
787 |
+
json_obj = json.loads(f.read())
|
788 |
+
print("loaded json, setting variables...")
|
789 |
+
ui_name_scheme = {"pretrained_model_name_or_path": "base_model", "logging_dir": "log_dir",
|
790 |
+
"train_data_dir": "img_folder", "reg_data_dir": "reg_img_folder",
|
791 |
+
"output_dir": "output_folder", "max_resolution": "train_resolution",
|
792 |
+
"lr_scheduler": "scheduler", "lr_warmup": "warmup_lr_ratio",
|
793 |
+
"train_batch_size": "batch_size", "epoch": "num_epochs",
|
794 |
+
"save_at_n_epochs": "save_every_n_epochs", "num_cpu_threads_per_process": "num_workers",
|
795 |
+
"enable_bucket": "buckets", "save_model_as": "save_as", "shuffle_caption": "shuffle_captions",
|
796 |
+
"resume": "load_previous_save_state", "network_dim": "net_dim",
|
797 |
+
"gradient_accumulation_steps": "gradient_acc_steps", "output_name": "change_output_name",
|
798 |
+
"network_alpha": "alpha", "lr_scheduler_num_cycles": "cosine_restarts",
|
799 |
+
"lr_scheduler_power": "scheduler_power"}
|
800 |
+
|
801 |
+
for key in list(json_obj):
|
802 |
+
if key in ui_name_scheme:
|
803 |
+
json_obj[ui_name_scheme[key]] = json_obj[key]
|
804 |
+
if ui_name_scheme[key] in {"batch_size", "num_epochs"}:
|
805 |
+
try:
|
806 |
+
json_obj[ui_name_scheme[key]] = int(json_obj[ui_name_scheme[key]])
|
807 |
+
except ValueError:
|
808 |
+
print(f"attempting to load {key} from json failed as input isn't an integer")
|
809 |
+
quit(1)
|
810 |
+
|
811 |
+
for key in list(json_obj):
|
812 |
+
if obj["json_load_skip_list"] and key in obj["json_load_skip_list"]:
|
813 |
+
continue
|
814 |
+
if key in obj:
|
815 |
+
if key in {"keep_tokens", "warmup_lr_ratio"}:
|
816 |
+
json_obj[key] = int(json_obj[key]) if json_obj[key] is not None else None
|
817 |
+
if key in {"learning_rate", "unet_lr", "text_encoder_lr"}:
|
818 |
+
json_obj[key] = float(json_obj[key]) if json_obj[key] is not None else None
|
819 |
+
if obj[key] != json_obj[key]:
|
820 |
+
print_change(key, obj[key], json_obj[key])
|
821 |
+
obj[key] = json_obj[key]
|
822 |
+
print("completed changing variables.")
|
823 |
+
return obj
|
824 |
+
|
825 |
+
|
826 |
+
def print_change(value, old, new):
|
827 |
+
print(f"{value} changed from {old} to {new}")
|
828 |
+
|
829 |
+
|
830 |
+
class ButtonBox:
|
831 |
+
def __init__(self, label: str, button_name_list: list[str]) -> None:
|
832 |
+
self.window = tk.Tk()
|
833 |
+
self.button_list = []
|
834 |
+
self.current_value = ""
|
835 |
+
|
836 |
+
self.window.attributes("-topmost", True)
|
837 |
+
self.window.resizable(False, False)
|
838 |
+
self.window.eval('tk::PlaceWindow . center')
|
839 |
+
|
840 |
+
def del_window():
|
841 |
+
self.window.quit()
|
842 |
+
self.window.destroy()
|
843 |
+
|
844 |
+
self.window.protocol("WM_DELETE_WINDOW", del_window)
|
845 |
+
tk.Label(text=label, master=self.window).pack()
|
846 |
+
for button in button_name_list:
|
847 |
+
self.button_list.append(ttk.Button(text=button, master=self.window,
|
848 |
+
command=partial(self.set_current_value, button)))
|
849 |
+
self.button_list[-1].pack()
|
850 |
+
|
851 |
+
def set_current_value(self, value):
|
852 |
+
self.current_value = value
|
853 |
+
self.window.quit()
|
854 |
+
self.window.destroy()
|
855 |
+
|
856 |
+
|
857 |
+
root = tk.Tk()
|
858 |
+
root.attributes('-topmost', True)
|
859 |
+
root.withdraw()
|
860 |
+
|
861 |
+
if __name__ == "__main__":
|
862 |
+
main()
|
lycoris/kohya.py
CHANGED
@@ -5,7 +5,6 @@
|
|
5 |
# https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
|
6 |
|
7 |
import math
|
8 |
-
from warnings import warn
|
9 |
import os
|
10 |
from typing import List
|
11 |
import torch
|
@@ -28,22 +27,6 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
|
28 |
}[algo]
|
29 |
|
30 |
print(f'Using rank adaptation algo: {algo}')
|
31 |
-
|
32 |
-
if (algo == 'loha'
|
33 |
-
and not kwargs.get('no_dim_warn', False)
|
34 |
-
and (network_dim>64 or conv_dim>64)):
|
35 |
-
print('='*20 + 'WARNING' + '='*20)
|
36 |
-
warn(
|
37 |
-
(
|
38 |
-
"You are not supposed to use dim>64 (64*64 = 4096, it already has enough rank)"
|
39 |
-
"in Hadamard Product representation!\n"
|
40 |
-
"Please consider use lower dim or disable this warning with --network_args no_dim_warn=True\n"
|
41 |
-
"If you just want to use high dim loha, please consider use lower lr."
|
42 |
-
),
|
43 |
-
stacklevel=2,
|
44 |
-
)
|
45 |
-
print('='*20 + 'WARNING' + '='*20)
|
46 |
-
|
47 |
network = LoRANetwork(
|
48 |
text_encoder, unet,
|
49 |
multiplier=multiplier,
|
|
|
5 |
# https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
|
6 |
|
7 |
import math
|
|
|
8 |
import os
|
9 |
from typing import List
|
10 |
import torch
|
|
|
27 |
}[algo]
|
28 |
|
29 |
print(f'Using rank adaptation algo: {algo}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
network = LoRANetwork(
|
31 |
text_encoder, unet,
|
32 |
multiplier=multiplier,
|
lycoris/loha.py
CHANGED
@@ -36,12 +36,7 @@ class LohaModule(nn.Module):
|
|
36 |
Hadamard product Implementaion for Low Rank Adaptation
|
37 |
"""
|
38 |
|
39 |
-
def __init__(
|
40 |
-
self,
|
41 |
-
lora_name,
|
42 |
-
org_module: nn.Module,
|
43 |
-
multiplier=1.0, lora_dim=4, alpha=1, dropout=0.,
|
44 |
-
):
|
45 |
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
46 |
super().__init__()
|
47 |
self.lora_name = lora_name
|
|
|
36 |
Hadamard product Implementaion for Low Rank Adaptation
|
37 |
"""
|
38 |
|
39 |
+
def __init__(self, lora_name, org_module: nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=0.):
|
|
|
|
|
|
|
|
|
|
|
40 |
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
41 |
super().__init__()
|
42 |
self.lora_name = lora_name
|
lycoris/utils.py
CHANGED
@@ -28,13 +28,11 @@ def extract_conv(
|
|
28 |
assert 1>=mode_param>=0
|
29 |
min_s = torch.max(S)*mode_param
|
30 |
lora_rank = torch.sum(S>min_s)
|
31 |
-
elif mode=='
|
32 |
assert 1>=mode_param>=0
|
33 |
s_cum = torch.cumsum(S, dim=0)
|
34 |
min_cum_sum = mode_param * torch.sum(S)
|
35 |
lora_rank = torch.sum(s_cum<min_cum_sum)
|
36 |
-
else:
|
37 |
-
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
|
38 |
lora_rank = max(1, lora_rank)
|
39 |
lora_rank = min(out_ch, in_ch, lora_rank)
|
40 |
|
@@ -90,13 +88,11 @@ def extract_linear(
|
|
90 |
assert 1>=mode_param>=0
|
91 |
min_s = torch.max(S)*mode_param
|
92 |
lora_rank = torch.sum(S>min_s)
|
93 |
-
elif mode=='
|
94 |
assert 1>=mode_param>=0
|
95 |
s_cum = torch.cumsum(S, dim=0)
|
96 |
min_cum_sum = mode_param * torch.sum(S)
|
97 |
lora_rank = torch.sum(s_cum<min_cum_sum)
|
98 |
-
else:
|
99 |
-
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
|
100 |
lora_rank = max(1, lora_rank)
|
101 |
lora_rank = min(out_ch, in_ch, lora_rank)
|
102 |
|
@@ -263,69 +259,6 @@ def merge_locon(
|
|
263 |
child_module.weight += (alpha.to(device)/rank * scale * delta).cpu()
|
264 |
del delta
|
265 |
|
266 |
-
merge(
|
267 |
-
LORA_PREFIX_TEXT_ENCODER,
|
268 |
-
base_model[0],
|
269 |
-
TEXT_ENCODER_TARGET_REPLACE_MODULE
|
270 |
-
)
|
271 |
-
merge(
|
272 |
-
LORA_PREFIX_UNET,
|
273 |
-
base_model[2],
|
274 |
-
UNET_TARGET_REPLACE_MODULE
|
275 |
-
)
|
276 |
-
|
277 |
-
|
278 |
-
def merge_loha(
|
279 |
-
base_model,
|
280 |
-
loha_state_dict: Dict[str, torch.TensorType],
|
281 |
-
scale: float = 1.0,
|
282 |
-
device = 'cpu'
|
283 |
-
):
|
284 |
-
UNET_TARGET_REPLACE_MODULE = [
|
285 |
-
"Transformer2DModel",
|
286 |
-
"Attention",
|
287 |
-
"ResnetBlock2D",
|
288 |
-
"Downsample2D",
|
289 |
-
"Upsample2D"
|
290 |
-
]
|
291 |
-
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
292 |
-
LORA_PREFIX_UNET = 'lora_unet'
|
293 |
-
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
294 |
-
def merge(
|
295 |
-
prefix,
|
296 |
-
root_module: torch.nn.Module,
|
297 |
-
target_replace_modules
|
298 |
-
):
|
299 |
-
temp = {}
|
300 |
-
|
301 |
-
for name, module in tqdm(list(root_module.named_modules())):
|
302 |
-
if module.__class__.__name__ in target_replace_modules:
|
303 |
-
temp[name] = {}
|
304 |
-
for child_name, child_module in module.named_modules():
|
305 |
-
layer = child_module.__class__.__name__
|
306 |
-
if layer not in {'Linear', 'Conv2d'}:
|
307 |
-
continue
|
308 |
-
lora_name = prefix + '.' + name + '.' + child_name
|
309 |
-
lora_name = lora_name.replace('.', '_')
|
310 |
-
|
311 |
-
w1a = loha_state_dict[f'{lora_name}.hada_w1_a'].float().to(device)
|
312 |
-
w1b = loha_state_dict[f'{lora_name}.hada_w1_b'].float().to(device)
|
313 |
-
w2a = loha_state_dict[f'{lora_name}.hada_w2_a'].float().to(device)
|
314 |
-
w2b = loha_state_dict[f'{lora_name}.hada_w2_b'].float().to(device)
|
315 |
-
alpha = loha_state_dict[f'{lora_name}.alpha'].float().to(device)
|
316 |
-
dim = w1b.shape[0]
|
317 |
-
|
318 |
-
delta = (w1a @ w1b) * (w2a @ w2b)
|
319 |
-
delta = delta.reshape(child_module.weight.shape)
|
320 |
-
|
321 |
-
if layer == 'Conv2d':
|
322 |
-
child_module.weight.requires_grad_(False)
|
323 |
-
child_module.weight += (alpha.to(device)/dim * scale * delta).cpu()
|
324 |
-
elif layer == 'Linear':
|
325 |
-
child_module.weight.requires_grad_(False)
|
326 |
-
child_module.weight += (alpha.to(device)/dim * scale * delta).cpu()
|
327 |
-
del delta
|
328 |
-
|
329 |
merge(
|
330 |
LORA_PREFIX_TEXT_ENCODER,
|
331 |
base_model[0],
|
|
|
28 |
assert 1>=mode_param>=0
|
29 |
min_s = torch.max(S)*mode_param
|
30 |
lora_rank = torch.sum(S>min_s)
|
31 |
+
elif mode=='percentile':
|
32 |
assert 1>=mode_param>=0
|
33 |
s_cum = torch.cumsum(S, dim=0)
|
34 |
min_cum_sum = mode_param * torch.sum(S)
|
35 |
lora_rank = torch.sum(s_cum<min_cum_sum)
|
|
|
|
|
36 |
lora_rank = max(1, lora_rank)
|
37 |
lora_rank = min(out_ch, in_ch, lora_rank)
|
38 |
|
|
|
88 |
assert 1>=mode_param>=0
|
89 |
min_s = torch.max(S)*mode_param
|
90 |
lora_rank = torch.sum(S>min_s)
|
91 |
+
elif mode=='percentile':
|
92 |
assert 1>=mode_param>=0
|
93 |
s_cum = torch.cumsum(S, dim=0)
|
94 |
min_cum_sum = mode_param * torch.sum(S)
|
95 |
lora_rank = torch.sum(s_cum<min_cum_sum)
|
|
|
|
|
96 |
lora_rank = max(1, lora_rank)
|
97 |
lora_rank = min(out_ch, in_ch, lora_rank)
|
98 |
|
|
|
259 |
child_module.weight += (alpha.to(device)/rank * scale * delta).cpu()
|
260 |
del delta
|
261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
merge(
|
263 |
LORA_PREFIX_TEXT_ENCODER,
|
264 |
base_model[0],
|
networks/check_lora_weights.py
CHANGED
@@ -21,7 +21,7 @@ def main(file):
|
|
21 |
|
22 |
for key, value in values:
|
23 |
value = value.to(torch.float32)
|
24 |
-
print(f"{key},{
|
25 |
|
26 |
|
27 |
if __name__ == '__main__':
|
|
|
21 |
|
22 |
for key, value in values:
|
23 |
value = value.to(torch.float32)
|
24 |
+
print(f"{key},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
|
25 |
|
26 |
|
27 |
if __name__ == '__main__':
|
networks/extract_lora_from_models.py
CHANGED
@@ -45,13 +45,8 @@ def svd(args):
|
|
45 |
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
|
46 |
|
47 |
# create LoRA network to extract weights: Use dim (rank) as alpha
|
48 |
-
|
49 |
-
|
50 |
-
else:
|
51 |
-
kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim}
|
52 |
-
|
53 |
-
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o, **kwargs)
|
54 |
-
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t, **kwargs)
|
55 |
assert len(lora_network_o.text_encoder_loras) == len(
|
56 |
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
57 |
|
@@ -90,28 +85,13 @@ def svd(args):
|
|
90 |
|
91 |
# make LoRA with svd
|
92 |
print("calculating by svd")
|
|
|
93 |
lora_weights = {}
|
94 |
with torch.no_grad():
|
95 |
for lora_name, mat in tqdm(list(diffs.items())):
|
96 |
-
# if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3
|
97 |
conv2d = (len(mat.size()) == 4)
|
98 |
-
kernel_size = None if not conv2d else mat.size()[2:4]
|
99 |
-
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
100 |
-
|
101 |
-
rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim
|
102 |
-
out_dim, in_dim = mat.size()[0:2]
|
103 |
-
|
104 |
-
if args.device:
|
105 |
-
mat = mat.to(args.device)
|
106 |
-
|
107 |
-
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
108 |
-
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
109 |
-
|
110 |
if conv2d:
|
111 |
-
|
112 |
-
mat = mat.flatten(start_dim=1)
|
113 |
-
else:
|
114 |
-
mat = mat.squeeze()
|
115 |
|
116 |
U, S, Vh = torch.linalg.svd(mat)
|
117 |
|
@@ -128,27 +108,30 @@ def svd(args):
|
|
128 |
U = U.clamp(low_val, hi_val)
|
129 |
Vh = Vh.clamp(low_val, hi_val)
|
130 |
|
131 |
-
if conv2d:
|
132 |
-
U = U.reshape(out_dim, rank, 1, 1)
|
133 |
-
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
134 |
-
|
135 |
-
U = U.to("cpu").contiguous()
|
136 |
-
Vh = Vh.to("cpu").contiguous()
|
137 |
-
|
138 |
lora_weights[lora_name] = (U, Vh)
|
139 |
|
140 |
# make state dict for LoRA
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
lora_sd[lora_name + '.lora_down.weight'] = down_weight
|
145 |
-
lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0])
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
150 |
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
print(f"Loading extracted LoRA weights: {info}")
|
153 |
|
154 |
dir_name = os.path.dirname(args.save_to)
|
@@ -156,9 +139,9 @@ def svd(args):
|
|
156 |
os.makedirs(dir_name, exist_ok=True)
|
157 |
|
158 |
# minimum metadata
|
159 |
-
metadata = {"
|
160 |
|
161 |
-
|
162 |
print(f"LoRA weights are saved to: {args.save_to}")
|
163 |
|
164 |
|
@@ -175,8 +158,6 @@ if __name__ == '__main__':
|
|
175 |
parser.add_argument("--save_to", type=str, default=None,
|
176 |
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
177 |
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
178 |
-
parser.add_argument("--conv_dim", type=int, default=None,
|
179 |
-
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)")
|
180 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
181 |
|
182 |
args = parser.parse_args()
|
|
|
45 |
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
|
46 |
|
47 |
# create LoRA network to extract weights: Use dim (rank) as alpha
|
48 |
+
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o)
|
49 |
+
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t)
|
|
|
|
|
|
|
|
|
|
|
50 |
assert len(lora_network_o.text_encoder_loras) == len(
|
51 |
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
52 |
|
|
|
85 |
|
86 |
# make LoRA with svd
|
87 |
print("calculating by svd")
|
88 |
+
rank = args.dim
|
89 |
lora_weights = {}
|
90 |
with torch.no_grad():
|
91 |
for lora_name, mat in tqdm(list(diffs.items())):
|
|
|
92 |
conv2d = (len(mat.size()) == 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
if conv2d:
|
94 |
+
mat = mat.squeeze()
|
|
|
|
|
|
|
95 |
|
96 |
U, S, Vh = torch.linalg.svd(mat)
|
97 |
|
|
|
108 |
U = U.clamp(low_val, hi_val)
|
109 |
Vh = Vh.clamp(low_val, hi_val)
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
lora_weights[lora_name] = (U, Vh)
|
112 |
|
113 |
# make state dict for LoRA
|
114 |
+
lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict
|
115 |
+
lora_sd = lora_network_o.state_dict()
|
116 |
+
print(f"LoRA has {len(lora_sd)} weights.")
|
|
|
|
|
117 |
|
118 |
+
for key in list(lora_sd.keys()):
|
119 |
+
if "alpha" in key:
|
120 |
+
continue
|
121 |
+
|
122 |
+
lora_name = key.split('.')[0]
|
123 |
+
i = 0 if "lora_up" in key else 1
|
124 |
|
125 |
+
weights = lora_weights[lora_name][i]
|
126 |
+
# print(key, i, weights.size(), lora_sd[key].size())
|
127 |
+
if len(lora_sd[key].size()) == 4:
|
128 |
+
weights = weights.unsqueeze(2).unsqueeze(3)
|
129 |
+
|
130 |
+
assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
|
131 |
+
lora_sd[key] = weights
|
132 |
+
|
133 |
+
# load state dict to LoRA and save it
|
134 |
+
info = lora_network_o.load_state_dict(lora_sd)
|
135 |
print(f"Loading extracted LoRA weights: {info}")
|
136 |
|
137 |
dir_name = os.path.dirname(args.save_to)
|
|
|
139 |
os.makedirs(dir_name, exist_ok=True)
|
140 |
|
141 |
# minimum metadata
|
142 |
+
metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
|
143 |
|
144 |
+
lora_network_o.save_weights(args.save_to, save_dtype, metadata)
|
145 |
print(f"LoRA weights are saved to: {args.save_to}")
|
146 |
|
147 |
|
|
|
158 |
parser.add_argument("--save_to", type=str, default=None,
|
159 |
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
160 |
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
|
|
|
|
161 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
162 |
|
163 |
args = parser.parse_args()
|
networks/lora.py
CHANGED
@@ -6,7 +6,6 @@
|
|
6 |
import math
|
7 |
import os
|
8 |
from typing import List
|
9 |
-
import numpy as np
|
10 |
import torch
|
11 |
|
12 |
from library import train_util
|
@@ -21,34 +20,22 @@ class LoRAModule(torch.nn.Module):
|
|
21 |
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
22 |
super().__init__()
|
23 |
self.lora_name = lora_name
|
|
|
24 |
|
25 |
if org_module.__class__.__name__ == 'Conv2d':
|
26 |
in_dim = org_module.in_channels
|
27 |
out_dim = org_module.out_channels
|
|
|
|
|
28 |
else:
|
29 |
in_dim = org_module.in_features
|
30 |
out_dim = org_module.out_features
|
31 |
-
|
32 |
-
|
33 |
-
# self.lora_dim = min(lora_dim, in_dim, out_dim)
|
34 |
-
# if self.lora_dim != lora_dim:
|
35 |
-
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
36 |
-
# else:
|
37 |
-
self.lora_dim = lora_dim
|
38 |
-
|
39 |
-
if org_module.__class__.__name__ == 'Conv2d':
|
40 |
-
kernel_size = org_module.kernel_size
|
41 |
-
stride = org_module.stride
|
42 |
-
padding = org_module.padding
|
43 |
-
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
44 |
-
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
45 |
-
else:
|
46 |
-
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
47 |
-
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
48 |
|
49 |
if type(alpha) == torch.Tensor:
|
50 |
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
51 |
-
alpha =
|
52 |
self.scale = alpha / self.lora_dim
|
53 |
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
54 |
|
@@ -58,192 +45,69 @@ class LoRAModule(torch.nn.Module):
|
|
58 |
|
59 |
self.multiplier = multiplier
|
60 |
self.org_module = org_module # remove in applying
|
61 |
-
self.region = None
|
62 |
-
self.region_mask = None
|
63 |
|
64 |
def apply_to(self):
|
65 |
self.org_forward = self.org_module.forward
|
66 |
self.org_module.forward = self.forward
|
67 |
del self.org_module
|
68 |
|
69 |
-
def set_region(self, region):
|
70 |
-
self.region = region
|
71 |
-
self.region_mask = None
|
72 |
-
|
73 |
def forward(self, x):
|
74 |
-
|
75 |
-
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
76 |
-
|
77 |
-
# regional LoRA FIXME same as additional-network extension
|
78 |
-
if x.size()[1] % 77 == 0:
|
79 |
-
# print(f"LoRA for context: {self.lora_name}")
|
80 |
-
self.region = None
|
81 |
-
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
82 |
-
|
83 |
-
# calculate region mask first time
|
84 |
-
if self.region_mask is None:
|
85 |
-
if len(x.size()) == 4:
|
86 |
-
h, w = x.size()[2:4]
|
87 |
-
else:
|
88 |
-
seq_len = x.size()[1]
|
89 |
-
ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
|
90 |
-
h = int(self.region.size()[0] / ratio + .5)
|
91 |
-
w = seq_len // h
|
92 |
-
|
93 |
-
r = self.region.to(x.device)
|
94 |
-
if r.dtype == torch.bfloat16:
|
95 |
-
r = r.to(torch.float)
|
96 |
-
r = r.unsqueeze(0).unsqueeze(1)
|
97 |
-
# print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
|
98 |
-
r = torch.nn.functional.interpolate(r, (h, w), mode='bilinear')
|
99 |
-
r = r.to(x.dtype)
|
100 |
-
|
101 |
-
if len(x.size()) == 3:
|
102 |
-
r = torch.reshape(r, (1, x.size()[1], -1))
|
103 |
-
|
104 |
-
self.region_mask = r
|
105 |
-
|
106 |
-
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask
|
107 |
|
108 |
|
109 |
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
110 |
if network_dim is None:
|
111 |
network_dim = 4 # default
|
112 |
-
|
113 |
-
# extract dim/alpha for conv2d, and block dim
|
114 |
-
conv_dim = kwargs.get('conv_dim', None)
|
115 |
-
conv_alpha = kwargs.get('conv_alpha', None)
|
116 |
-
if conv_dim is not None:
|
117 |
-
conv_dim = int(conv_dim)
|
118 |
-
if conv_alpha is None:
|
119 |
-
conv_alpha = 1.0
|
120 |
-
else:
|
121 |
-
conv_alpha = float(conv_alpha)
|
122 |
-
|
123 |
-
"""
|
124 |
-
block_dims = kwargs.get("block_dims")
|
125 |
-
block_alphas = None
|
126 |
-
|
127 |
-
if block_dims is not None:
|
128 |
-
block_dims = [int(d) for d in block_dims.split(',')]
|
129 |
-
assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
|
130 |
-
block_alphas = kwargs.get("block_alphas")
|
131 |
-
if block_alphas is None:
|
132 |
-
block_alphas = [1] * len(block_dims)
|
133 |
-
else:
|
134 |
-
block_alphas = [int(a) for a in block_alphas(',')]
|
135 |
-
assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
136 |
-
|
137 |
-
conv_block_dims = kwargs.get("conv_block_dims")
|
138 |
-
conv_block_alphas = None
|
139 |
-
|
140 |
-
if conv_block_dims is not None:
|
141 |
-
conv_block_dims = [int(d) for d in conv_block_dims.split(',')]
|
142 |
-
assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
|
143 |
-
conv_block_alphas = kwargs.get("conv_block_alphas")
|
144 |
-
if conv_block_alphas is None:
|
145 |
-
conv_block_alphas = [1] * len(conv_block_dims)
|
146 |
-
else:
|
147 |
-
conv_block_alphas = [int(a) for a in conv_block_alphas(',')]
|
148 |
-
assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
149 |
-
"""
|
150 |
-
|
151 |
-
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim,
|
152 |
-
alpha=network_alpha, conv_lora_dim=conv_dim, conv_alpha=conv_alpha)
|
153 |
return network
|
154 |
|
155 |
|
156 |
-
def create_network_from_weights(multiplier, file, vae, text_encoder, unet,
|
157 |
-
if
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
weights_sd = torch.load(file, map_location='cpu')
|
163 |
|
164 |
-
# get dim
|
165 |
-
|
166 |
-
|
167 |
for key, value in weights_sd.items():
|
168 |
-
if '
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
# print(lora_name, value.size(), dim)
|
178 |
-
|
179 |
-
# support old LoRA without alpha
|
180 |
-
for key in modules_dim.keys():
|
181 |
-
if key not in modules_alpha:
|
182 |
-
modules_alpha = modules_dim[key]
|
183 |
-
|
184 |
-
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
185 |
network.weights_sd = weights_sd
|
186 |
return network
|
187 |
|
188 |
|
189 |
class LoRANetwork(torch.nn.Module):
|
190 |
-
# is it possible to apply conv_in and conv_out?
|
191 |
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
192 |
-
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
193 |
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
194 |
LORA_PREFIX_UNET = 'lora_unet'
|
195 |
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
196 |
|
197 |
-
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1
|
198 |
super().__init__()
|
199 |
self.multiplier = multiplier
|
200 |
-
|
201 |
self.lora_dim = lora_dim
|
202 |
self.alpha = alpha
|
203 |
-
self.conv_lora_dim = conv_lora_dim
|
204 |
-
self.conv_alpha = conv_alpha
|
205 |
-
|
206 |
-
if modules_dim is not None:
|
207 |
-
print(f"create LoRA network from weights")
|
208 |
-
else:
|
209 |
-
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
210 |
-
|
211 |
-
self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None
|
212 |
-
if self.apply_to_conv2d_3x3:
|
213 |
-
if self.conv_alpha is None:
|
214 |
-
self.conv_alpha = self.alpha
|
215 |
-
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
216 |
|
217 |
# create module instances
|
218 |
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
219 |
loras = []
|
220 |
for name, module in root_module.named_modules():
|
221 |
if module.__class__.__name__ in target_replace_modules:
|
222 |
-
# TODO get block index here
|
223 |
for child_name, child_module in module.named_modules():
|
224 |
-
|
225 |
-
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
226 |
-
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
227 |
-
if is_linear or is_conv2d:
|
228 |
lora_name = prefix + '.' + name + '.' + child_name
|
229 |
lora_name = lora_name.replace('.', '_')
|
230 |
-
|
231 |
-
if modules_dim is not None:
|
232 |
-
if lora_name not in modules_dim:
|
233 |
-
continue # no LoRA module in this weights file
|
234 |
-
dim = modules_dim[lora_name]
|
235 |
-
alpha = modules_alpha[lora_name]
|
236 |
-
else:
|
237 |
-
if is_linear or is_conv2d_1x1:
|
238 |
-
dim = self.lora_dim
|
239 |
-
alpha = self.alpha
|
240 |
-
elif self.apply_to_conv2d_3x3:
|
241 |
-
dim = self.conv_lora_dim
|
242 |
-
alpha = self.conv_alpha
|
243 |
-
else:
|
244 |
-
continue
|
245 |
-
|
246 |
-
lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
|
247 |
loras.append(lora)
|
248 |
return loras
|
249 |
|
@@ -251,12 +115,7 @@ class LoRANetwork(torch.nn.Module):
|
|
251 |
text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
252 |
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
253 |
|
254 |
-
|
255 |
-
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
256 |
-
if modules_dim is not None or self.conv_lora_dim is not None:
|
257 |
-
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
258 |
-
|
259 |
-
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules)
|
260 |
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
261 |
|
262 |
self.weights_sd = None
|
@@ -267,11 +126,6 @@ class LoRANetwork(torch.nn.Module):
|
|
267 |
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
268 |
names.add(lora.lora_name)
|
269 |
|
270 |
-
def set_multiplier(self, multiplier):
|
271 |
-
self.multiplier = multiplier
|
272 |
-
for lora in self.text_encoder_loras + self.unet_loras:
|
273 |
-
lora.multiplier = self.multiplier
|
274 |
-
|
275 |
def load_weights(self, file):
|
276 |
if os.path.splitext(file)[1] == '.safetensors':
|
277 |
from safetensors.torch import load_file, safe_open
|
@@ -381,18 +235,3 @@ class LoRANetwork(torch.nn.Module):
|
|
381 |
save_file(state_dict, file, metadata)
|
382 |
else:
|
383 |
torch.save(state_dict, file)
|
384 |
-
|
385 |
-
@ staticmethod
|
386 |
-
def set_regions(networks, image):
|
387 |
-
image = image.astype(np.float32) / 255.0
|
388 |
-
for i, network in enumerate(networks[:3]):
|
389 |
-
# NOTE: consider averaging overwrapping area
|
390 |
-
region = image[:, :, i]
|
391 |
-
if region.max() == 0:
|
392 |
-
continue
|
393 |
-
region = torch.tensor(region)
|
394 |
-
network.set_region(region)
|
395 |
-
|
396 |
-
def set_region(self, region):
|
397 |
-
for lora in self.unet_loras:
|
398 |
-
lora.set_region(region)
|
|
|
6 |
import math
|
7 |
import os
|
8 |
from typing import List
|
|
|
9 |
import torch
|
10 |
|
11 |
from library import train_util
|
|
|
20 |
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
21 |
super().__init__()
|
22 |
self.lora_name = lora_name
|
23 |
+
self.lora_dim = lora_dim
|
24 |
|
25 |
if org_module.__class__.__name__ == 'Conv2d':
|
26 |
in_dim = org_module.in_channels
|
27 |
out_dim = org_module.out_channels
|
28 |
+
self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
|
29 |
+
self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
|
30 |
else:
|
31 |
in_dim = org_module.in_features
|
32 |
out_dim = org_module.out_features
|
33 |
+
self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
|
34 |
+
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
if type(alpha) == torch.Tensor:
|
37 |
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
38 |
+
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
39 |
self.scale = alpha / self.lora_dim
|
40 |
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
41 |
|
|
|
45 |
|
46 |
self.multiplier = multiplier
|
47 |
self.org_module = org_module # remove in applying
|
|
|
|
|
48 |
|
49 |
def apply_to(self):
|
50 |
self.org_forward = self.org_module.forward
|
51 |
self.org_module.forward = self.forward
|
52 |
del self.org_module
|
53 |
|
|
|
|
|
|
|
|
|
54 |
def forward(self, x):
|
55 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
|
58 |
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
59 |
if network_dim is None:
|
60 |
network_dim = 4 # default
|
61 |
+
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
return network
|
63 |
|
64 |
|
65 |
+
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
|
66 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
67 |
+
from safetensors.torch import load_file, safe_open
|
68 |
+
weights_sd = load_file(file)
|
69 |
+
else:
|
70 |
+
weights_sd = torch.load(file, map_location='cpu')
|
|
|
71 |
|
72 |
+
# get dim (rank)
|
73 |
+
network_alpha = None
|
74 |
+
network_dim = None
|
75 |
for key, value in weights_sd.items():
|
76 |
+
if network_alpha is None and 'alpha' in key:
|
77 |
+
network_alpha = value
|
78 |
+
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
|
79 |
+
network_dim = value.size()[0]
|
80 |
+
|
81 |
+
if network_alpha is None:
|
82 |
+
network_alpha = network_dim
|
83 |
+
|
84 |
+
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
network.weights_sd = weights_sd
|
86 |
return network
|
87 |
|
88 |
|
89 |
class LoRANetwork(torch.nn.Module):
|
|
|
90 |
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
|
|
91 |
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
92 |
LORA_PREFIX_UNET = 'lora_unet'
|
93 |
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
94 |
|
95 |
+
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
|
96 |
super().__init__()
|
97 |
self.multiplier = multiplier
|
|
|
98 |
self.lora_dim = lora_dim
|
99 |
self.alpha = alpha
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
# create module instances
|
102 |
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
103 |
loras = []
|
104 |
for name, module in root_module.named_modules():
|
105 |
if module.__class__.__name__ in target_replace_modules:
|
|
|
106 |
for child_name, child_module in module.named_modules():
|
107 |
+
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
|
|
|
|
|
|
|
108 |
lora_name = prefix + '.' + name + '.' + child_name
|
109 |
lora_name = lora_name.replace('.', '_')
|
110 |
+
lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
loras.append(lora)
|
112 |
return loras
|
113 |
|
|
|
115 |
text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
116 |
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
117 |
|
118 |
+
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
|
|
|
|
|
|
|
|
|
|
|
119 |
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
120 |
|
121 |
self.weights_sd = None
|
|
|
126 |
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
127 |
names.add(lora.lora_name)
|
128 |
|
|
|
|
|
|
|
|
|
|
|
129 |
def load_weights(self, file):
|
130 |
if os.path.splitext(file)[1] == '.safetensors':
|
131 |
from safetensors.torch import load_file, safe_open
|
|
|
235 |
save_file(state_dict, file, metadata)
|
236 |
else:
|
237 |
torch.save(state_dict, file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
networks/merge_lora.py
CHANGED
@@ -48,7 +48,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
|
48 |
for name, module in root_module.named_modules():
|
49 |
if module.__class__.__name__ in target_replace_modules:
|
50 |
for child_name, child_module in module.named_modules():
|
51 |
-
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
|
52 |
lora_name = prefix + '.' + name + '.' + child_name
|
53 |
lora_name = lora_name.replace('.', '_')
|
54 |
name_to_module[lora_name] = child_module
|
@@ -80,19 +80,13 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
|
80 |
|
81 |
# W <- W + U * D
|
82 |
weight = module.weight
|
83 |
-
# print(module_name, down_weight.size(), up_weight.size())
|
84 |
if len(weight.size()) == 2:
|
85 |
# linear
|
86 |
weight = weight + ratio * (up_weight @ down_weight) * scale
|
87 |
-
|
88 |
-
# conv2d
|
89 |
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
90 |
).unsqueeze(2).unsqueeze(3) * scale
|
91 |
-
else:
|
92 |
-
# conv2d 3x3
|
93 |
-
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
94 |
-
# print(conved.size(), weight.size(), module.stride, module.padding)
|
95 |
-
weight = weight + ratio * conved * scale
|
96 |
|
97 |
module.weight = torch.nn.Parameter(weight)
|
98 |
|
@@ -129,7 +123,7 @@ def merge_lora_models(models, ratios, merge_dtype):
|
|
129 |
alphas[lora_module_name] = alpha
|
130 |
if lora_module_name not in base_alphas:
|
131 |
base_alphas[lora_module_name] = alpha
|
132 |
-
|
133 |
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
134 |
|
135 |
# merge
|
@@ -151,7 +145,7 @@ def merge_lora_models(models, ratios, merge_dtype):
|
|
151 |
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
152 |
else:
|
153 |
merged_sd[key] = lora_sd[key] * scale
|
154 |
-
|
155 |
# set alpha to sd
|
156 |
for lora_module_name, alpha in base_alphas.items():
|
157 |
key = lora_module_name + ".alpha"
|
|
|
48 |
for name, module in root_module.named_modules():
|
49 |
if module.__class__.__name__ in target_replace_modules:
|
50 |
for child_name, child_module in module.named_modules():
|
51 |
+
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
|
52 |
lora_name = prefix + '.' + name + '.' + child_name
|
53 |
lora_name = lora_name.replace('.', '_')
|
54 |
name_to_module[lora_name] = child_module
|
|
|
80 |
|
81 |
# W <- W + U * D
|
82 |
weight = module.weight
|
|
|
83 |
if len(weight.size()) == 2:
|
84 |
# linear
|
85 |
weight = weight + ratio * (up_weight @ down_weight) * scale
|
86 |
+
else:
|
87 |
+
# conv2d
|
88 |
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
89 |
).unsqueeze(2).unsqueeze(3) * scale
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
module.weight = torch.nn.Parameter(weight)
|
92 |
|
|
|
123 |
alphas[lora_module_name] = alpha
|
124 |
if lora_module_name not in base_alphas:
|
125 |
base_alphas[lora_module_name] = alpha
|
126 |
+
|
127 |
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
128 |
|
129 |
# merge
|
|
|
145 |
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
146 |
else:
|
147 |
merged_sd[key] = lora_sd[key] * scale
|
148 |
+
|
149 |
# set alpha to sd
|
150 |
for lora_module_name, alpha in base_alphas.items():
|
151 |
key = lora_module_name + ".alpha"
|
networks/resize_lora.py
CHANGED
@@ -1,15 +1,14 @@
|
|
1 |
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
2 |
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
3 |
-
# Thanks to cloneofsimo
|
4 |
|
5 |
import argparse
|
|
|
6 |
import torch
|
7 |
from safetensors.torch import load_file, save_file, safe_open
|
8 |
from tqdm import tqdm
|
9 |
from library import train_util, model_util
|
10 |
-
import numpy as np
|
11 |
|
12 |
-
MIN_SV = 1e-6
|
13 |
|
14 |
def load_state_dict(file_name, dtype):
|
15 |
if model_util.is_safetensors(file_name):
|
@@ -39,149 +38,12 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
|
|
39 |
torch.save(model, file_name)
|
40 |
|
41 |
|
42 |
-
def
|
43 |
-
original_sum = float(torch.sum(S))
|
44 |
-
cumulative_sums = torch.cumsum(S, dim=0)/original_sum
|
45 |
-
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
46 |
-
if index >= len(S):
|
47 |
-
index = len(S) - 1
|
48 |
-
|
49 |
-
return index
|
50 |
-
|
51 |
-
|
52 |
-
def index_sv_fro(S, target):
|
53 |
-
S_squared = S.pow(2)
|
54 |
-
s_fro_sq = float(torch.sum(S_squared))
|
55 |
-
sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
|
56 |
-
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
57 |
-
if index >= len(S):
|
58 |
-
index = len(S) - 1
|
59 |
-
|
60 |
-
return index
|
61 |
-
|
62 |
-
|
63 |
-
# Modified from Kohaku-blueleaf's extract/merge functions
|
64 |
-
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
65 |
-
out_size, in_size, kernel_size, _ = weight.size()
|
66 |
-
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
|
67 |
-
|
68 |
-
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
69 |
-
lora_rank = param_dict["new_rank"]
|
70 |
-
|
71 |
-
U = U[:, :lora_rank]
|
72 |
-
S = S[:lora_rank]
|
73 |
-
U = U @ torch.diag(S)
|
74 |
-
Vh = Vh[:lora_rank, :]
|
75 |
-
|
76 |
-
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
|
77 |
-
param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
|
78 |
-
del U, S, Vh, weight
|
79 |
-
return param_dict
|
80 |
-
|
81 |
-
|
82 |
-
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
83 |
-
out_size, in_size = weight.size()
|
84 |
-
|
85 |
-
U, S, Vh = torch.linalg.svd(weight.to(device))
|
86 |
-
|
87 |
-
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
88 |
-
lora_rank = param_dict["new_rank"]
|
89 |
-
|
90 |
-
U = U[:, :lora_rank]
|
91 |
-
S = S[:lora_rank]
|
92 |
-
U = U @ torch.diag(S)
|
93 |
-
Vh = Vh[:lora_rank, :]
|
94 |
-
|
95 |
-
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
|
96 |
-
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
|
97 |
-
del U, S, Vh, weight
|
98 |
-
return param_dict
|
99 |
-
|
100 |
-
|
101 |
-
def merge_conv(lora_down, lora_up, device):
|
102 |
-
in_rank, in_size, kernel_size, k_ = lora_down.shape
|
103 |
-
out_size, out_rank, _, _ = lora_up.shape
|
104 |
-
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
|
105 |
-
|
106 |
-
lora_down = lora_down.to(device)
|
107 |
-
lora_up = lora_up.to(device)
|
108 |
-
|
109 |
-
merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
|
110 |
-
weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
|
111 |
-
del lora_up, lora_down
|
112 |
-
return weight
|
113 |
-
|
114 |
-
|
115 |
-
def merge_linear(lora_down, lora_up, device):
|
116 |
-
in_rank, in_size = lora_down.shape
|
117 |
-
out_size, out_rank = lora_up.shape
|
118 |
-
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
|
119 |
-
|
120 |
-
lora_down = lora_down.to(device)
|
121 |
-
lora_up = lora_up.to(device)
|
122 |
-
|
123 |
-
weight = lora_up @ lora_down
|
124 |
-
del lora_up, lora_down
|
125 |
-
return weight
|
126 |
-
|
127 |
-
|
128 |
-
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
129 |
-
param_dict = {}
|
130 |
-
|
131 |
-
if dynamic_method=="sv_ratio":
|
132 |
-
# Calculate new dim and alpha based off ratio
|
133 |
-
max_sv = S[0]
|
134 |
-
min_sv = max_sv/dynamic_param
|
135 |
-
new_rank = max(torch.sum(S > min_sv).item(),1)
|
136 |
-
new_alpha = float(scale*new_rank)
|
137 |
-
|
138 |
-
elif dynamic_method=="sv_cumulative":
|
139 |
-
# Calculate new dim and alpha based off cumulative sum
|
140 |
-
new_rank = index_sv_cumulative(S, dynamic_param)
|
141 |
-
new_rank = max(new_rank, 1)
|
142 |
-
new_alpha = float(scale*new_rank)
|
143 |
-
|
144 |
-
elif dynamic_method=="sv_fro":
|
145 |
-
# Calculate new dim and alpha based off sqrt sum of squares
|
146 |
-
new_rank = index_sv_fro(S, dynamic_param)
|
147 |
-
new_rank = min(max(new_rank, 1), len(S)-1)
|
148 |
-
new_alpha = float(scale*new_rank)
|
149 |
-
else:
|
150 |
-
new_rank = rank
|
151 |
-
new_alpha = float(scale*new_rank)
|
152 |
-
|
153 |
-
|
154 |
-
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
|
155 |
-
new_rank = 1
|
156 |
-
new_alpha = float(scale*new_rank)
|
157 |
-
elif new_rank > rank: # cap max rank at rank
|
158 |
-
new_rank = rank
|
159 |
-
new_alpha = float(scale*new_rank)
|
160 |
-
|
161 |
-
|
162 |
-
# Calculate resize info
|
163 |
-
s_sum = torch.sum(torch.abs(S))
|
164 |
-
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
165 |
-
|
166 |
-
S_squared = S.pow(2)
|
167 |
-
s_fro = torch.sqrt(torch.sum(S_squared))
|
168 |
-
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
|
169 |
-
fro_percent = float(s_red_fro/s_fro)
|
170 |
-
|
171 |
-
param_dict["new_rank"] = new_rank
|
172 |
-
param_dict["new_alpha"] = new_alpha
|
173 |
-
param_dict["sum_retained"] = (s_rank)/s_sum
|
174 |
-
param_dict["fro_retained"] = fro_percent
|
175 |
-
param_dict["max_ratio"] = S[0]/S[new_rank]
|
176 |
-
|
177 |
-
return param_dict
|
178 |
-
|
179 |
-
|
180 |
-
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
181 |
network_alpha = None
|
182 |
network_dim = None
|
183 |
verbose_str = "\n"
|
184 |
-
|
|
|
185 |
|
186 |
# Extract loaded lora dim and alpha
|
187 |
for key, value in lora_sd.items():
|
@@ -195,9 +57,9 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
|
195 |
network_alpha = network_dim
|
196 |
|
197 |
scale = network_alpha/network_dim
|
|
|
198 |
|
199 |
-
|
200 |
-
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
|
201 |
|
202 |
lora_down_weight = None
|
203 |
lora_up_weight = None
|
@@ -206,6 +68,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
|
206 |
block_down_name = None
|
207 |
block_up_name = None
|
208 |
|
|
|
209 |
with torch.no_grad():
|
210 |
for key, value in tqdm(lora_sd.items()):
|
211 |
if 'lora_down' in key:
|
@@ -222,43 +85,57 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
|
222 |
conv2d = (len(lora_down_weight.size()) == 4)
|
223 |
|
224 |
if conv2d:
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
if verbose:
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
fro_list.append(float(fro_retained))
|
237 |
|
238 |
-
|
239 |
-
|
|
|
240 |
|
241 |
-
|
242 |
-
verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
|
243 |
-
else:
|
244 |
-
verbose_str+=f"\n"
|
245 |
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
|
251 |
block_down_name = None
|
252 |
block_up_name = None
|
253 |
lora_down_weight = None
|
254 |
lora_up_weight = None
|
255 |
weights_loaded = False
|
256 |
-
del param_dict
|
257 |
|
258 |
if verbose:
|
259 |
print(verbose_str)
|
260 |
-
|
261 |
-
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
262 |
print("resizing complete")
|
263 |
return o_lora_sd, network_dim, new_alpha
|
264 |
|
@@ -274,9 +151,6 @@ def resize(args):
|
|
274 |
return torch.bfloat16
|
275 |
return None
|
276 |
|
277 |
-
if args.dynamic_method and not args.dynamic_param:
|
278 |
-
raise Exception("If using dynamic_method, then dynamic_param is required")
|
279 |
-
|
280 |
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
|
281 |
save_dtype = str_to_dtype(args.save_precision)
|
282 |
if save_dtype is None:
|
@@ -285,23 +159,17 @@ def resize(args):
|
|
285 |
print("loading Model...")
|
286 |
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
287 |
|
288 |
-
print("
|
289 |
-
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.
|
290 |
|
291 |
# update metadata
|
292 |
if metadata is None:
|
293 |
metadata = {}
|
294 |
|
295 |
comment = metadata.get("ss_training_comment", "")
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
metadata["ss_network_dim"] = str(args.new_rank)
|
300 |
-
metadata["ss_network_alpha"] = str(new_alpha)
|
301 |
-
else:
|
302 |
-
metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
|
303 |
-
metadata["ss_network_dim"] = 'Dynamic'
|
304 |
-
metadata["ss_network_alpha"] = 'Dynamic'
|
305 |
|
306 |
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
307 |
metadata["sshs_model_hash"] = model_hash
|
@@ -325,11 +193,6 @@ if __name__ == '__main__':
|
|
325 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
326 |
parser.add_argument("--verbose", action="store_true",
|
327 |
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
|
328 |
-
parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
|
329 |
-
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
|
330 |
-
parser.add_argument("--dynamic_param", type=float, default=None,
|
331 |
-
help="Specify target for dynamic reduction")
|
332 |
-
|
333 |
|
334 |
args = parser.parse_args()
|
335 |
resize(args)
|
|
|
1 |
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
2 |
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
3 |
+
# Thanks to cloneofsimo and kohya
|
4 |
|
5 |
import argparse
|
6 |
+
import os
|
7 |
import torch
|
8 |
from safetensors.torch import load_file, save_file, safe_open
|
9 |
from tqdm import tqdm
|
10 |
from library import train_util, model_util
|
|
|
11 |
|
|
|
12 |
|
13 |
def load_state_dict(file_name, dtype):
|
14 |
if model_util.is_safetensors(file_name):
|
|
|
38 |
torch.save(model, file_name)
|
39 |
|
40 |
|
41 |
+
def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
network_alpha = None
|
43 |
network_dim = None
|
44 |
verbose_str = "\n"
|
45 |
+
|
46 |
+
CLAMP_QUANTILE = 0.99
|
47 |
|
48 |
# Extract loaded lora dim and alpha
|
49 |
for key, value in lora_sd.items():
|
|
|
57 |
network_alpha = network_dim
|
58 |
|
59 |
scale = network_alpha/network_dim
|
60 |
+
new_alpha = float(scale*new_rank) # calculate new alpha from scale
|
61 |
|
62 |
+
print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}")
|
|
|
63 |
|
64 |
lora_down_weight = None
|
65 |
lora_up_weight = None
|
|
|
68 |
block_down_name = None
|
69 |
block_up_name = None
|
70 |
|
71 |
+
print("resizing lora...")
|
72 |
with torch.no_grad():
|
73 |
for key, value in tqdm(lora_sd.items()):
|
74 |
if 'lora_down' in key:
|
|
|
85 |
conv2d = (len(lora_down_weight.size()) == 4)
|
86 |
|
87 |
if conv2d:
|
88 |
+
lora_down_weight = lora_down_weight.squeeze()
|
89 |
+
lora_up_weight = lora_up_weight.squeeze()
|
90 |
+
|
91 |
+
if device:
|
92 |
+
org_device = lora_up_weight.device
|
93 |
+
lora_up_weight = lora_up_weight.to(args.device)
|
94 |
+
lora_down_weight = lora_down_weight.to(args.device)
|
95 |
+
|
96 |
+
full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
|
97 |
+
|
98 |
+
U, S, Vh = torch.linalg.svd(full_weight_matrix)
|
99 |
|
100 |
if verbose:
|
101 |
+
s_sum = torch.sum(torch.abs(S))
|
102 |
+
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
103 |
+
verbose_str+=f"{block_down_name:76} | "
|
104 |
+
verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n"
|
|
|
105 |
|
106 |
+
U = U[:, :new_rank]
|
107 |
+
S = S[:new_rank]
|
108 |
+
U = U @ torch.diag(S)
|
109 |
|
110 |
+
Vh = Vh[:new_rank, :]
|
|
|
|
|
|
|
111 |
|
112 |
+
dist = torch.cat([U.flatten(), Vh.flatten()])
|
113 |
+
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
114 |
+
low_val = -hi_val
|
115 |
+
|
116 |
+
U = U.clamp(low_val, hi_val)
|
117 |
+
Vh = Vh.clamp(low_val, hi_val)
|
118 |
+
|
119 |
+
if conv2d:
|
120 |
+
U = U.unsqueeze(2).unsqueeze(3)
|
121 |
+
Vh = Vh.unsqueeze(2).unsqueeze(3)
|
122 |
+
|
123 |
+
if device:
|
124 |
+
U = U.to(org_device)
|
125 |
+
Vh = Vh.to(org_device)
|
126 |
+
|
127 |
+
o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
|
128 |
+
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
|
129 |
+
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
|
130 |
|
131 |
block_down_name = None
|
132 |
block_up_name = None
|
133 |
lora_down_weight = None
|
134 |
lora_up_weight = None
|
135 |
weights_loaded = False
|
|
|
136 |
|
137 |
if verbose:
|
138 |
print(verbose_str)
|
|
|
|
|
139 |
print("resizing complete")
|
140 |
return o_lora_sd, network_dim, new_alpha
|
141 |
|
|
|
151 |
return torch.bfloat16
|
152 |
return None
|
153 |
|
|
|
|
|
|
|
154 |
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
|
155 |
save_dtype = str_to_dtype(args.save_precision)
|
156 |
if save_dtype is None:
|
|
|
159 |
print("loading Model...")
|
160 |
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
161 |
|
162 |
+
print("resizing rank...")
|
163 |
+
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose)
|
164 |
|
165 |
# update metadata
|
166 |
if metadata is None:
|
167 |
metadata = {}
|
168 |
|
169 |
comment = metadata.get("ss_training_comment", "")
|
170 |
+
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
|
171 |
+
metadata["ss_network_dim"] = str(args.new_rank)
|
172 |
+
metadata["ss_network_alpha"] = str(new_alpha)
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
175 |
metadata["sshs_model_hash"] = model_hash
|
|
|
193 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
194 |
parser.add_argument("--verbose", action="store_true",
|
195 |
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
args = parser.parse_args()
|
198 |
resize(args)
|
networks/svd_merge_lora.py
CHANGED
@@ -23,20 +23,19 @@ def load_state_dict(file_name, dtype):
|
|
23 |
return sd
|
24 |
|
25 |
|
26 |
-
def save_to_file(file_name, state_dict, dtype):
|
27 |
if dtype is not None:
|
28 |
for key in list(state_dict.keys()):
|
29 |
if type(state_dict[key]) == torch.Tensor:
|
30 |
state_dict[key] = state_dict[key].to(dtype)
|
31 |
|
32 |
if os.path.splitext(file_name)[1] == '.safetensors':
|
33 |
-
save_file(
|
34 |
else:
|
35 |
-
torch.save(
|
36 |
|
37 |
|
38 |
-
def merge_lora_models(models, ratios, new_rank,
|
39 |
-
print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
|
40 |
merged_sd = {}
|
41 |
for model, ratio in zip(models, ratios):
|
42 |
print(f"loading: {model}")
|
@@ -59,12 +58,11 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
|
59 |
in_dim = down_weight.size()[1]
|
60 |
out_dim = up_weight.size()[0]
|
61 |
conv2d = len(down_weight.size()) == 4
|
62 |
-
|
63 |
-
# print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
|
64 |
|
65 |
# make original weight if not exist
|
66 |
if lora_module_name not in merged_sd:
|
67 |
-
weight = torch.zeros((out_dim, in_dim,
|
68 |
if device:
|
69 |
weight = weight.to(device)
|
70 |
else:
|
@@ -77,18 +75,11 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
|
77 |
|
78 |
# W <- W + U * D
|
79 |
scale = (alpha / network_dim)
|
80 |
-
|
81 |
-
if device: # and isinstance(scale, torch.Tensor):
|
82 |
-
scale = scale.to(device)
|
83 |
-
|
84 |
if not conv2d: # linear
|
85 |
weight = weight + ratio * (up_weight @ down_weight) * scale
|
86 |
-
|
87 |
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
88 |
).unsqueeze(2).unsqueeze(3) * scale
|
89 |
-
else:
|
90 |
-
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
91 |
-
weight = weight + ratio * conved * scale
|
92 |
|
93 |
merged_sd[lora_module_name] = weight
|
94 |
|
@@ -98,26 +89,16 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
|
98 |
with torch.no_grad():
|
99 |
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
100 |
conv2d = (len(mat.size()) == 4)
|
101 |
-
kernel_size = None if not conv2d else mat.size()[2:4]
|
102 |
-
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
103 |
-
out_dim, in_dim = mat.size()[0:2]
|
104 |
-
|
105 |
if conv2d:
|
106 |
-
|
107 |
-
mat = mat.flatten(start_dim=1)
|
108 |
-
else:
|
109 |
-
mat = mat.squeeze()
|
110 |
-
|
111 |
-
module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
|
112 |
-
module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
113 |
|
114 |
U, S, Vh = torch.linalg.svd(mat)
|
115 |
|
116 |
-
U = U[:, :
|
117 |
-
S = S[:
|
118 |
U = U @ torch.diag(S)
|
119 |
|
120 |
-
Vh = Vh[:
|
121 |
|
122 |
dist = torch.cat([U.flatten(), Vh.flatten()])
|
123 |
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
@@ -126,16 +107,16 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
|
126 |
U = U.clamp(low_val, hi_val)
|
127 |
Vh = Vh.clamp(low_val, hi_val)
|
128 |
|
129 |
-
if conv2d:
|
130 |
-
U = U.reshape(out_dim, module_new_rank, 1, 1)
|
131 |
-
Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1])
|
132 |
-
|
133 |
up_weight = U
|
134 |
down_weight = Vh
|
135 |
|
|
|
|
|
|
|
|
|
136 |
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
|
137 |
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
|
138 |
-
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(
|
139 |
|
140 |
return merged_lora_sd
|
141 |
|
@@ -157,11 +138,10 @@ def merge(args):
|
|
157 |
if save_dtype is None:
|
158 |
save_dtype = merge_dtype
|
159 |
|
160 |
-
|
161 |
-
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype)
|
162 |
|
163 |
print(f"saving model to: {args.save_to}")
|
164 |
-
save_to_file(args.save_to, state_dict, save_dtype)
|
165 |
|
166 |
|
167 |
if __name__ == '__main__':
|
@@ -178,8 +158,6 @@ if __name__ == '__main__':
|
|
178 |
help="ratios for each model / それぞれのLoRAモデルの比率")
|
179 |
parser.add_argument("--new_rank", type=int, default=4,
|
180 |
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
181 |
-
parser.add_argument("--new_conv_rank", type=int, default=None,
|
182 |
-
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ")
|
183 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
184 |
|
185 |
args = parser.parse_args()
|
|
|
23 |
return sd
|
24 |
|
25 |
|
26 |
+
def save_to_file(file_name, model, state_dict, dtype):
|
27 |
if dtype is not None:
|
28 |
for key in list(state_dict.keys()):
|
29 |
if type(state_dict[key]) == torch.Tensor:
|
30 |
state_dict[key] = state_dict[key].to(dtype)
|
31 |
|
32 |
if os.path.splitext(file_name)[1] == '.safetensors':
|
33 |
+
save_file(model, file_name)
|
34 |
else:
|
35 |
+
torch.save(model, file_name)
|
36 |
|
37 |
|
38 |
+
def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
|
|
39 |
merged_sd = {}
|
40 |
for model, ratio in zip(models, ratios):
|
41 |
print(f"loading: {model}")
|
|
|
58 |
in_dim = down_weight.size()[1]
|
59 |
out_dim = up_weight.size()[0]
|
60 |
conv2d = len(down_weight.size()) == 4
|
61 |
+
print(lora_module_name, network_dim, alpha, in_dim, out_dim)
|
|
|
62 |
|
63 |
# make original weight if not exist
|
64 |
if lora_module_name not in merged_sd:
|
65 |
+
weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
|
66 |
if device:
|
67 |
weight = weight.to(device)
|
68 |
else:
|
|
|
75 |
|
76 |
# W <- W + U * D
|
77 |
scale = (alpha / network_dim)
|
|
|
|
|
|
|
|
|
78 |
if not conv2d: # linear
|
79 |
weight = weight + ratio * (up_weight @ down_weight) * scale
|
80 |
+
else:
|
81 |
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
82 |
).unsqueeze(2).unsqueeze(3) * scale
|
|
|
|
|
|
|
83 |
|
84 |
merged_sd[lora_module_name] = weight
|
85 |
|
|
|
89 |
with torch.no_grad():
|
90 |
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
91 |
conv2d = (len(mat.size()) == 4)
|
|
|
|
|
|
|
|
|
92 |
if conv2d:
|
93 |
+
mat = mat.squeeze()
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
U, S, Vh = torch.linalg.svd(mat)
|
96 |
|
97 |
+
U = U[:, :new_rank]
|
98 |
+
S = S[:new_rank]
|
99 |
U = U @ torch.diag(S)
|
100 |
|
101 |
+
Vh = Vh[:new_rank, :]
|
102 |
|
103 |
dist = torch.cat([U.flatten(), Vh.flatten()])
|
104 |
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
|
|
107 |
U = U.clamp(low_val, hi_val)
|
108 |
Vh = Vh.clamp(low_val, hi_val)
|
109 |
|
|
|
|
|
|
|
|
|
110 |
up_weight = U
|
111 |
down_weight = Vh
|
112 |
|
113 |
+
if conv2d:
|
114 |
+
up_weight = up_weight.unsqueeze(2).unsqueeze(3)
|
115 |
+
down_weight = down_weight.unsqueeze(2).unsqueeze(3)
|
116 |
+
|
117 |
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
|
118 |
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
|
119 |
+
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank)
|
120 |
|
121 |
return merged_lora_sd
|
122 |
|
|
|
138 |
if save_dtype is None:
|
139 |
save_dtype = merge_dtype
|
140 |
|
141 |
+
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, args.device, merge_dtype)
|
|
|
142 |
|
143 |
print(f"saving model to: {args.save_to}")
|
144 |
+
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
145 |
|
146 |
|
147 |
if __name__ == '__main__':
|
|
|
158 |
help="ratios for each model / それぞれのLoRAモデルの比率")
|
159 |
parser.add_argument("--new_rank", type=int, default=4,
|
160 |
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
|
|
|
|
161 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
162 |
|
163 |
args = parser.parse_args()
|
requirements.txt
CHANGED
@@ -12,8 +12,6 @@ safetensors==0.2.6
|
|
12 |
gradio==3.16.2
|
13 |
altair==4.2.2
|
14 |
easygui==0.98.3
|
15 |
-
toml==0.10.2
|
16 |
-
voluptuous==0.13.1
|
17 |
# for BLIP captioning
|
18 |
requests==2.28.2
|
19 |
timm==0.6.12
|
@@ -23,4 +21,5 @@ fairscale==0.4.13
|
|
23 |
tensorflow==2.10.1
|
24 |
huggingface-hub==0.12.0
|
25 |
# for kohya_ss library
|
|
|
26 |
.
|
|
|
12 |
gradio==3.16.2
|
13 |
altair==4.2.2
|
14 |
easygui==0.98.3
|
|
|
|
|
15 |
# for BLIP captioning
|
16 |
requests==2.28.2
|
17 |
timm==0.6.12
|
|
|
21 |
tensorflow==2.10.1
|
22 |
huggingface-hub==0.12.0
|
23 |
# for kohya_ss library
|
24 |
+
#locon.locon_kohya
|
25 |
.
|
requirements_startup.txt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.15.0
|
2 |
+
transformers==4.26.0
|
3 |
+
ftfy==6.1.1
|
4 |
+
albumentations==1.3.0
|
5 |
+
opencv-python==4.7.0.68
|
6 |
+
einops==0.6.0
|
7 |
+
diffusers[torch]==0.10.2
|
8 |
+
pytorch-lightning==1.9.0
|
9 |
+
bitsandbytes==0.35.0
|
10 |
+
tensorboard==2.10.1
|
11 |
+
safetensors==0.2.6
|
12 |
+
gradio==3.18.0
|
13 |
+
altair==4.2.2
|
14 |
+
easygui==0.98.3
|
15 |
+
# for BLIP captioning
|
16 |
+
requests==2.28.2
|
17 |
+
timm==0.4.12
|
18 |
+
fairscale==0.4.4
|
19 |
+
# for WD14 captioning
|
20 |
+
tensorflow==2.10.1
|
21 |
+
huggingface-hub==0.12.0
|
22 |
+
# for kohya_ss library
|
23 |
+
.
|
train_db.py
CHANGED
@@ -15,11 +15,7 @@ import diffusers
|
|
15 |
from diffusers import DDPMScheduler
|
16 |
|
17 |
import library.train_util as train_util
|
18 |
-
|
19 |
-
from library.config_util import (
|
20 |
-
ConfigSanitizer,
|
21 |
-
BlueprintGenerator,
|
22 |
-
)
|
23 |
|
24 |
|
25 |
def collate_fn(examples):
|
@@ -37,33 +33,24 @@ def train(args):
|
|
37 |
|
38 |
tokenizer = train_util.load_tokenizer(args)
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
if any(getattr(args, attr) is not None for attr in ignored):
|
46 |
-
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
47 |
-
else:
|
48 |
-
user_config = {
|
49 |
-
"datasets": [{
|
50 |
-
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
|
51 |
-
}]
|
52 |
-
}
|
53 |
-
|
54 |
-
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
55 |
-
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
56 |
|
57 |
if args.no_token_padding:
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
if args.debug_dataset:
|
61 |
-
train_util.debug_dataset(
|
62 |
return
|
63 |
|
64 |
-
if cache_latents:
|
65 |
-
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
66 |
-
|
67 |
# acceleratorを準備する
|
68 |
print("prepare accelerator")
|
69 |
|
@@ -104,7 +91,7 @@ def train(args):
|
|
104 |
vae.requires_grad_(False)
|
105 |
vae.eval()
|
106 |
with torch.no_grad():
|
107 |
-
|
108 |
vae.to("cpu")
|
109 |
if torch.cuda.is_available():
|
110 |
torch.cuda.empty_cache()
|
@@ -128,18 +115,38 @@ def train(args):
|
|
128 |
|
129 |
# 学習に必要なクラスを準備する
|
130 |
print("prepare optimizer, data loader etc.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
if train_text_encoder:
|
132 |
trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
133 |
else:
|
134 |
trainable_params = unet.parameters()
|
135 |
|
136 |
-
|
|
|
137 |
|
138 |
# dataloaderを準備する
|
139 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
140 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
141 |
train_dataloader = torch.utils.data.DataLoader(
|
142 |
-
|
143 |
|
144 |
# 学習ステップ数を計算する
|
145 |
if args.max_train_epochs is not None:
|
@@ -149,10 +156,9 @@ def train(args):
|
|
149 |
if args.stop_text_encoder_training is None:
|
150 |
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
151 |
|
152 |
-
# lr schedulerを用意する
|
153 |
-
lr_scheduler =
|
154 |
-
|
155 |
-
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
156 |
|
157 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
158 |
if args.full_fp16:
|
@@ -189,8 +195,8 @@ def train(args):
|
|
189 |
# 学習する
|
190 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
191 |
print("running training / 学習開始")
|
192 |
-
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {
|
193 |
-
print(f" num reg images / 正則化画像の数: {
|
194 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
195 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
196 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
@@ -211,7 +217,7 @@ def train(args):
|
|
211 |
loss_total = 0.0
|
212 |
for epoch in range(num_train_epochs):
|
213 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
214 |
-
|
215 |
|
216 |
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
217 |
unet.train()
|
@@ -275,12 +281,12 @@ def train(args):
|
|
275 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
276 |
|
277 |
accelerator.backward(loss)
|
278 |
-
if accelerator.sync_gradients
|
279 |
if train_text_encoder:
|
280 |
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
281 |
else:
|
282 |
params_to_clip = unet.parameters()
|
283 |
-
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
284 |
|
285 |
optimizer.step()
|
286 |
lr_scheduler.step()
|
@@ -291,13 +297,9 @@ def train(args):
|
|
291 |
progress_bar.update(1)
|
292 |
global_step += 1
|
293 |
|
294 |
-
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
295 |
-
|
296 |
current_loss = loss.detach().item()
|
297 |
if args.logging_dir is not None:
|
298 |
-
logs = {"loss": current_loss, "lr":
|
299 |
-
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
300 |
-
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
301 |
accelerator.log(logs, step=global_step)
|
302 |
|
303 |
if epoch == 0:
|
@@ -324,8 +326,6 @@ def train(args):
|
|
324 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
325 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
326 |
|
327 |
-
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
328 |
-
|
329 |
is_main_process = accelerator.is_main_process
|
330 |
if is_main_process:
|
331 |
unet = unwrap_model(unet)
|
@@ -352,8 +352,6 @@ if __name__ == '__main__':
|
|
352 |
train_util.add_dataset_arguments(parser, True, False, True)
|
353 |
train_util.add_training_arguments(parser, True)
|
354 |
train_util.add_sd_saving_arguments(parser)
|
355 |
-
train_util.add_optimizer_arguments(parser)
|
356 |
-
config_util.add_config_arguments(parser)
|
357 |
|
358 |
parser.add_argument("--no_token_padding", action="store_true",
|
359 |
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
|
|
15 |
from diffusers import DDPMScheduler
|
16 |
|
17 |
import library.train_util as train_util
|
18 |
+
from library.train_util import DreamBoothDataset
|
|
|
|
|
|
|
|
|
19 |
|
20 |
|
21 |
def collate_fn(examples):
|
|
|
33 |
|
34 |
tokenizer = train_util.load_tokenizer(args)
|
35 |
|
36 |
+
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
|
37 |
+
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
|
38 |
+
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
39 |
+
args.bucket_reso_steps, args.bucket_no_upscale,
|
40 |
+
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
if args.no_token_padding:
|
43 |
+
train_dataset.disable_token_padding()
|
44 |
+
|
45 |
+
# 学習データのdropout率を設定する
|
46 |
+
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
|
47 |
+
|
48 |
+
train_dataset.make_buckets()
|
49 |
|
50 |
if args.debug_dataset:
|
51 |
+
train_util.debug_dataset(train_dataset)
|
52 |
return
|
53 |
|
|
|
|
|
|
|
54 |
# acceleratorを準備する
|
55 |
print("prepare accelerator")
|
56 |
|
|
|
91 |
vae.requires_grad_(False)
|
92 |
vae.eval()
|
93 |
with torch.no_grad():
|
94 |
+
train_dataset.cache_latents(vae)
|
95 |
vae.to("cpu")
|
96 |
if torch.cuda.is_available():
|
97 |
torch.cuda.empty_cache()
|
|
|
115 |
|
116 |
# 学習に必要なクラスを準備する
|
117 |
print("prepare optimizer, data loader etc.")
|
118 |
+
|
119 |
+
# 8-bit Adamを使う
|
120 |
+
if args.use_8bit_adam:
|
121 |
+
try:
|
122 |
+
import bitsandbytes as bnb
|
123 |
+
except ImportError:
|
124 |
+
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
125 |
+
print("use 8-bit Adam optimizer")
|
126 |
+
optimizer_class = bnb.optim.AdamW8bit
|
127 |
+
elif args.use_lion_optimizer:
|
128 |
+
try:
|
129 |
+
import lion_pytorch
|
130 |
+
except ImportError:
|
131 |
+
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
132 |
+
print("use Lion optimizer")
|
133 |
+
optimizer_class = lion_pytorch.Lion
|
134 |
+
else:
|
135 |
+
optimizer_class = torch.optim.AdamW
|
136 |
+
|
137 |
if train_text_encoder:
|
138 |
trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
139 |
else:
|
140 |
trainable_params = unet.parameters()
|
141 |
|
142 |
+
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
143 |
+
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
144 |
|
145 |
# dataloaderを準備する
|
146 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
147 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
148 |
train_dataloader = torch.utils.data.DataLoader(
|
149 |
+
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
150 |
|
151 |
# 学習ステップ数を計算する
|
152 |
if args.max_train_epochs is not None:
|
|
|
156 |
if args.stop_text_encoder_training is None:
|
157 |
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
158 |
|
159 |
+
# lr schedulerを用意する
|
160 |
+
lr_scheduler = diffusers.optimization.get_scheduler(
|
161 |
+
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
|
|
|
162 |
|
163 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
164 |
if args.full_fp16:
|
|
|
195 |
# 学習する
|
196 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
197 |
print("running training / 学習開始")
|
198 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
|
199 |
+
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
|
200 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
201 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
202 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
|
|
217 |
loss_total = 0.0
|
218 |
for epoch in range(num_train_epochs):
|
219 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
220 |
+
train_dataset.set_current_epoch(epoch + 1)
|
221 |
|
222 |
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
223 |
unet.train()
|
|
|
281 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
282 |
|
283 |
accelerator.backward(loss)
|
284 |
+
if accelerator.sync_gradients:
|
285 |
if train_text_encoder:
|
286 |
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
287 |
else:
|
288 |
params_to_clip = unet.parameters()
|
289 |
+
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
290 |
|
291 |
optimizer.step()
|
292 |
lr_scheduler.step()
|
|
|
297 |
progress_bar.update(1)
|
298 |
global_step += 1
|
299 |
|
|
|
|
|
300 |
current_loss = loss.detach().item()
|
301 |
if args.logging_dir is not None:
|
302 |
+
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
|
|
303 |
accelerator.log(logs, step=global_step)
|
304 |
|
305 |
if epoch == 0:
|
|
|
326 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
327 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
328 |
|
|
|
|
|
329 |
is_main_process = accelerator.is_main_process
|
330 |
if is_main_process:
|
331 |
unet = unwrap_model(unet)
|
|
|
352 |
train_util.add_dataset_arguments(parser, True, False, True)
|
353 |
train_util.add_training_arguments(parser, True)
|
354 |
train_util.add_sd_saving_arguments(parser)
|
|
|
|
|
355 |
|
356 |
parser.add_argument("--no_token_padding", action="store_true",
|
357 |
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
train_network.py
CHANGED
@@ -1,4 +1,8 @@
|
|
|
|
|
|
|
|
1 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
2 |
import importlib
|
3 |
import argparse
|
4 |
import gc
|
@@ -11,41 +15,94 @@ import json
|
|
11 |
from tqdm import tqdm
|
12 |
import torch
|
13 |
from accelerate.utils import set_seed
|
|
|
14 |
from diffusers import DDPMScheduler
|
15 |
|
16 |
import library.train_util as train_util
|
17 |
-
from library.train_util import
|
18 |
-
DreamBoothDataset,
|
19 |
-
)
|
20 |
-
import library.config_util as config_util
|
21 |
-
from library.config_util import (
|
22 |
-
ConfigSanitizer,
|
23 |
-
BlueprintGenerator,
|
24 |
-
)
|
25 |
|
26 |
|
27 |
def collate_fn(examples):
|
28 |
return examples[0]
|
29 |
|
30 |
|
31 |
-
# TODO 他のスクリプトと共通化する
|
32 |
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
33 |
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
34 |
|
35 |
if args.network_train_unet_only:
|
36 |
-
logs["lr/unet"] =
|
37 |
elif args.network_train_text_encoder_only:
|
38 |
-
logs["lr/textencoder"] =
|
39 |
else:
|
40 |
-
logs["lr/textencoder"] =
|
41 |
-
logs["lr/unet"] =
|
42 |
-
|
43 |
-
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
|
44 |
-
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
|
45 |
|
46 |
return logs
|
47 |
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
def train(args):
|
50 |
session_id = random.randint(0, 2**32)
|
51 |
training_started_at = time.time()
|
@@ -54,7 +111,6 @@ def train(args):
|
|
54 |
|
55 |
cache_latents = args.cache_latents
|
56 |
use_dreambooth_method = args.in_json is None
|
57 |
-
use_user_config = args.dataset_config is not None
|
58 |
|
59 |
if args.seed is not None:
|
60 |
set_seed(args.seed)
|
@@ -62,51 +118,38 @@ def train(args):
|
|
62 |
tokenizer = train_util.load_tokenizer(args)
|
63 |
|
64 |
# データセットを準備する
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
else:
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
"image_dir": args.train_data_dir,
|
87 |
-
"metadata_file": args.in_json,
|
88 |
-
}]
|
89 |
-
}]
|
90 |
-
}
|
91 |
-
|
92 |
-
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
93 |
-
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
94 |
|
95 |
if args.debug_dataset:
|
96 |
-
train_util.debug_dataset(
|
97 |
return
|
98 |
-
if len(
|
99 |
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
100 |
return
|
101 |
|
102 |
-
if cache_latents:
|
103 |
-
assert train_dataset_group.is_latent_cacheable(
|
104 |
-
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
105 |
-
|
106 |
# acceleratorを準備する
|
107 |
print("prepare accelerator")
|
108 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
109 |
-
is_main_process = accelerator.is_main_process
|
110 |
|
111 |
# mixed precisionに対応した型を用意しておき適宜castする
|
112 |
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
@@ -118,7 +161,7 @@ def train(args):
|
|
118 |
if args.lowram:
|
119 |
text_encoder.to("cuda")
|
120 |
unet.to("cuda")
|
121 |
-
|
122 |
# モデルに xformers とか memory efficient attention を組み込む
|
123 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
124 |
|
@@ -128,15 +171,13 @@ def train(args):
|
|
128 |
vae.requires_grad_(False)
|
129 |
vae.eval()
|
130 |
with torch.no_grad():
|
131 |
-
|
132 |
vae.to("cpu")
|
133 |
if torch.cuda.is_available():
|
134 |
torch.cuda.empty_cache()
|
135 |
gc.collect()
|
136 |
|
137 |
# prepare network
|
138 |
-
import sys
|
139 |
-
sys.path.append(os.path.dirname(__file__))
|
140 |
print("import network module:", args.network_module)
|
141 |
network_module = importlib.import_module(args.network_module)
|
142 |
|
@@ -167,25 +208,48 @@ def train(args):
|
|
167 |
# 学習に必要なクラスを準備する
|
168 |
print("prepare optimizer, data loader etc.")
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
171 |
-
|
|
|
|
|
172 |
|
173 |
# dataloaderを準備する
|
174 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
175 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
176 |
train_dataloader = torch.utils.data.DataLoader(
|
177 |
-
|
178 |
|
179 |
# 学習ステップ数を計算する
|
180 |
if args.max_train_epochs is not None:
|
181 |
-
args.max_train_steps = args.max_train_epochs *
|
182 |
-
|
183 |
-
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
184 |
|
185 |
# lr schedulerを用意する
|
186 |
-
lr_scheduler =
|
187 |
-
|
188 |
-
|
|
|
|
|
189 |
|
190 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
191 |
if args.full_fp16:
|
@@ -253,21 +317,17 @@ def train(args):
|
|
253 |
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
254 |
|
255 |
# 学習する
|
256 |
-
# TODO: find a way to handle total batch size when there are multiple datasets
|
257 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
269 |
-
|
270 |
-
# TODO refactor metadata creation and move to util
|
271 |
metadata = {
|
272 |
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
273 |
"ss_training_started_at": training_started_at, # unix timestamp
|
@@ -275,10 +335,12 @@ def train(args):
|
|
275 |
"ss_learning_rate": args.learning_rate,
|
276 |
"ss_text_encoder_lr": args.text_encoder_lr,
|
277 |
"ss_unet_lr": args.unet_lr,
|
278 |
-
"ss_num_train_images":
|
279 |
-
"ss_num_reg_images":
|
280 |
"ss_num_batches_per_epoch": len(train_dataloader),
|
281 |
"ss_num_epochs": num_train_epochs,
|
|
|
|
|
282 |
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
283 |
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
284 |
"ss_max_train_steps": args.max_train_steps,
|
@@ -290,156 +352,33 @@ def train(args):
|
|
290 |
"ss_mixed_precision": args.mixed_precision,
|
291 |
"ss_full_fp16": bool(args.full_fp16),
|
292 |
"ss_v2": bool(args.v2),
|
|
|
293 |
"ss_clip_skip": args.clip_skip,
|
294 |
"ss_max_token_length": args.max_token_length,
|
|
|
|
|
|
|
|
|
295 |
"ss_cache_latents": bool(args.cache_latents),
|
|
|
|
|
|
|
296 |
"ss_seed": args.seed,
|
297 |
-
"
|
298 |
"ss_noise_offset": args.noise_offset,
|
|
|
|
|
|
|
|
|
299 |
"ss_training_comment": args.training_comment, # will not be updated after training
|
300 |
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
301 |
-
"ss_optimizer": optimizer_name
|
302 |
-
"ss_max_grad_norm": args.max_grad_norm,
|
303 |
-
"ss_caption_dropout_rate": args.caption_dropout_rate,
|
304 |
-
"ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
|
305 |
-
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
|
306 |
-
"ss_face_crop_aug_range": args.face_crop_aug_range,
|
307 |
-
"ss_prior_loss_weight": args.prior_loss_weight,
|
308 |
}
|
309 |
|
310 |
-
if
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
datasets_metadata = []
|
315 |
-
tag_frequency = {} # merge tag frequency for metadata editor
|
316 |
-
dataset_dirs_info = {} # merge subset dirs for metadata editor
|
317 |
-
|
318 |
-
for dataset in train_dataset_group.datasets:
|
319 |
-
is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
|
320 |
-
dataset_metadata = {
|
321 |
-
"is_dreambooth": is_dreambooth_dataset,
|
322 |
-
"batch_size_per_device": dataset.batch_size,
|
323 |
-
"num_train_images": dataset.num_train_images, # includes repeating
|
324 |
-
"num_reg_images": dataset.num_reg_images,
|
325 |
-
"resolution": (dataset.width, dataset.height),
|
326 |
-
"enable_bucket": bool(dataset.enable_bucket),
|
327 |
-
"min_bucket_reso": dataset.min_bucket_reso,
|
328 |
-
"max_bucket_reso": dataset.max_bucket_reso,
|
329 |
-
"tag_frequency": dataset.tag_frequency,
|
330 |
-
"bucket_info": dataset.bucket_info,
|
331 |
-
}
|
332 |
-
|
333 |
-
subsets_metadata = []
|
334 |
-
for subset in dataset.subsets:
|
335 |
-
subset_metadata = {
|
336 |
-
"img_count": subset.img_count,
|
337 |
-
"num_repeats": subset.num_repeats,
|
338 |
-
"color_aug": bool(subset.color_aug),
|
339 |
-
"flip_aug": bool(subset.flip_aug),
|
340 |
-
"random_crop": bool(subset.random_crop),
|
341 |
-
"shuffle_caption": bool(subset.shuffle_caption),
|
342 |
-
"keep_tokens": subset.keep_tokens,
|
343 |
-
}
|
344 |
-
|
345 |
-
image_dir_or_metadata_file = None
|
346 |
-
if subset.image_dir:
|
347 |
-
image_dir = os.path.basename(subset.image_dir)
|
348 |
-
subset_metadata["image_dir"] = image_dir
|
349 |
-
image_dir_or_metadata_file = image_dir
|
350 |
-
|
351 |
-
if is_dreambooth_dataset:
|
352 |
-
subset_metadata["class_tokens"] = subset.class_tokens
|
353 |
-
subset_metadata["is_reg"] = subset.is_reg
|
354 |
-
if subset.is_reg:
|
355 |
-
image_dir_or_metadata_file = None # not merging reg dataset
|
356 |
-
else:
|
357 |
-
metadata_file = os.path.basename(subset.metadata_file)
|
358 |
-
subset_metadata["metadata_file"] = metadata_file
|
359 |
-
image_dir_or_metadata_file = metadata_file # may overwrite
|
360 |
-
|
361 |
-
subsets_metadata.append(subset_metadata)
|
362 |
-
|
363 |
-
# merge dataset dir: not reg subset only
|
364 |
-
# TODO update additional-network extension to show detailed dataset config from metadata
|
365 |
-
if image_dir_or_metadata_file is not None:
|
366 |
-
# datasets may have a certain dir multiple times
|
367 |
-
v = image_dir_or_metadata_file
|
368 |
-
i = 2
|
369 |
-
while v in dataset_dirs_info:
|
370 |
-
v = image_dir_or_metadata_file + f" ({i})"
|
371 |
-
i += 1
|
372 |
-
image_dir_or_metadata_file = v
|
373 |
-
|
374 |
-
dataset_dirs_info[image_dir_or_metadata_file] = {
|
375 |
-
"n_repeats": subset.num_repeats,
|
376 |
-
"img_count": subset.img_count
|
377 |
-
}
|
378 |
-
|
379 |
-
dataset_metadata["subsets"] = subsets_metadata
|
380 |
-
datasets_metadata.append(dataset_metadata)
|
381 |
-
|
382 |
-
# merge tag frequency:
|
383 |
-
for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
|
384 |
-
# あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
|
385 |
-
# もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
|
386 |
-
# なので、ここで複数datasetの回数を合算してもあまり意味はない
|
387 |
-
if ds_dir_name in tag_frequency:
|
388 |
-
continue
|
389 |
-
tag_frequency[ds_dir_name] = ds_freq_for_dir
|
390 |
-
|
391 |
-
metadata["ss_datasets"] = json.dumps(datasets_metadata)
|
392 |
-
metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
|
393 |
-
metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
|
394 |
-
else:
|
395 |
-
# conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
|
396 |
-
assert len(
|
397 |
-
train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
|
398 |
-
|
399 |
-
dataset = train_dataset_group.datasets[0]
|
400 |
-
|
401 |
-
dataset_dirs_info = {}
|
402 |
-
reg_dataset_dirs_info = {}
|
403 |
-
if use_dreambooth_method:
|
404 |
-
for subset in dataset.subsets:
|
405 |
-
info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
|
406 |
-
info[os.path.basename(subset.image_dir)] = {
|
407 |
-
"n_repeats": subset.num_repeats,
|
408 |
-
"img_count": subset.img_count
|
409 |
-
}
|
410 |
-
else:
|
411 |
-
for subset in dataset.subsets:
|
412 |
-
dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
|
413 |
-
"n_repeats": subset.num_repeats,
|
414 |
-
"img_count": subset.img_count
|
415 |
-
}
|
416 |
-
|
417 |
-
metadata.update({
|
418 |
-
"ss_batch_size_per_device": args.train_batch_size,
|
419 |
-
"ss_total_batch_size": total_batch_size,
|
420 |
-
"ss_resolution": args.resolution,
|
421 |
-
"ss_color_aug": bool(args.color_aug),
|
422 |
-
"ss_flip_aug": bool(args.flip_aug),
|
423 |
-
"ss_random_crop": bool(args.random_crop),
|
424 |
-
"ss_shuffle_caption": bool(args.shuffle_caption),
|
425 |
-
"ss_enable_bucket": bool(dataset.enable_bucket),
|
426 |
-
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
|
427 |
-
"ss_min_bucket_reso": dataset.min_bucket_reso,
|
428 |
-
"ss_max_bucket_reso": dataset.max_bucket_reso,
|
429 |
-
"ss_keep_tokens": args.keep_tokens,
|
430 |
-
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
|
431 |
-
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
|
432 |
-
"ss_tag_frequency": json.dumps(dataset.tag_frequency),
|
433 |
-
"ss_bucket_info": json.dumps(dataset.bucket_info),
|
434 |
-
})
|
435 |
-
|
436 |
-
# add extra args
|
437 |
-
if args.network_args:
|
438 |
-
metadata["ss_network_args"] = json.dumps(net_kwargs)
|
439 |
-
# for key, value in net_kwargs.items():
|
440 |
-
# metadata["ss_arg_" + key] = value
|
441 |
-
|
442 |
-
# model name and hash
|
443 |
if args.pretrained_model_name_or_path is not None:
|
444 |
sd_model_name = args.pretrained_model_name_or_path
|
445 |
if os.path.exists(sd_model_name):
|
@@ -458,13 +397,6 @@ def train(args):
|
|
458 |
|
459 |
metadata = {k: str(v) for k, v in metadata.items()}
|
460 |
|
461 |
-
# make minimum metadata for filtering
|
462 |
-
minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"]
|
463 |
-
minimum_metadata = {}
|
464 |
-
for key in minimum_keys:
|
465 |
-
if key in metadata:
|
466 |
-
minimum_metadata[key] = metadata[key]
|
467 |
-
|
468 |
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
469 |
global_step = 0
|
470 |
|
@@ -477,9 +409,8 @@ def train(args):
|
|
477 |
loss_list = []
|
478 |
loss_total = 0.0
|
479 |
for epoch in range(num_train_epochs):
|
480 |
-
|
481 |
-
|
482 |
-
train_dataset_group.set_current_epoch(epoch + 1)
|
483 |
|
484 |
metadata["ss_epoch"] = str(epoch+1)
|
485 |
|
@@ -516,7 +447,7 @@ def train(args):
|
|
516 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
517 |
|
518 |
# Predict the noise residual
|
519 |
-
with
|
520 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
521 |
|
522 |
if args.v_parameterization:
|
@@ -534,9 +465,9 @@ def train(args):
|
|
534 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
535 |
|
536 |
accelerator.backward(loss)
|
537 |
-
if accelerator.sync_gradients
|
538 |
params_to_clip = network.get_trainable_params()
|
539 |
-
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
540 |
|
541 |
optimizer.step()
|
542 |
lr_scheduler.step()
|
@@ -547,8 +478,6 @@ def train(args):
|
|
547 |
progress_bar.update(1)
|
548 |
global_step += 1
|
549 |
|
550 |
-
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
551 |
-
|
552 |
current_loss = loss.detach().item()
|
553 |
if epoch == 0:
|
554 |
loss_list.append(current_loss)
|
@@ -579,9 +508,8 @@ def train(args):
|
|
579 |
def save_func():
|
580 |
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
581 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
582 |
-
metadata["ss_training_finished_at"] = str(time.time())
|
583 |
print(f"saving checkpoint: {ckpt_file}")
|
584 |
-
unwrap_model(network).save_weights(ckpt_file, save_dtype,
|
585 |
|
586 |
def remove_old_func(old_epoch_no):
|
587 |
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
@@ -590,18 +518,15 @@ def train(args):
|
|
590 |
print(f"removing old checkpoint: {old_ckpt_file}")
|
591 |
os.remove(old_ckpt_file)
|
592 |
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
597 |
-
|
598 |
-
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
599 |
|
600 |
# end of epoch
|
601 |
|
602 |
metadata["ss_epoch"] = str(num_train_epochs)
|
603 |
-
metadata["ss_training_finished_at"] = str(time.time())
|
604 |
|
|
|
605 |
if is_main_process:
|
606 |
network = unwrap_model(network)
|
607 |
|
@@ -620,7 +545,7 @@ def train(args):
|
|
620 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
621 |
|
622 |
print(f"save trained model to {ckpt_file}")
|
623 |
-
network.save_weights(ckpt_file, save_dtype,
|
624 |
print("model saved.")
|
625 |
|
626 |
|
@@ -630,8 +555,6 @@ if __name__ == '__main__':
|
|
630 |
train_util.add_sd_models_arguments(parser)
|
631 |
train_util.add_dataset_arguments(parser, True, True, True)
|
632 |
train_util.add_training_arguments(parser, True)
|
633 |
-
train_util.add_optimizer_arguments(parser)
|
634 |
-
config_util.add_config_arguments(parser)
|
635 |
|
636 |
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
637 |
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
@@ -639,6 +562,10 @@ if __name__ == '__main__':
|
|
639 |
|
640 |
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
641 |
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
|
|
|
|
|
|
|
|
642 |
|
643 |
parser.add_argument("--network_weights", type=str, default=None,
|
644 |
help="pretrained weights for network / 学習するネットワークの初期重み")
|
|
|
1 |
+
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
2 |
+
from torch.optim import Optimizer
|
3 |
+
from torch.cuda.amp import autocast
|
4 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
5 |
+
from typing import Optional, Union
|
6 |
import importlib
|
7 |
import argparse
|
8 |
import gc
|
|
|
15 |
from tqdm import tqdm
|
16 |
import torch
|
17 |
from accelerate.utils import set_seed
|
18 |
+
import diffusers
|
19 |
from diffusers import DDPMScheduler
|
20 |
|
21 |
import library.train_util as train_util
|
22 |
+
from library.train_util import DreamBoothDataset, FineTuningDataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
def collate_fn(examples):
|
26 |
return examples[0]
|
27 |
|
28 |
|
|
|
29 |
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
30 |
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
31 |
|
32 |
if args.network_train_unet_only:
|
33 |
+
logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
|
34 |
elif args.network_train_text_encoder_only:
|
35 |
+
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
36 |
else:
|
37 |
+
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
38 |
+
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
|
|
|
|
|
|
|
39 |
|
40 |
return logs
|
41 |
|
42 |
|
43 |
+
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
44 |
+
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
45 |
+
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
46 |
+
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
47 |
+
|
48 |
+
|
49 |
+
def get_scheduler_fix(
|
50 |
+
name: Union[str, SchedulerType],
|
51 |
+
optimizer: Optimizer,
|
52 |
+
num_warmup_steps: Optional[int] = None,
|
53 |
+
num_training_steps: Optional[int] = None,
|
54 |
+
num_cycles: int = 1,
|
55 |
+
power: float = 1.0,
|
56 |
+
):
|
57 |
+
"""
|
58 |
+
Unified API to get any scheduler from its name.
|
59 |
+
Args:
|
60 |
+
name (`str` or `SchedulerType`):
|
61 |
+
The name of the scheduler to use.
|
62 |
+
optimizer (`torch.optim.Optimizer`):
|
63 |
+
The optimizer that will be used during training.
|
64 |
+
num_warmup_steps (`int`, *optional*):
|
65 |
+
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
66 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
67 |
+
num_training_steps (`int``, *optional*):
|
68 |
+
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
69 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
70 |
+
num_cycles (`int`, *optional*):
|
71 |
+
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
72 |
+
power (`float`, *optional*, defaults to 1.0):
|
73 |
+
Power factor. See `POLYNOMIAL` scheduler
|
74 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
75 |
+
The index of the last epoch when resuming training.
|
76 |
+
"""
|
77 |
+
name = SchedulerType(name)
|
78 |
+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
79 |
+
if name == SchedulerType.CONSTANT:
|
80 |
+
return schedule_func(optimizer)
|
81 |
+
|
82 |
+
# All other schedulers require `num_warmup_steps`
|
83 |
+
if num_warmup_steps is None:
|
84 |
+
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
85 |
+
|
86 |
+
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
87 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
88 |
+
|
89 |
+
# All other schedulers require `num_training_steps`
|
90 |
+
if num_training_steps is None:
|
91 |
+
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
92 |
+
|
93 |
+
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
94 |
+
return schedule_func(
|
95 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
96 |
+
)
|
97 |
+
|
98 |
+
if name == SchedulerType.POLYNOMIAL:
|
99 |
+
return schedule_func(
|
100 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
101 |
+
)
|
102 |
+
|
103 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
104 |
+
|
105 |
+
|
106 |
def train(args):
|
107 |
session_id = random.randint(0, 2**32)
|
108 |
training_started_at = time.time()
|
|
|
111 |
|
112 |
cache_latents = args.cache_latents
|
113 |
use_dreambooth_method = args.in_json is None
|
|
|
114 |
|
115 |
if args.seed is not None:
|
116 |
set_seed(args.seed)
|
|
|
118 |
tokenizer = train_util.load_tokenizer(args)
|
119 |
|
120 |
# データセットを準備する
|
121 |
+
if use_dreambooth_method:
|
122 |
+
print("Use DreamBooth method.")
|
123 |
+
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
|
124 |
+
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
|
125 |
+
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
126 |
+
args.bucket_reso_steps, args.bucket_no_upscale,
|
127 |
+
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
|
128 |
+
args.random_crop, args.debug_dataset)
|
129 |
else:
|
130 |
+
print("Train with captions.")
|
131 |
+
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
132 |
+
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
133 |
+
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
134 |
+
args.bucket_reso_steps, args.bucket_no_upscale,
|
135 |
+
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
136 |
+
args.dataset_repeats, args.debug_dataset)
|
137 |
+
|
138 |
+
# 学習データのdropout率を設定する
|
139 |
+
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
|
140 |
+
|
141 |
+
train_dataset.make_buckets()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
if args.debug_dataset:
|
144 |
+
train_util.debug_dataset(train_dataset)
|
145 |
return
|
146 |
+
if len(train_dataset) == 0:
|
147 |
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
148 |
return
|
149 |
|
|
|
|
|
|
|
|
|
150 |
# acceleratorを準備する
|
151 |
print("prepare accelerator")
|
152 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
|
|
153 |
|
154 |
# mixed precisionに対応した型を用意しておき適宜castする
|
155 |
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
|
|
161 |
if args.lowram:
|
162 |
text_encoder.to("cuda")
|
163 |
unet.to("cuda")
|
164 |
+
|
165 |
# モデルに xformers とか memory efficient attention を組み込む
|
166 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
167 |
|
|
|
171 |
vae.requires_grad_(False)
|
172 |
vae.eval()
|
173 |
with torch.no_grad():
|
174 |
+
train_dataset.cache_latents(vae)
|
175 |
vae.to("cpu")
|
176 |
if torch.cuda.is_available():
|
177 |
torch.cuda.empty_cache()
|
178 |
gc.collect()
|
179 |
|
180 |
# prepare network
|
|
|
|
|
181 |
print("import network module:", args.network_module)
|
182 |
network_module = importlib.import_module(args.network_module)
|
183 |
|
|
|
208 |
# 学習に必要なクラスを準備する
|
209 |
print("prepare optimizer, data loader etc.")
|
210 |
|
211 |
+
# 8-bit Adamを使う
|
212 |
+
if args.use_8bit_adam:
|
213 |
+
try:
|
214 |
+
import bitsandbytes as bnb
|
215 |
+
except ImportError:
|
216 |
+
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
217 |
+
print("use 8-bit Adam optimizer")
|
218 |
+
optimizer_class = bnb.optim.AdamW8bit
|
219 |
+
elif args.use_lion_optimizer:
|
220 |
+
try:
|
221 |
+
import lion_pytorch
|
222 |
+
except ImportError:
|
223 |
+
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
224 |
+
print("use Lion optimizer")
|
225 |
+
optimizer_class = lion_pytorch.Lion
|
226 |
+
else:
|
227 |
+
optimizer_class = torch.optim.AdamW
|
228 |
+
|
229 |
+
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
230 |
+
|
231 |
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
232 |
+
|
233 |
+
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
234 |
+
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
235 |
|
236 |
# dataloaderを準備する
|
237 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
238 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
239 |
train_dataloader = torch.utils.data.DataLoader(
|
240 |
+
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
241 |
|
242 |
# 学習ステップ数を計算する
|
243 |
if args.max_train_epochs is not None:
|
244 |
+
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
245 |
+
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
|
|
246 |
|
247 |
# lr schedulerを用意する
|
248 |
+
# lr_scheduler = diffusers.optimization.get_scheduler(
|
249 |
+
lr_scheduler = get_scheduler_fix(
|
250 |
+
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
251 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
252 |
+
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
253 |
|
254 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
255 |
if args.full_fp16:
|
|
|
317 |
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
318 |
|
319 |
# 学習する
|
|
|
320 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
321 |
+
print("running training / 学習開始")
|
322 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
|
323 |
+
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
|
324 |
+
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
325 |
+
print(f" num epochs / epoch数: {num_train_epochs}")
|
326 |
+
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
327 |
+
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
328 |
+
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
329 |
+
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
330 |
+
|
|
|
|
|
|
|
331 |
metadata = {
|
332 |
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
333 |
"ss_training_started_at": training_started_at, # unix timestamp
|
|
|
335 |
"ss_learning_rate": args.learning_rate,
|
336 |
"ss_text_encoder_lr": args.text_encoder_lr,
|
337 |
"ss_unet_lr": args.unet_lr,
|
338 |
+
"ss_num_train_images": train_dataset.num_train_images, # includes repeating
|
339 |
+
"ss_num_reg_images": train_dataset.num_reg_images,
|
340 |
"ss_num_batches_per_epoch": len(train_dataloader),
|
341 |
"ss_num_epochs": num_train_epochs,
|
342 |
+
"ss_batch_size_per_device": args.train_batch_size,
|
343 |
+
"ss_total_batch_size": total_batch_size,
|
344 |
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
345 |
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
346 |
"ss_max_train_steps": args.max_train_steps,
|
|
|
352 |
"ss_mixed_precision": args.mixed_precision,
|
353 |
"ss_full_fp16": bool(args.full_fp16),
|
354 |
"ss_v2": bool(args.v2),
|
355 |
+
"ss_resolution": args.resolution,
|
356 |
"ss_clip_skip": args.clip_skip,
|
357 |
"ss_max_token_length": args.max_token_length,
|
358 |
+
"ss_color_aug": bool(args.color_aug),
|
359 |
+
"ss_flip_aug": bool(args.flip_aug),
|
360 |
+
"ss_random_crop": bool(args.random_crop),
|
361 |
+
"ss_shuffle_caption": bool(args.shuffle_caption),
|
362 |
"ss_cache_latents": bool(args.cache_latents),
|
363 |
+
"ss_enable_bucket": bool(train_dataset.enable_bucket),
|
364 |
+
"ss_min_bucket_reso": train_dataset.min_bucket_reso,
|
365 |
+
"ss_max_bucket_reso": train_dataset.max_bucket_reso,
|
366 |
"ss_seed": args.seed,
|
367 |
+
"ss_keep_tokens": args.keep_tokens,
|
368 |
"ss_noise_offset": args.noise_offset,
|
369 |
+
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
370 |
+
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
371 |
+
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
|
372 |
+
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
|
373 |
"ss_training_comment": args.training_comment, # will not be updated after training
|
374 |
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
375 |
+
"ss_optimizer": optimizer_name
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
}
|
377 |
|
378 |
+
# uncomment if another network is added
|
379 |
+
# for key, value in net_kwargs.items():
|
380 |
+
# metadata["ss_arg_" + key] = value
|
381 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
if args.pretrained_model_name_or_path is not None:
|
383 |
sd_model_name = args.pretrained_model_name_or_path
|
384 |
if os.path.exists(sd_model_name):
|
|
|
397 |
|
398 |
metadata = {k: str(v) for k, v in metadata.items()}
|
399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
401 |
global_step = 0
|
402 |
|
|
|
409 |
loss_list = []
|
410 |
loss_total = 0.0
|
411 |
for epoch in range(num_train_epochs):
|
412 |
+
print(f"epoch {epoch+1}/{num_train_epochs}")
|
413 |
+
train_dataset.set_current_epoch(epoch + 1)
|
|
|
414 |
|
415 |
metadata["ss_epoch"] = str(epoch+1)
|
416 |
|
|
|
447 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
448 |
|
449 |
# Predict the noise residual
|
450 |
+
with autocast():
|
451 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
452 |
|
453 |
if args.v_parameterization:
|
|
|
465 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
466 |
|
467 |
accelerator.backward(loss)
|
468 |
+
if accelerator.sync_gradients:
|
469 |
params_to_clip = network.get_trainable_params()
|
470 |
+
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
471 |
|
472 |
optimizer.step()
|
473 |
lr_scheduler.step()
|
|
|
478 |
progress_bar.update(1)
|
479 |
global_step += 1
|
480 |
|
|
|
|
|
481 |
current_loss = loss.detach().item()
|
482 |
if epoch == 0:
|
483 |
loss_list.append(current_loss)
|
|
|
508 |
def save_func():
|
509 |
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
510 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
|
|
511 |
print(f"saving checkpoint: {ckpt_file}")
|
512 |
+
unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
513 |
|
514 |
def remove_old_func(old_epoch_no):
|
515 |
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
|
|
518 |
print(f"removing old checkpoint: {old_ckpt_file}")
|
519 |
os.remove(old_ckpt_file)
|
520 |
|
521 |
+
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
522 |
+
if saving and args.save_state:
|
523 |
+
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
|
|
|
|
|
|
524 |
|
525 |
# end of epoch
|
526 |
|
527 |
metadata["ss_epoch"] = str(num_train_epochs)
|
|
|
528 |
|
529 |
+
is_main_process = accelerator.is_main_process
|
530 |
if is_main_process:
|
531 |
network = unwrap_model(network)
|
532 |
|
|
|
545 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
546 |
|
547 |
print(f"save trained model to {ckpt_file}")
|
548 |
+
network.save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
549 |
print("model saved.")
|
550 |
|
551 |
|
|
|
555 |
train_util.add_sd_models_arguments(parser)
|
556 |
train_util.add_dataset_arguments(parser, True, True, True)
|
557 |
train_util.add_training_arguments(parser, True)
|
|
|
|
|
558 |
|
559 |
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
560 |
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
|
|
562 |
|
563 |
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
564 |
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
565 |
+
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
566 |
+
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
567 |
+
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
568 |
+
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
569 |
|
570 |
parser.add_argument("--network_weights", type=str, default=None,
|
571 |
help="pretrained weights for network / 学習するネットワークの初期重み")
|
train_network_opt.py
CHANGED
@@ -1,5 +1,8 @@
|
|
|
|
|
|
1 |
from torch.cuda.amp import autocast
|
2 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
3 |
import importlib
|
4 |
import argparse
|
5 |
import gc
|
@@ -12,49 +15,138 @@ import json
|
|
12 |
from tqdm import tqdm
|
13 |
import torch
|
14 |
from accelerate.utils import set_seed
|
15 |
-
|
16 |
from diffusers import DDPMScheduler
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
##### バケット拡張のためのモジュール
|
19 |
import append_module
|
20 |
######
|
21 |
import library.train_util as train_util
|
22 |
-
from library.train_util import
|
23 |
-
DreamBoothDataset,
|
24 |
-
)
|
25 |
-
import library.config_util as config_util
|
26 |
-
from library.config_util import (
|
27 |
-
ConfigSanitizer,
|
28 |
-
BlueprintGenerator,
|
29 |
-
)
|
30 |
|
31 |
|
32 |
def collate_fn(examples):
|
33 |
return examples[0]
|
34 |
|
35 |
|
36 |
-
|
37 |
-
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, split_names=None):
|
38 |
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
else:
|
45 |
-
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
46 |
-
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
|
47 |
else:
|
48 |
last_lrs = lr_scheduler.get_last_lr()
|
49 |
-
|
50 |
-
logs[
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
return logs
|
56 |
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
def train(args):
|
59 |
session_id = random.randint(0, 2**32)
|
60 |
training_started_at = time.time()
|
@@ -63,7 +155,6 @@ def train(args):
|
|
63 |
|
64 |
cache_latents = args.cache_latents
|
65 |
use_dreambooth_method = args.in_json is None
|
66 |
-
use_user_config = args.dataset_config is not None
|
67 |
|
68 |
if args.seed is not None:
|
69 |
set_seed(args.seed)
|
@@ -71,72 +162,52 @@ def train(args):
|
|
71 |
tokenizer = train_util.load_tokenizer(args)
|
72 |
|
73 |
# データセットを準備する
|
74 |
-
if
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
88 |
-
else:
|
89 |
-
if use_dreambooth_method:
|
90 |
-
print("Use DreamBooth method.")
|
91 |
-
user_config = {
|
92 |
-
"datasets": [{
|
93 |
-
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
|
94 |
-
}]
|
95 |
-
}
|
96 |
-
else:
|
97 |
-
print("Train with captions.")
|
98 |
-
user_config = {
|
99 |
-
"datasets": [{
|
100 |
-
"subsets": [{
|
101 |
-
"image_dir": args.train_data_dir,
|
102 |
-
"metadata_file": args.in_json,
|
103 |
-
}]
|
104 |
-
}]
|
105 |
-
}
|
106 |
-
|
107 |
-
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
108 |
-
if args.min_resolution:
|
109 |
-
train_dataset_group = append_module.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
110 |
else:
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
if args.debug_dataset:
|
114 |
-
train_util.debug_dataset(
|
115 |
return
|
116 |
-
if len(
|
117 |
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
118 |
return
|
119 |
|
120 |
-
if cache_latents:
|
121 |
-
assert train_dataset_group.is_latent_cacheable(
|
122 |
-
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
123 |
-
|
124 |
# acceleratorを準備する
|
125 |
print("prepare accelerator")
|
126 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
127 |
-
is_main_process = accelerator.is_main_process
|
128 |
|
129 |
# mixed precisionに対応した型を用意しておき適宜castする
|
130 |
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
131 |
|
132 |
# モデルを読み込む
|
133 |
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
text_encoder.to("cuda")
|
138 |
-
unet.to("cuda")
|
139 |
-
|
140 |
# モデルに xformers とか memory efficient attention を組み込む
|
141 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
142 |
|
@@ -146,15 +217,13 @@ def train(args):
|
|
146 |
vae.requires_grad_(False)
|
147 |
vae.eval()
|
148 |
with torch.no_grad():
|
149 |
-
|
150 |
vae.to("cpu")
|
151 |
if torch.cuda.is_available():
|
152 |
torch.cuda.empty_cache()
|
153 |
gc.collect()
|
154 |
|
155 |
# prepare network
|
156 |
-
import sys
|
157 |
-
sys.path.append(os.path.dirname(__file__))
|
158 |
print("import network module:", args.network_module)
|
159 |
network_module = importlib.import_module(args.network_module)
|
160 |
|
@@ -184,65 +253,188 @@ def train(args):
|
|
184 |
|
185 |
# 学習に必要なクラスを準備する
|
186 |
print("prepare optimizer, data loader etc.")
|
187 |
-
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
if args.split_lora_networks:
|
191 |
-
lr_dic, block_args_dic = append_module.create_lr_blocks(args.blocks_lr_setting, args.block_optim_args)
|
192 |
lora_names = append_module.create_split_names(args.split_lora_networks, args.split_lora_level)
|
193 |
-
append_module.replace_prepare_optimizer_params(network
|
194 |
-
trainable_params,
|
195 |
else:
|
196 |
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
if args.
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
if args.lookahead_arg is not None:
|
212 |
-
for _arg in args.lookahead_arg:
|
213 |
-
k, v = _arg.split("=")
|
214 |
-
if k == "k":
|
215 |
-
lookahed_arg[k] = int(v)
|
216 |
-
else:
|
217 |
-
lookahed_arg[k] = float(v)
|
218 |
-
optimizer = torch_optimizer.Lookahead(optimizer, **lookahed_arg)
|
219 |
-
except:
|
220 |
-
print("\n============\ntorch_optimizerのimportに失敗しました Lookaheadを無効化して処理を続けます\n============\n")
|
221 |
# dataloaderを準備する
|
222 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
223 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
224 |
train_dataloader = torch.utils.data.DataLoader(
|
225 |
-
|
226 |
|
227 |
# 学習ステップ数を計算する
|
228 |
if args.max_train_epochs is not None:
|
229 |
-
args.max_train_steps = args.max_train_epochs *
|
230 |
-
|
231 |
-
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
232 |
|
233 |
# lr schedulerを用意する
|
234 |
-
|
235 |
-
|
|
|
|
|
|
|
|
|
236 |
else:
|
237 |
-
lr_scheduler =
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
241 |
#追加機能の設定をコメントに追記して残す
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
args.training_comment=f"{args.training_comment} split_lora_networks: {args.split_lora_networks} split_level: {args.split_lora_level}"
|
246 |
if args.min_resolution:
|
247 |
args.training_comment=f"{args.training_comment} min_resolution: {args.min_resolution} area_step: {args.area_step}"
|
248 |
|
@@ -312,21 +504,17 @@ def train(args):
|
|
312 |
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
313 |
|
314 |
# 学習する
|
315 |
-
# TODO: find a way to handle total batch size when there are multiple datasets
|
316 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
328 |
-
|
329 |
-
# TODO refactor metadata creation and move to util
|
330 |
metadata = {
|
331 |
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
332 |
"ss_training_started_at": training_started_at, # unix timestamp
|
@@ -334,10 +522,12 @@ def train(args):
|
|
334 |
"ss_learning_rate": args.learning_rate,
|
335 |
"ss_text_encoder_lr": args.text_encoder_lr,
|
336 |
"ss_unet_lr": args.unet_lr,
|
337 |
-
"ss_num_train_images":
|
338 |
-
"ss_num_reg_images":
|
339 |
"ss_num_batches_per_epoch": len(train_dataloader),
|
340 |
"ss_num_epochs": num_train_epochs,
|
|
|
|
|
341 |
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
342 |
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
343 |
"ss_max_train_steps": args.max_train_steps,
|
@@ -349,156 +539,32 @@ def train(args):
|
|
349 |
"ss_mixed_precision": args.mixed_precision,
|
350 |
"ss_full_fp16": bool(args.full_fp16),
|
351 |
"ss_v2": bool(args.v2),
|
|
|
352 |
"ss_clip_skip": args.clip_skip,
|
353 |
"ss_max_token_length": args.max_token_length,
|
|
|
|
|
|
|
|
|
354 |
"ss_cache_latents": bool(args.cache_latents),
|
|
|
|
|
|
|
355 |
"ss_seed": args.seed,
|
356 |
-
"
|
357 |
"ss_noise_offset": args.noise_offset,
|
|
|
|
|
|
|
|
|
358 |
"ss_training_comment": args.training_comment, # will not be updated after training
|
359 |
-
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash()
|
360 |
-
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
|
361 |
-
"ss_max_grad_norm": args.max_grad_norm,
|
362 |
-
"ss_caption_dropout_rate": args.caption_dropout_rate,
|
363 |
-
"ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
|
364 |
-
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
|
365 |
-
"ss_face_crop_aug_range": args.face_crop_aug_range,
|
366 |
-
"ss_prior_loss_weight": args.prior_loss_weight,
|
367 |
}
|
368 |
|
369 |
-
if
|
370 |
-
# save metadata of multiple datasets
|
371 |
-
# NOTE: pack "ss_datasets" value as json one time
|
372 |
-
# or should also pack nested collections as json?
|
373 |
-
datasets_metadata = []
|
374 |
-
tag_frequency = {} # merge tag frequency for metadata editor
|
375 |
-
dataset_dirs_info = {} # merge subset dirs for metadata editor
|
376 |
-
|
377 |
-
for dataset in train_dataset_group.datasets:
|
378 |
-
is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
|
379 |
-
dataset_metadata = {
|
380 |
-
"is_dreambooth": is_dreambooth_dataset,
|
381 |
-
"batch_size_per_device": dataset.batch_size,
|
382 |
-
"num_train_images": dataset.num_train_images, # includes repeating
|
383 |
-
"num_reg_images": dataset.num_reg_images,
|
384 |
-
"resolution": (dataset.width, dataset.height),
|
385 |
-
"enable_bucket": bool(dataset.enable_bucket),
|
386 |
-
"min_bucket_reso": dataset.min_bucket_reso,
|
387 |
-
"max_bucket_reso": dataset.max_bucket_reso,
|
388 |
-
"tag_frequency": dataset.tag_frequency,
|
389 |
-
"bucket_info": dataset.bucket_info,
|
390 |
-
}
|
391 |
-
|
392 |
-
subsets_metadata = []
|
393 |
-
for subset in dataset.subsets:
|
394 |
-
subset_metadata = {
|
395 |
-
"img_count": subset.img_count,
|
396 |
-
"num_repeats": subset.num_repeats,
|
397 |
-
"color_aug": bool(subset.color_aug),
|
398 |
-
"flip_aug": bool(subset.flip_aug),
|
399 |
-
"random_crop": bool(subset.random_crop),
|
400 |
-
"shuffle_caption": bool(subset.shuffle_caption),
|
401 |
-
"keep_tokens": subset.keep_tokens,
|
402 |
-
}
|
403 |
-
|
404 |
-
image_dir_or_metadata_file = None
|
405 |
-
if subset.image_dir:
|
406 |
-
image_dir = os.path.basename(subset.image_dir)
|
407 |
-
subset_metadata["image_dir"] = image_dir
|
408 |
-
image_dir_or_metadata_file = image_dir
|
409 |
-
|
410 |
-
if is_dreambooth_dataset:
|
411 |
-
subset_metadata["class_tokens"] = subset.class_tokens
|
412 |
-
subset_metadata["is_reg"] = subset.is_reg
|
413 |
-
if subset.is_reg:
|
414 |
-
image_dir_or_metadata_file = None # not merging reg dataset
|
415 |
-
else:
|
416 |
-
metadata_file = os.path.basename(subset.metadata_file)
|
417 |
-
subset_metadata["metadata_file"] = metadata_file
|
418 |
-
image_dir_or_metadata_file = metadata_file # may overwrite
|
419 |
-
|
420 |
-
subsets_metadata.append(subset_metadata)
|
421 |
-
|
422 |
-
# merge dataset dir: not reg subset only
|
423 |
-
# TODO update additional-network extension to show detailed dataset config from metadata
|
424 |
-
if image_dir_or_metadata_file is not None:
|
425 |
-
# datasets may have a certain dir multiple times
|
426 |
-
v = image_dir_or_metadata_file
|
427 |
-
i = 2
|
428 |
-
while v in dataset_dirs_info:
|
429 |
-
v = image_dir_or_metadata_file + f" ({i})"
|
430 |
-
i += 1
|
431 |
-
image_dir_or_metadata_file = v
|
432 |
-
|
433 |
-
dataset_dirs_info[image_dir_or_metadata_file] = {
|
434 |
-
"n_repeats": subset.num_repeats,
|
435 |
-
"img_count": subset.img_count
|
436 |
-
}
|
437 |
-
|
438 |
-
dataset_metadata["subsets"] = subsets_metadata
|
439 |
-
datasets_metadata.append(dataset_metadata)
|
440 |
-
|
441 |
-
# merge tag frequency:
|
442 |
-
for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
|
443 |
-
# あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
|
444 |
-
# もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
|
445 |
-
# なので、ここで複数datasetの回数を合算してもあまり意味はない
|
446 |
-
if ds_dir_name in tag_frequency:
|
447 |
-
continue
|
448 |
-
tag_frequency[ds_dir_name] = ds_freq_for_dir
|
449 |
-
|
450 |
-
metadata["ss_datasets"] = json.dumps(datasets_metadata)
|
451 |
-
metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
|
452 |
-
metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
|
453 |
-
else:
|
454 |
-
# conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
|
455 |
-
assert len(
|
456 |
-
train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
|
457 |
-
|
458 |
-
dataset = train_dataset_group.datasets[0]
|
459 |
-
|
460 |
-
dataset_dirs_info = {}
|
461 |
-
reg_dataset_dirs_info = {}
|
462 |
-
if use_dreambooth_method:
|
463 |
-
for subset in dataset.subsets:
|
464 |
-
info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
|
465 |
-
info[os.path.basename(subset.image_dir)] = {
|
466 |
-
"n_repeats": subset.num_repeats,
|
467 |
-
"img_count": subset.img_count
|
468 |
-
}
|
469 |
-
else:
|
470 |
-
for subset in dataset.subsets:
|
471 |
-
dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
|
472 |
-
"n_repeats": subset.num_repeats,
|
473 |
-
"img_count": subset.img_count
|
474 |
-
}
|
475 |
-
|
476 |
-
metadata.update({
|
477 |
-
"ss_batch_size_per_device": args.train_batch_size,
|
478 |
-
"ss_total_batch_size": total_batch_size,
|
479 |
-
"ss_resolution": args.resolution,
|
480 |
-
"ss_color_aug": bool(args.color_aug),
|
481 |
-
"ss_flip_aug": bool(args.flip_aug),
|
482 |
-
"ss_random_crop": bool(args.random_crop),
|
483 |
-
"ss_shuffle_caption": bool(args.shuffle_caption),
|
484 |
-
"ss_enable_bucket": bool(dataset.enable_bucket),
|
485 |
-
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
|
486 |
-
"ss_min_bucket_reso": dataset.min_bucket_reso,
|
487 |
-
"ss_max_bucket_reso": dataset.max_bucket_reso,
|
488 |
-
"ss_keep_tokens": args.keep_tokens,
|
489 |
-
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
|
490 |
-
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
|
491 |
-
"ss_tag_frequency": json.dumps(dataset.tag_frequency),
|
492 |
-
"ss_bucket_info": json.dumps(dataset.bucket_info),
|
493 |
-
})
|
494 |
-
|
495 |
-
# add extra args
|
496 |
-
if args.network_args:
|
497 |
-
metadata["ss_network_args"] = json.dumps(net_kwargs)
|
498 |
# for key, value in net_kwargs.items():
|
499 |
# metadata["ss_arg_" + key] = value
|
500 |
|
501 |
-
# model name and hash
|
502 |
if args.pretrained_model_name_or_path is not None:
|
503 |
sd_model_name = args.pretrained_model_name_or_path
|
504 |
if os.path.exists(sd_model_name):
|
@@ -517,13 +583,6 @@ def train(args):
|
|
517 |
|
518 |
metadata = {k: str(v) for k, v in metadata.items()}
|
519 |
|
520 |
-
# make minimum metadata for filtering
|
521 |
-
minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"]
|
522 |
-
minimum_metadata = {}
|
523 |
-
for key in minimum_keys:
|
524 |
-
if key in metadata:
|
525 |
-
minimum_metadata[key] = metadata[key]
|
526 |
-
|
527 |
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
528 |
global_step = 0
|
529 |
|
@@ -536,9 +595,8 @@ def train(args):
|
|
536 |
loss_list = []
|
537 |
loss_total = 0.0
|
538 |
for epoch in range(num_train_epochs):
|
539 |
-
|
540 |
-
|
541 |
-
train_dataset_group.set_current_epoch(epoch + 1)
|
542 |
|
543 |
metadata["ss_epoch"] = str(epoch+1)
|
544 |
|
@@ -575,7 +633,7 @@ def train(args):
|
|
575 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
576 |
|
577 |
# Predict the noise residual
|
578 |
-
with
|
579 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
580 |
|
581 |
if args.v_parameterization:
|
@@ -593,13 +651,12 @@ def train(args):
|
|
593 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
594 |
|
595 |
accelerator.backward(loss)
|
596 |
-
if accelerator.sync_gradients
|
597 |
params_to_clip = network.get_trainable_params()
|
598 |
-
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
599 |
|
600 |
optimizer.step()
|
601 |
-
|
602 |
-
lr_scheduler.step()
|
603 |
optimizer.zero_grad(set_to_none=True)
|
604 |
|
605 |
# Checks if the accelerator has performed an optimization step behind the scenes
|
@@ -607,8 +664,6 @@ def train(args):
|
|
607 |
progress_bar.update(1)
|
608 |
global_step += 1
|
609 |
|
610 |
-
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
611 |
-
|
612 |
current_loss = loss.detach().item()
|
613 |
if epoch == 0:
|
614 |
loss_list.append(current_loss)
|
@@ -621,7 +676,7 @@ def train(args):
|
|
621 |
progress_bar.set_postfix(**logs)
|
622 |
|
623 |
if args.logging_dir is not None:
|
624 |
-
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler
|
625 |
accelerator.log(logs, step=global_step)
|
626 |
|
627 |
if global_step >= args.max_train_steps:
|
@@ -639,9 +694,8 @@ def train(args):
|
|
639 |
def save_func():
|
640 |
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
641 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
642 |
-
metadata["ss_training_finished_at"] = str(time.time())
|
643 |
print(f"saving checkpoint: {ckpt_file}")
|
644 |
-
unwrap_model(network).save_weights(ckpt_file, save_dtype,
|
645 |
|
646 |
def remove_old_func(old_epoch_no):
|
647 |
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
@@ -650,18 +704,15 @@ def train(args):
|
|
650 |
print(f"removing old checkpoint: {old_ckpt_file}")
|
651 |
os.remove(old_ckpt_file)
|
652 |
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
657 |
-
|
658 |
-
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
659 |
|
660 |
# end of epoch
|
661 |
|
662 |
metadata["ss_epoch"] = str(num_train_epochs)
|
663 |
-
metadata["ss_training_finished_at"] = str(time.time())
|
664 |
|
|
|
665 |
if is_main_process:
|
666 |
network = unwrap_model(network)
|
667 |
|
@@ -680,7 +731,7 @@ def train(args):
|
|
680 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
681 |
|
682 |
print(f"save trained model to {ckpt_file}")
|
683 |
-
network.save_weights(ckpt_file, save_dtype,
|
684 |
print("model saved.")
|
685 |
|
686 |
|
@@ -690,8 +741,6 @@ if __name__ == '__main__':
|
|
690 |
train_util.add_sd_models_arguments(parser)
|
691 |
train_util.add_dataset_arguments(parser, True, True, True)
|
692 |
train_util.add_training_arguments(parser, True)
|
693 |
-
train_util.add_optimizer_arguments(parser)
|
694 |
-
config_util.add_config_arguments(parser)
|
695 |
|
696 |
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
697 |
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
@@ -699,6 +748,10 @@ if __name__ == '__main__':
|
|
699 |
|
700 |
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
701 |
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
|
|
|
|
|
|
|
|
702 |
|
703 |
parser.add_argument("--network_weights", type=str, default=None,
|
704 |
help="pretrained weights for network / 学習するネットワークの初期重み")
|
@@ -718,30 +771,27 @@ if __name__ == '__main__':
|
|
718 |
#Optimizer変更関連のオプション追加
|
719 |
append_module.add_append_arguments(parser)
|
720 |
args = append_module.get_config(parser)
|
721 |
-
if not args.not_output_config:
|
722 |
-
#argsを保存する
|
723 |
-
import yaml
|
724 |
-
import datetime
|
725 |
-
_t = datetime.datetime.today().strftime('%Y%m%d_%H%M')
|
726 |
-
if args.output_name==None:
|
727 |
-
config_name = f"train_network_config_{_t}.yaml"
|
728 |
-
else:
|
729 |
-
config_name = f"train_network_config_{os.path.basename(args.output_name)}_{_t}.yaml"
|
730 |
-
print(f"{config_name} に設定を書き出し中...")
|
731 |
-
with open(config_name, mode="w") as f:
|
732 |
-
yaml.dump(args.__dict__, f, indent=4)
|
733 |
|
734 |
if args.resolution==args.min_resolution:
|
735 |
args.min_resolution=None
|
736 |
|
737 |
train(args)
|
738 |
-
print("done!")
|
739 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
740 |
|
741 |
'''
|
742 |
optimizer設定メモ
|
743 |
-
torch_optimizer.AdaBelief
|
744 |
-
adastand.Adastand
|
745 |
(optimizer_argから設定できるように変更するためのメモ)
|
746 |
|
747 |
AdamWのweight_decay初期値は1e-2
|
@@ -771,7 +821,6 @@ Adafactor
|
|
771 |
transformerベースのT5学習において最強とかいう噂のoptimizer
|
772 |
huggingfaceのサンプルパラ
|
773 |
eps=1e-30,1e-3 clip_threshold=1.0 decay_rate=-0.8 relative_step=False scale_parameter=False warmup_init=False
|
774 |
-
epsの二つ目の値1e-3が学習率に影響大きい
|
775 |
|
776 |
AggMo
|
777 |
|
|
|
1 |
+
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
2 |
+
from torch.optim import Optimizer
|
3 |
from torch.cuda.amp import autocast
|
4 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
5 |
+
from typing import Optional, Union
|
6 |
import importlib
|
7 |
import argparse
|
8 |
import gc
|
|
|
15 |
from tqdm import tqdm
|
16 |
import torch
|
17 |
from accelerate.utils import set_seed
|
18 |
+
import diffusers
|
19 |
from diffusers import DDPMScheduler
|
20 |
+
print("**********************************")
|
21 |
+
#先に
|
22 |
+
#pip install torch_optimizer
|
23 |
+
#が必要
|
24 |
+
try:
|
25 |
+
import torch_optimizer as optim
|
26 |
+
except:
|
27 |
+
print("torch_optimizerがインストールされていないためAdafactorとAdastand以外の追加optimzierは使えません。\noptimizerの変更をしたい場合先にpip install torch_optimizerでライブラリを追加してください")
|
28 |
+
try:
|
29 |
+
import adastand
|
30 |
+
except:
|
31 |
+
print("※Adastandが使えません")
|
32 |
+
|
33 |
+
from transformers.optimization import Adafactor, AdafactorSchedule
|
34 |
+
print("**********************************")
|
35 |
##### バケット拡張のためのモジュール
|
36 |
import append_module
|
37 |
######
|
38 |
import library.train_util as train_util
|
39 |
+
from library.train_util import DreamBoothDataset, FineTuningDataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
|
42 |
def collate_fn(examples):
|
43 |
return examples[0]
|
44 |
|
45 |
|
46 |
+
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
|
|
47 |
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
48 |
+
|
49 |
+
if args.network_train_unet_only:
|
50 |
+
logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
|
51 |
+
elif args.network_train_text_encoder_only:
|
52 |
+
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
|
|
|
|
|
|
53 |
else:
|
54 |
last_lrs = lr_scheduler.get_last_lr()
|
55 |
+
if len(last_lrs) == 2:
|
56 |
+
logs["lr/textencoder"] = float(last_lrs[0])
|
57 |
+
logs["lr/unet"] = float(last_lrs[-1]) # may be same to textencoder
|
58 |
+
else:
|
59 |
+
if len(last_lrs) == 4:
|
60 |
+
logs_names = ["textencoder", "lora_unet_mid_block", "unet_down_blocks", "unet_up_blocks"]
|
61 |
+
elif len(last_lrs) == 8:
|
62 |
+
logs_names = ["textencoder", "unet_midblock"]
|
63 |
+
for i in range(3):
|
64 |
+
logs_names.append(f"unet_down_blocks_{i}")
|
65 |
+
logs_names.append(f"unet_up_blocks_{i+1}")
|
66 |
+
else:
|
67 |
+
logs_names = []
|
68 |
+
for i in range(12):
|
69 |
+
logs_names.append(f"text_model_encoder_layers_{i}_")
|
70 |
+
logs_names.append("unet_midblock")
|
71 |
+
for i in range(3):
|
72 |
+
logs_names.append(f"unet_down_blocks_{i}")
|
73 |
+
logs_names.append(f"unet_up_blocks_{i+1}")
|
74 |
+
|
75 |
+
for last_lr, logs_name in zip(last_lrs, logs_names):
|
76 |
+
logs[f"lr/{logs_name}"] = float(last_lr)
|
77 |
|
78 |
return logs
|
79 |
|
80 |
|
81 |
+
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
82 |
+
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
83 |
+
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
84 |
+
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
85 |
+
|
86 |
+
|
87 |
+
def get_scheduler_fix(
|
88 |
+
name: Union[str, SchedulerType],
|
89 |
+
optimizer: Optimizer,
|
90 |
+
num_warmup_steps: Optional[int] = None,
|
91 |
+
num_training_steps: Optional[int] = None,
|
92 |
+
num_cycles: float = 1.,
|
93 |
+
power: float = 1.0,
|
94 |
+
):
|
95 |
+
"""
|
96 |
+
Unified API to get any scheduler from its name.
|
97 |
+
Args:
|
98 |
+
name (`str` or `SchedulerType`):
|
99 |
+
The name of the scheduler to use.
|
100 |
+
optimizer (`torch.optim.Optimizer`):
|
101 |
+
The optimizer that will be used during training.
|
102 |
+
num_warmup_steps (`int`, *optional*):
|
103 |
+
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
104 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
105 |
+
num_training_steps (`int``, *optional*):
|
106 |
+
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
107 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
108 |
+
num_cycles (`int`, *optional*):
|
109 |
+
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
110 |
+
power (`float`, *optional*, defaults to 1.0):
|
111 |
+
Power factor. See `POLYNOMIAL` scheduler
|
112 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
113 |
+
The index of the last epoch when resuming training.
|
114 |
+
"""
|
115 |
+
name = SchedulerType(name)
|
116 |
+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
117 |
+
if name == SchedulerType.CONSTANT:
|
118 |
+
return schedule_func(optimizer)
|
119 |
+
|
120 |
+
# All other schedulers require `num_warmup_steps`
|
121 |
+
if num_warmup_steps is None:
|
122 |
+
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
123 |
+
|
124 |
+
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
125 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
126 |
+
|
127 |
+
# All other schedulers require `num_training_steps`
|
128 |
+
if num_training_steps is None:
|
129 |
+
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
130 |
+
|
131 |
+
if name == SchedulerType.COSINE:
|
132 |
+
print(f"{name} num_cycles: {num_cycles}")
|
133 |
+
return schedule_func(
|
134 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
135 |
+
)
|
136 |
+
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
137 |
+
print(f"{name} num_cycles: {int(num_cycles)}")
|
138 |
+
return schedule_func(
|
139 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=int(num_cycles)
|
140 |
+
)
|
141 |
+
|
142 |
+
if name == SchedulerType.POLYNOMIAL:
|
143 |
+
return schedule_func(
|
144 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
145 |
+
)
|
146 |
+
|
147 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
148 |
+
|
149 |
+
|
150 |
def train(args):
|
151 |
session_id = random.randint(0, 2**32)
|
152 |
training_started_at = time.time()
|
|
|
155 |
|
156 |
cache_latents = args.cache_latents
|
157 |
use_dreambooth_method = args.in_json is None
|
|
|
158 |
|
159 |
if args.seed is not None:
|
160 |
set_seed(args.seed)
|
|
|
162 |
tokenizer = train_util.load_tokenizer(args)
|
163 |
|
164 |
# データセットを準備する
|
165 |
+
if use_dreambooth_method:
|
166 |
+
if args.min_resolution:
|
167 |
+
args.min_resolution = tuple([int(r) for r in args.min_resolution.split(',')])
|
168 |
+
if len(args.min_resolution) == 1:
|
169 |
+
args.min_resolution = (args.min_resolution[0], args.min_resolution[0])
|
170 |
+
|
171 |
+
print("Use DreamBooth method.")
|
172 |
+
train_dataset = append_module.DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
|
173 |
+
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
|
174 |
+
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
175 |
+
args.bucket_reso_steps, args.bucket_no_upscale,
|
176 |
+
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
|
177 |
+
args.random_crop, args.debug_dataset, args.min_resolution, args.area_step)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
else:
|
179 |
+
print("Train with captions.")
|
180 |
+
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
181 |
+
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
182 |
+
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
183 |
+
args.bucket_reso_steps, args.bucket_no_upscale,
|
184 |
+
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
185 |
+
args.dataset_repeats, args.debug_dataset)
|
186 |
+
|
187 |
+
# 学習データのdropout率を設定する
|
188 |
+
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
|
189 |
+
|
190 |
+
train_dataset.make_buckets()
|
191 |
|
192 |
if args.debug_dataset:
|
193 |
+
train_util.debug_dataset(train_dataset)
|
194 |
return
|
195 |
+
if len(train_dataset) == 0:
|
196 |
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
197 |
return
|
198 |
|
|
|
|
|
|
|
|
|
199 |
# acceleratorを準備する
|
200 |
print("prepare accelerator")
|
201 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
|
|
202 |
|
203 |
# mixed precisionに対応した型を用意しておき適宜castする
|
204 |
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
205 |
|
206 |
# モデルを読み込む
|
207 |
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
208 |
+
# unnecessary, but work on low-ram device
|
209 |
+
text_encoder.to("cuda")
|
210 |
+
unet.to("cuda")
|
|
|
|
|
|
|
211 |
# モデルに xformers とか memory efficient attention を組み込む
|
212 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
213 |
|
|
|
217 |
vae.requires_grad_(False)
|
218 |
vae.eval()
|
219 |
with torch.no_grad():
|
220 |
+
train_dataset.cache_latents(vae)
|
221 |
vae.to("cpu")
|
222 |
if torch.cuda.is_available():
|
223 |
torch.cuda.empty_cache()
|
224 |
gc.collect()
|
225 |
|
226 |
# prepare network
|
|
|
|
|
227 |
print("import network module:", args.network_module)
|
228 |
network_module = importlib.import_module(args.network_module)
|
229 |
|
|
|
253 |
|
254 |
# 学習に必要なクラスを準備する
|
255 |
print("prepare optimizer, data loader etc.")
|
256 |
+
try:
|
257 |
+
print(f"torch_optimzier version is {optim.__version__}")
|
258 |
+
not_torch_optimizer_flag = False
|
259 |
+
except:
|
260 |
+
not_torch_optimizer_flag = True
|
261 |
+
try:
|
262 |
+
print(f"adastand version is {adastand.__version__()}")
|
263 |
+
not_adasatand_optimzier_flag = False
|
264 |
+
except:
|
265 |
+
not_adasatand_optimzier_flag = True
|
266 |
+
|
267 |
+
# 8-bit Adamを使う
|
268 |
+
if args.optimizer=="Adafactor" or args.optimizer=="Adastand" or args.optimizer=="Adastand_belief":
|
269 |
+
not_torch_optimizer_flag = False
|
270 |
+
if args.optimizer=="Adafactor":
|
271 |
+
not_adasatand_optimzier_flag = False
|
272 |
+
if not_torch_optimizer_flag or not_adasatand_optimzier_flag:
|
273 |
+
print(f"==========================\n必要なライブラリがないため {args.optimizer} の使用ができません。optimizerを AdamW に変更して実行します\n==========================")
|
274 |
+
args.optimizer="AdamW"
|
275 |
+
if args.use_8bit_adam:
|
276 |
+
if not args.optimizer=="AdamW" and not args.optimizer=="Lamb":
|
277 |
+
print(f"\n==========================\n{args.optimizer} は8bitAdamに実装されていないので8bitAdamをオフにします\n==========================\n")
|
278 |
+
args.use_8bit_adam=False
|
279 |
+
if args.use_8bit_adam:
|
280 |
+
try:
|
281 |
+
import bitsandbytes as bnb
|
282 |
+
except ImportError:
|
283 |
+
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
284 |
+
print("use 8-bit Adam optimizer")
|
285 |
+
args.training_comment=f"{args.training_comment} use_8bit_adam=True"
|
286 |
+
if args.optimizer=="Lamb":
|
287 |
+
optimizer_class = bnb.optim.LAMB8bit
|
288 |
+
else:
|
289 |
+
args.optimizer="AdamW"
|
290 |
+
optimizer_class = bnb.optim.AdamW8bit
|
291 |
+
else:
|
292 |
+
print(f"use {args.optimizer}")
|
293 |
+
if args.optimizer=="RAdam":
|
294 |
+
optimizer_class = torch.optim.RAdam
|
295 |
+
elif args.optimizer=="AdaBound":
|
296 |
+
optimizer_class = optim.AdaBound
|
297 |
+
elif args.optimizer=="AdaBelief":
|
298 |
+
optimizer_class = optim.AdaBelief
|
299 |
+
elif args.optimizer=="AdamP":
|
300 |
+
optimizer_class = optim.AdamP
|
301 |
+
elif args.optimizer=="Adafactor":
|
302 |
+
optimizer_class = Adafactor
|
303 |
+
elif args.optimizer=="Adastand":
|
304 |
+
optimizer_class = adastand.Adastand
|
305 |
+
elif args.optimizer=="Adastand_belief":
|
306 |
+
optimizer_class = adastand.Adastand_b
|
307 |
+
elif args.optimizer=="AggMo":
|
308 |
+
optimizer_class = optim.AggMo
|
309 |
+
elif args.optimizer=="Apollo":
|
310 |
+
optimizer_class = optim.Apollo
|
311 |
+
elif args.optimizer=="Lamb":
|
312 |
+
optimizer_class = optim.Lamb
|
313 |
+
elif args.optimizer=="Ranger":
|
314 |
+
optimizer_class = optim.Ranger
|
315 |
+
elif args.optimizer=="RangerVA":
|
316 |
+
optimizer_class = optim.RangerVA
|
317 |
+
elif args.optimizer=="Yogi":
|
318 |
+
optimizer_class = optim.Yogi
|
319 |
+
elif args.optimizer=="Shampoo":
|
320 |
+
optimizer_class = optim.Shampoo
|
321 |
+
elif args.optimizer=="NovoGrad":
|
322 |
+
optimizer_class = optim.NovoGrad
|
323 |
+
elif args.optimizer=="QHAdam":
|
324 |
+
optimizer_class = optim.QHAdam
|
325 |
+
elif args.optimizer=="DiffGrad" or args.optimizer=="Lookahead_DiffGrad":
|
326 |
+
optimizer_class = optim.DiffGrad
|
327 |
+
elif args.optimizer=="MADGRAD":
|
328 |
+
optimizer_class = optim.MADGRAD
|
329 |
+
else:
|
330 |
+
optimizer_class = torch.optim.AdamW
|
331 |
+
|
332 |
+
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
333 |
+
#optimizerデフォ設定
|
334 |
+
if args.optimizer_arg==None:
|
335 |
+
if args.optimizer=="AdaBelief":
|
336 |
+
args.optimizer_arg = ["eps=1e-16","betas=0.9,0.999","weight_decouple=True","rectify=False","fixed_decay=False"]
|
337 |
+
elif args.optimizer=="DiffGrad":
|
338 |
+
args.optimizer_arg = ["eps=1e-16"]
|
339 |
+
optimizer_arg = {}
|
340 |
+
lookahed_arg = {"k": 5, "alpha": 0.5}
|
341 |
+
adafactor_scheduler_arg = {"initial_lr": 0.}
|
342 |
+
int_args = ["k","n_sma_threshold","warmup"]
|
343 |
+
str_args = ["transformer","grad_transformer"]
|
344 |
+
if not args.optimizer_arg==None and len(args.optimizer_arg)>0:
|
345 |
+
for _opt_arg in args.optimizer_arg:
|
346 |
+
key, value = _opt_arg.split("=")
|
347 |
+
if value=="True" or value=="False":
|
348 |
+
optimizer_arg[key]=bool((value=="True"))
|
349 |
+
elif key=="betas" or key=="nus" or key=="eps2" or (key=="eps" and "," in value):
|
350 |
+
_value = value.split(",")
|
351 |
+
optimizer_arg[key] = (float(_value[0]),float(_value[1]))
|
352 |
+
del _value
|
353 |
+
elif key in int_args:
|
354 |
+
if "Lookahead" in args.optimizer:
|
355 |
+
lookahed_arg[key] = int(value)
|
356 |
+
else:
|
357 |
+
optimizer_arg[key] = int(value)
|
358 |
+
elif key in str_args:
|
359 |
+
optimizer_arg[key] = value
|
360 |
+
else:
|
361 |
+
if key=="alpha" and "Lookahead" in args.optimizer:
|
362 |
+
lookahed_arg[key] = int(value)
|
363 |
+
elif key=="initial_lr" and args.optimizer == "Adafactor":
|
364 |
+
adafactor_scheduler_arg[key] = float(value)
|
365 |
+
else:
|
366 |
+
optimizer_arg[key] = float(value)
|
367 |
+
del _opt_arg
|
368 |
+
AdafactorScheduler_Flag = False
|
369 |
+
list_of_init_lr = []
|
370 |
+
if args.optimizer=="Adafactor":
|
371 |
+
if not "relative_step" in optimizer_arg:
|
372 |
+
optimizer_arg["relative_step"] = True
|
373 |
+
if "warmup_init" in optimizer_arg:
|
374 |
+
if optimizer_arg["warmup_init"]==True and optimizer_arg["relative_step"]==False:
|
375 |
+
print("**************\nwarmup_initはrelative_stepがオンである必要があるためrelative_stepをオンにします\n**************")
|
376 |
+
optimizer_arg["relative_step"] = True
|
377 |
+
if optimizer_arg["relative_step"] == True:
|
378 |
+
AdafactorScheduler_Flag = True
|
379 |
+
list_of_init_lr = [0.,0.]
|
380 |
+
if args.text_encoder_lr is not None: list_of_init_lr[0] = float(args.text_encoder_lr)
|
381 |
+
if args.unet_lr is not None: list_of_init_lr[1] = float(args.unet_lr)
|
382 |
+
#if not "initial_lr" in adafactor_scheduler_arg:
|
383 |
+
# adafactor_scheduler_arg = args.learning_rate
|
384 |
+
args.learning_rate = None
|
385 |
+
args.text_encoder_lr = None
|
386 |
+
args.unet_lr = None
|
387 |
+
print(f"optimizer arg: {optimizer_arg}")
|
388 |
+
print("=-----------------------------------=")
|
389 |
+
if not AdafactorScheduler_Flag: args.split_lora_networks = False
|
390 |
if args.split_lora_networks:
|
|
|
391 |
lora_names = append_module.create_split_names(args.split_lora_networks, args.split_lora_level)
|
392 |
+
append_module.replace_prepare_optimizer_params(network)
|
393 |
+
trainable_params, _list_of_init_lr = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, list_of_init_lr, lora_names)
|
394 |
else:
|
395 |
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
396 |
+
_list_of_init_lr = []
|
397 |
+
print(f"trainable_params_len: {len(trainable_params)}")
|
398 |
+
if len(_list_of_init_lr)>0:
|
399 |
+
list_of_init_lr = _list_of_init_lr
|
400 |
+
print(f"split loras network is {len(list_of_init_lr)}")
|
401 |
+
if len(list_of_init_lr) > 0:
|
402 |
+
adafactor_scheduler_arg["initial_lr"] = list_of_init_lr
|
403 |
+
|
404 |
+
optimizer = optimizer_class(trainable_params, lr=args.learning_rate, **optimizer_arg)
|
405 |
+
|
406 |
+
if args.optimizer=="Lookahead_DiffGrad" or args.optimizer=="Lookahedad_Adam":
|
407 |
+
optimizer = optim.Lookahead(optimizer, **lookahed_arg)
|
408 |
+
print(f"lookahed_arg: {lookahed_arg}")
|
409 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
# dataloaderを準備する
|
411 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
412 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
413 |
train_dataloader = torch.utils.data.DataLoader(
|
414 |
+
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
415 |
|
416 |
# 学習ステップ数を計算する
|
417 |
if args.max_train_epochs is not None:
|
418 |
+
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
419 |
+
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
|
|
420 |
|
421 |
# lr schedulerを用意する
|
422 |
+
# lr_scheduler = diffusers.optimization.get_scheduler(
|
423 |
+
if AdafactorScheduler_Flag:
|
424 |
+
print("===================================\nAdafactorはデフォルトでrelative_stepがオンになっているので lrは自動算出されるためLrScheculerの指定も無効になります\nもし任意のLrやLr_Schedulerを使いたい場合は --optimizer_arg relative_ste=False を指定してください\nまた任意のLrを使う場合は scale_parameter=False も併せて指定するのが推奨です\n===================================")
|
425 |
+
lr_scheduler = append_module.AdafactorSchedule_append(optimizer, **adafactor_scheduler_arg)
|
426 |
+
print(f"AdafactorSchedule initial lrs: {lr_scheduler.get_lr()}")
|
427 |
+
del list_of_init_lr
|
428 |
else:
|
429 |
+
lr_scheduler = get_scheduler_fix(
|
430 |
+
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
431 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
432 |
+
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
433 |
+
|
434 |
#追加機能の設定をコメントに追記して残す
|
435 |
+
args.training_comment=f"{args.training_comment} optimizer: {args.optimizer} / optimizer_arg: {args.optimizer_arg}"
|
436 |
+
if AdafactorScheduler_Flag:
|
437 |
+
args.training_comment=f"{args.training_comment} split_lora_networks: {args.split_lora_networks}"
|
|
|
438 |
if args.min_resolution:
|
439 |
args.training_comment=f"{args.training_comment} min_resolution: {args.min_resolution} area_step: {args.area_step}"
|
440 |
|
|
|
504 |
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
505 |
|
506 |
# 学習する
|
|
|
507 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
508 |
+
print("running training / 学習開始")
|
509 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
|
510 |
+
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
|
511 |
+
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
512 |
+
print(f" num epochs / epoch数: {num_train_epochs}")
|
513 |
+
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
514 |
+
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
515 |
+
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
516 |
+
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
517 |
+
|
|
|
|
|
|
|
518 |
metadata = {
|
519 |
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
520 |
"ss_training_started_at": training_started_at, # unix timestamp
|
|
|
522 |
"ss_learning_rate": args.learning_rate,
|
523 |
"ss_text_encoder_lr": args.text_encoder_lr,
|
524 |
"ss_unet_lr": args.unet_lr,
|
525 |
+
"ss_num_train_images": train_dataset.num_train_images, # includes repeating
|
526 |
+
"ss_num_reg_images": train_dataset.num_reg_images,
|
527 |
"ss_num_batches_per_epoch": len(train_dataloader),
|
528 |
"ss_num_epochs": num_train_epochs,
|
529 |
+
"ss_batch_size_per_device": args.train_batch_size,
|
530 |
+
"ss_total_batch_size": total_batch_size,
|
531 |
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
532 |
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
533 |
"ss_max_train_steps": args.max_train_steps,
|
|
|
539 |
"ss_mixed_precision": args.mixed_precision,
|
540 |
"ss_full_fp16": bool(args.full_fp16),
|
541 |
"ss_v2": bool(args.v2),
|
542 |
+
"ss_resolution": args.resolution,
|
543 |
"ss_clip_skip": args.clip_skip,
|
544 |
"ss_max_token_length": args.max_token_length,
|
545 |
+
"ss_color_aug": bool(args.color_aug),
|
546 |
+
"ss_flip_aug": bool(args.flip_aug),
|
547 |
+
"ss_random_crop": bool(args.random_crop),
|
548 |
+
"ss_shuffle_caption": bool(args.shuffle_caption),
|
549 |
"ss_cache_latents": bool(args.cache_latents),
|
550 |
+
"ss_enable_bucket": bool(train_dataset.enable_bucket),
|
551 |
+
"ss_min_bucket_reso": train_dataset.min_bucket_reso,
|
552 |
+
"ss_max_bucket_reso": train_dataset.max_bucket_reso,
|
553 |
"ss_seed": args.seed,
|
554 |
+
"ss_keep_tokens": args.keep_tokens,
|
555 |
"ss_noise_offset": args.noise_offset,
|
556 |
+
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
557 |
+
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
558 |
+
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
|
559 |
+
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
|
560 |
"ss_training_comment": args.training_comment, # will not be updated after training
|
561 |
+
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
562 |
}
|
563 |
|
564 |
+
# uncomment if another network is added
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
565 |
# for key, value in net_kwargs.items():
|
566 |
# metadata["ss_arg_" + key] = value
|
567 |
|
|
|
568 |
if args.pretrained_model_name_or_path is not None:
|
569 |
sd_model_name = args.pretrained_model_name_or_path
|
570 |
if os.path.exists(sd_model_name):
|
|
|
583 |
|
584 |
metadata = {k: str(v) for k, v in metadata.items()}
|
585 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
586 |
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
587 |
global_step = 0
|
588 |
|
|
|
595 |
loss_list = []
|
596 |
loss_total = 0.0
|
597 |
for epoch in range(num_train_epochs):
|
598 |
+
print(f"epoch {epoch+1}/{num_train_epochs}")
|
599 |
+
train_dataset.set_current_epoch(epoch + 1)
|
|
|
600 |
|
601 |
metadata["ss_epoch"] = str(epoch+1)
|
602 |
|
|
|
633 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
634 |
|
635 |
# Predict the noise residual
|
636 |
+
with autocast():
|
637 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
638 |
|
639 |
if args.v_parameterization:
|
|
|
651 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
652 |
|
653 |
accelerator.backward(loss)
|
654 |
+
if accelerator.sync_gradients:
|
655 |
params_to_clip = network.get_trainable_params()
|
656 |
+
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
657 |
|
658 |
optimizer.step()
|
659 |
+
lr_scheduler.step()
|
|
|
660 |
optimizer.zero_grad(set_to_none=True)
|
661 |
|
662 |
# Checks if the accelerator has performed an optimization step behind the scenes
|
|
|
664 |
progress_bar.update(1)
|
665 |
global_step += 1
|
666 |
|
|
|
|
|
667 |
current_loss = loss.detach().item()
|
668 |
if epoch == 0:
|
669 |
loss_list.append(current_loss)
|
|
|
676 |
progress_bar.set_postfix(**logs)
|
677 |
|
678 |
if args.logging_dir is not None:
|
679 |
+
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
680 |
accelerator.log(logs, step=global_step)
|
681 |
|
682 |
if global_step >= args.max_train_steps:
|
|
|
694 |
def save_func():
|
695 |
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
696 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
|
|
697 |
print(f"saving checkpoint: {ckpt_file}")
|
698 |
+
unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
699 |
|
700 |
def remove_old_func(old_epoch_no):
|
701 |
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
|
|
704 |
print(f"removing old checkpoint: {old_ckpt_file}")
|
705 |
os.remove(old_ckpt_file)
|
706 |
|
707 |
+
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
708 |
+
if saving and args.save_state:
|
709 |
+
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
|
|
|
|
|
|
710 |
|
711 |
# end of epoch
|
712 |
|
713 |
metadata["ss_epoch"] = str(num_train_epochs)
|
|
|
714 |
|
715 |
+
is_main_process = accelerator.is_main_process
|
716 |
if is_main_process:
|
717 |
network = unwrap_model(network)
|
718 |
|
|
|
731 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
732 |
|
733 |
print(f"save trained model to {ckpt_file}")
|
734 |
+
network.save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
735 |
print("model saved.")
|
736 |
|
737 |
|
|
|
741 |
train_util.add_sd_models_arguments(parser)
|
742 |
train_util.add_dataset_arguments(parser, True, True, True)
|
743 |
train_util.add_training_arguments(parser, True)
|
|
|
|
|
744 |
|
745 |
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
746 |
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
|
|
748 |
|
749 |
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
750 |
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
751 |
+
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
752 |
+
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
753 |
+
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
754 |
+
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
755 |
|
756 |
parser.add_argument("--network_weights", type=str, default=None,
|
757 |
help="pretrained weights for network / 学習するネットワークの初期重み")
|
|
|
771 |
#Optimizer変更関連のオプション追加
|
772 |
append_module.add_append_arguments(parser)
|
773 |
args = append_module.get_config(parser)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
774 |
|
775 |
if args.resolution==args.min_resolution:
|
776 |
args.min_resolution=None
|
777 |
|
778 |
train(args)
|
|
|
779 |
|
780 |
+
#学習が終わったら現在のargsを保存する
|
781 |
+
# import yaml
|
782 |
+
# import datetime
|
783 |
+
# _t = datetime.datetime.today().strftime('%Y%m%d_%H%M')
|
784 |
+
# if args.output_name==None:
|
785 |
+
# config_name = f"train_network_config_{_t}.yaml"
|
786 |
+
# else:
|
787 |
+
# config_name = f"train_network_config_{os.path.basename(args.output_name)}_{_t}.yaml"
|
788 |
+
# print(f"{config_name} に設定を書き出し中...")
|
789 |
+
# with open(config_name, mode="w") as f:
|
790 |
+
# yaml.dump(args.__dict__, f, indent=4)
|
791 |
+
# print("done!")
|
792 |
|
793 |
'''
|
794 |
optimizer設定メモ
|
|
|
|
|
795 |
(optimizer_argから設定できるように変更するためのメモ)
|
796 |
|
797 |
AdamWのweight_decay初期値は1e-2
|
|
|
821 |
transformerベースのT5学習において最強とかいう噂のoptimizer
|
822 |
huggingfaceのサンプルパラ
|
823 |
eps=1e-30,1e-3 clip_threshold=1.0 decay_rate=-0.8 relative_step=False scale_parameter=False warmup_init=False
|
|
|
824 |
|
825 |
AggMo
|
826 |
|
train_textual_inversion.py
CHANGED
@@ -11,11 +11,7 @@ import diffusers
|
|
11 |
from diffusers import DDPMScheduler
|
12 |
|
13 |
import library.train_util as train_util
|
14 |
-
|
15 |
-
from library.config_util import (
|
16 |
-
ConfigSanitizer,
|
17 |
-
BlueprintGenerator,
|
18 |
-
)
|
19 |
|
20 |
imagenet_templates_small = [
|
21 |
"a photo of a {}",
|
@@ -83,6 +79,7 @@ def train(args):
|
|
83 |
train_util.prepare_dataset_args(args, True)
|
84 |
|
85 |
cache_latents = args.cache_latents
|
|
|
86 |
|
87 |
if args.seed is not None:
|
88 |
set_seed(args.seed)
|
@@ -142,35 +139,21 @@ def train(args):
|
|
142 |
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
143 |
|
144 |
# データセットを準備する
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
else:
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
}
|
161 |
-
else:
|
162 |
-
print("Train with captions.")
|
163 |
-
user_config = {
|
164 |
-
"datasets": [{
|
165 |
-
"subsets": [{
|
166 |
-
"image_dir": args.train_data_dir,
|
167 |
-
"metadata_file": args.in_json,
|
168 |
-
}]
|
169 |
-
}]
|
170 |
-
}
|
171 |
-
|
172 |
-
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
173 |
-
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
174 |
|
175 |
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
176 |
if use_template:
|
@@ -180,30 +163,20 @@ def train(args):
|
|
180 |
captions = []
|
181 |
for tmpl in templates:
|
182 |
captions.append(tmpl.format(replace_to))
|
183 |
-
|
|
|
|
|
|
|
184 |
|
185 |
-
|
186 |
-
prompt_replacement = (args.token_string, replace_to)
|
187 |
-
else:
|
188 |
-
prompt_replacement = None
|
189 |
-
else:
|
190 |
-
if args.num_vectors_per_token > 1:
|
191 |
-
replace_to = " ".join(token_strings)
|
192 |
-
train_dataset_group.add_replacement(args.token_string, replace_to)
|
193 |
-
prompt_replacement = (args.token_string, replace_to)
|
194 |
-
else:
|
195 |
-
prompt_replacement = None
|
196 |
|
197 |
if args.debug_dataset:
|
198 |
-
train_util.debug_dataset(
|
199 |
return
|
200 |
-
if len(
|
201 |
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
202 |
return
|
203 |
|
204 |
-
if cache_latents:
|
205 |
-
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
206 |
-
|
207 |
# モデルに xformers とか memory efficient attention を組み込む
|
208 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
209 |
|
@@ -213,7 +186,7 @@ def train(args):
|
|
213 |
vae.requires_grad_(False)
|
214 |
vae.eval()
|
215 |
with torch.no_grad():
|
216 |
-
|
217 |
vae.to("cpu")
|
218 |
if torch.cuda.is_available():
|
219 |
torch.cuda.empty_cache()
|
@@ -225,14 +198,35 @@ def train(args):
|
|
225 |
|
226 |
# 学習に必要なクラスを準備する
|
227 |
print("prepare optimizer, data loader etc.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
trainable_params = text_encoder.get_input_embeddings().parameters()
|
229 |
-
|
|
|
|
|
230 |
|
231 |
# dataloaderを準備する
|
232 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
233 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
234 |
train_dataloader = torch.utils.data.DataLoader(
|
235 |
-
|
236 |
|
237 |
# 学習ステップ数を計算する
|
238 |
if args.max_train_epochs is not None:
|
@@ -240,9 +234,8 @@ def train(args):
|
|
240 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
241 |
|
242 |
# lr schedulerを用意する
|
243 |
-
lr_scheduler =
|
244 |
-
|
245 |
-
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
246 |
|
247 |
# acceleratorがなんかよろしくやってくれるらしい
|
248 |
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
@@ -290,8 +283,8 @@ def train(args):
|
|
290 |
# 学習する
|
291 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
292 |
print("running training / 学習開始")
|
293 |
-
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {
|
294 |
-
print(f" num reg images / 正則化画像の数: {
|
295 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
296 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
297 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
@@ -310,11 +303,12 @@ def train(args):
|
|
310 |
|
311 |
for epoch in range(num_train_epochs):
|
312 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
313 |
-
|
314 |
|
315 |
text_encoder.train()
|
316 |
|
317 |
loss_total = 0
|
|
|
318 |
for step, batch in enumerate(train_dataloader):
|
319 |
with accelerator.accumulate(text_encoder):
|
320 |
with torch.no_grad():
|
@@ -363,9 +357,9 @@ def train(args):
|
|
363 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
364 |
|
365 |
accelerator.backward(loss)
|
366 |
-
if accelerator.sync_gradients
|
367 |
params_to_clip = text_encoder.get_input_embeddings().parameters()
|
368 |
-
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
369 |
|
370 |
optimizer.step()
|
371 |
lr_scheduler.step()
|
@@ -380,14 +374,9 @@ def train(args):
|
|
380 |
progress_bar.update(1)
|
381 |
global_step += 1
|
382 |
|
383 |
-
train_util.sample_images(accelerator, args, None, global_step, accelerator.device,
|
384 |
-
vae, tokenizer, text_encoder, unet, prompt_replacement)
|
385 |
-
|
386 |
current_loss = loss.detach().item()
|
387 |
if args.logging_dir is not None:
|
388 |
-
logs = {"loss": current_loss, "lr":
|
389 |
-
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
390 |
-
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
391 |
accelerator.log(logs, step=global_step)
|
392 |
|
393 |
loss_total += current_loss
|
@@ -405,6 +394,8 @@ def train(args):
|
|
405 |
accelerator.wait_for_everyone()
|
406 |
|
407 |
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
|
|
|
|
408 |
|
409 |
if args.save_every_n_epochs is not None:
|
410 |
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
@@ -426,9 +417,6 @@ def train(args):
|
|
426 |
if saving and args.save_state:
|
427 |
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
428 |
|
429 |
-
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device,
|
430 |
-
vae, tokenizer, text_encoder, unet, prompt_replacement)
|
431 |
-
|
432 |
# end of epoch
|
433 |
|
434 |
is_main_process = accelerator.is_main_process
|
@@ -503,8 +491,6 @@ if __name__ == '__main__':
|
|
503 |
train_util.add_sd_models_arguments(parser)
|
504 |
train_util.add_dataset_arguments(parser, True, True, False)
|
505 |
train_util.add_training_arguments(parser, True)
|
506 |
-
train_util.add_optimizer_arguments(parser)
|
507 |
-
config_util.add_config_arguments(parser)
|
508 |
|
509 |
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
510 |
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|
|
|
11 |
from diffusers import DDPMScheduler
|
12 |
|
13 |
import library.train_util as train_util
|
14 |
+
from library.train_util import DreamBoothDataset, FineTuningDataset
|
|
|
|
|
|
|
|
|
15 |
|
16 |
imagenet_templates_small = [
|
17 |
"a photo of a {}",
|
|
|
79 |
train_util.prepare_dataset_args(args, True)
|
80 |
|
81 |
cache_latents = args.cache_latents
|
82 |
+
use_dreambooth_method = args.in_json is None
|
83 |
|
84 |
if args.seed is not None:
|
85 |
set_seed(args.seed)
|
|
|
139 |
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
140 |
|
141 |
# データセットを準備する
|
142 |
+
if use_dreambooth_method:
|
143 |
+
print("Use DreamBooth method.")
|
144 |
+
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
|
145 |
+
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
|
146 |
+
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
147 |
+
args.bucket_reso_steps, args.bucket_no_upscale,
|
148 |
+
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
|
149 |
else:
|
150 |
+
print("Train with captions.")
|
151 |
+
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
152 |
+
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
153 |
+
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
154 |
+
args.bucket_reso_steps, args.bucket_no_upscale,
|
155 |
+
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
156 |
+
args.dataset_repeats, args.debug_dataset)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
159 |
if use_template:
|
|
|
163 |
captions = []
|
164 |
for tmpl in templates:
|
165 |
captions.append(tmpl.format(replace_to))
|
166 |
+
train_dataset.add_replacement("", captions)
|
167 |
+
elif args.num_vectors_per_token > 1:
|
168 |
+
replace_to = " ".join(token_strings)
|
169 |
+
train_dataset.add_replacement(args.token_string, replace_to)
|
170 |
|
171 |
+
train_dataset.make_buckets()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
if args.debug_dataset:
|
174 |
+
train_util.debug_dataset(train_dataset, show_input_ids=True)
|
175 |
return
|
176 |
+
if len(train_dataset) == 0:
|
177 |
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
178 |
return
|
179 |
|
|
|
|
|
|
|
180 |
# モデルに xformers とか memory efficient attention を組み込む
|
181 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
182 |
|
|
|
186 |
vae.requires_grad_(False)
|
187 |
vae.eval()
|
188 |
with torch.no_grad():
|
189 |
+
train_dataset.cache_latents(vae)
|
190 |
vae.to("cpu")
|
191 |
if torch.cuda.is_available():
|
192 |
torch.cuda.empty_cache()
|
|
|
198 |
|
199 |
# 学習に必要なクラスを準備する
|
200 |
print("prepare optimizer, data loader etc.")
|
201 |
+
|
202 |
+
# 8-bit Adamを使う
|
203 |
+
if args.use_8bit_adam:
|
204 |
+
try:
|
205 |
+
import bitsandbytes as bnb
|
206 |
+
except ImportError:
|
207 |
+
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
208 |
+
print("use 8-bit Adam optimizer")
|
209 |
+
optimizer_class = bnb.optim.AdamW8bit
|
210 |
+
elif args.use_lion_optimizer:
|
211 |
+
try:
|
212 |
+
import lion_pytorch
|
213 |
+
except ImportError:
|
214 |
+
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
215 |
+
print("use Lion optimizer")
|
216 |
+
optimizer_class = lion_pytorch.Lion
|
217 |
+
else:
|
218 |
+
optimizer_class = torch.optim.AdamW
|
219 |
+
|
220 |
trainable_params = text_encoder.get_input_embeddings().parameters()
|
221 |
+
|
222 |
+
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
223 |
+
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
224 |
|
225 |
# dataloaderを準備する
|
226 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
227 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
228 |
train_dataloader = torch.utils.data.DataLoader(
|
229 |
+
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
230 |
|
231 |
# 学習ステップ数を計算する
|
232 |
if args.max_train_epochs is not None:
|
|
|
234 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
235 |
|
236 |
# lr schedulerを用意する
|
237 |
+
lr_scheduler = diffusers.optimization.get_scheduler(
|
238 |
+
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
|
|
239 |
|
240 |
# acceleratorがなんかよろしくやってくれるらしい
|
241 |
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
|
|
283 |
# 学習する
|
284 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
285 |
print("running training / 学習開始")
|
286 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
|
287 |
+
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
|
288 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
289 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
290 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
|
|
303 |
|
304 |
for epoch in range(num_train_epochs):
|
305 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
306 |
+
train_dataset.set_current_epoch(epoch + 1)
|
307 |
|
308 |
text_encoder.train()
|
309 |
|
310 |
loss_total = 0
|
311 |
+
bef_epo_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
312 |
for step, batch in enumerate(train_dataloader):
|
313 |
with accelerator.accumulate(text_encoder):
|
314 |
with torch.no_grad():
|
|
|
357 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
358 |
|
359 |
accelerator.backward(loss)
|
360 |
+
if accelerator.sync_gradients:
|
361 |
params_to_clip = text_encoder.get_input_embeddings().parameters()
|
362 |
+
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
363 |
|
364 |
optimizer.step()
|
365 |
lr_scheduler.step()
|
|
|
374 |
progress_bar.update(1)
|
375 |
global_step += 1
|
376 |
|
|
|
|
|
|
|
377 |
current_loss = loss.detach().item()
|
378 |
if args.logging_dir is not None:
|
379 |
+
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
|
|
380 |
accelerator.log(logs, step=global_step)
|
381 |
|
382 |
loss_total += current_loss
|
|
|
394 |
accelerator.wait_for_everyone()
|
395 |
|
396 |
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
397 |
+
# d = updated_embs - bef_epo_embs
|
398 |
+
# print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
|
399 |
|
400 |
if args.save_every_n_epochs is not None:
|
401 |
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
|
|
417 |
if saving and args.save_state:
|
418 |
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
419 |
|
|
|
|
|
|
|
420 |
# end of epoch
|
421 |
|
422 |
is_main_process = accelerator.is_main_process
|
|
|
491 |
train_util.add_sd_models_arguments(parser)
|
492 |
train_util.add_dataset_arguments(parser, True, True, False)
|
493 |
train_util.add_training_arguments(parser, True)
|
|
|
|
|
494 |
|
495 |
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
496 |
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|