abc
commited on
Commit
·
74be2a5
1
Parent(s):
fbecb28
Upload 35 files
Browse files- append_module.py +56 -378
- fine_tune.py +45 -50
- gen_img_diffusers.py +48 -213
- library/train_util.py +230 -823
- networks/lora.py +0 -5
- tools/convert_diffusers20_original_sd.py +89 -0
- tools/detect_face_rotate.py +239 -0
- tools/resize_images_to_resolution.py +122 -0
- train_db.py +45 -47
- train_network.py +156 -212
- train_network_opt.py +373 -324
- train_textual_inversion.py +59 -68
append_module.py
CHANGED
@@ -2,19 +2,7 @@ import argparse
|
|
2 |
import json
|
3 |
import shutil
|
4 |
import time
|
5 |
-
from typing import
|
6 |
-
Dict,
|
7 |
-
List,
|
8 |
-
NamedTuple,
|
9 |
-
Optional,
|
10 |
-
Sequence,
|
11 |
-
Tuple,
|
12 |
-
Union,
|
13 |
-
)
|
14 |
-
from dataclasses import (
|
15 |
-
asdict,
|
16 |
-
dataclass,
|
17 |
-
)
|
18 |
from accelerate import Accelerator
|
19 |
from torch.autograd.function import Function
|
20 |
import glob
|
@@ -40,7 +28,6 @@ import safetensors.torch
|
|
40 |
|
41 |
import library.model_util as model_util
|
42 |
import library.train_util as train_util
|
43 |
-
import library.config_util as config_util
|
44 |
|
45 |
#============================================================================================================
|
46 |
#AdafactorScheduleに暫定的にinitial_lrを層別に適用できるようにしたもの
|
@@ -128,124 +115,6 @@ def make_bucket_resolutions_fix(max_reso, min_reso, min_size=256, max_size=1024,
|
|
128 |
return area_size_resos_list, area_size_list
|
129 |
|
130 |
#============================================================================================================
|
131 |
-
#config_util 内より
|
132 |
-
#============================================================================================================
|
133 |
-
@dataclass
|
134 |
-
class DreamBoothDatasetParams(config_util.DreamBoothDatasetParams):
|
135 |
-
min_resolution: Optional[Tuple[int, int]] = None
|
136 |
-
area_step : int = 2
|
137 |
-
|
138 |
-
class ConfigSanitizer(config_util.ConfigSanitizer):
|
139 |
-
#@config_util.curry
|
140 |
-
@staticmethod
|
141 |
-
def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
|
142 |
-
config_util.Schema(config_util.ExactSequence([klass, klass]))(value)
|
143 |
-
return tuple(value)
|
144 |
-
|
145 |
-
#@config_util.curry
|
146 |
-
@staticmethod
|
147 |
-
def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
|
148 |
-
config_util.Schema(config_util.Any(klass, config_util.ExactSequence([klass, klass])))(value)
|
149 |
-
try:
|
150 |
-
config_util.Schema(klass)(value)
|
151 |
-
return (value, value)
|
152 |
-
except:
|
153 |
-
return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
|
154 |
-
# datasets schema
|
155 |
-
DATASET_ASCENDABLE_SCHEMA = {
|
156 |
-
"batch_size": int,
|
157 |
-
"bucket_no_upscale": bool,
|
158 |
-
"bucket_reso_steps": int,
|
159 |
-
"enable_bucket": bool,
|
160 |
-
"max_bucket_reso": int,
|
161 |
-
"min_bucket_reso": int,
|
162 |
-
"resolution": config_util.functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
163 |
-
"min_resolution": config_util.functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
164 |
-
"area_step": int,
|
165 |
-
}
|
166 |
-
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None:
|
167 |
-
super().__init__(support_dreambooth, support_finetuning, support_dropout)
|
168 |
-
def _check(self):
|
169 |
-
print(self.db_dataset_schema)
|
170 |
-
|
171 |
-
class BlueprintGenerator(config_util.BlueprintGenerator):
|
172 |
-
def __init__(self, sanitizer: ConfigSanitizer):
|
173 |
-
config_util.DreamBoothDatasetParams = DreamBoothDatasetParams
|
174 |
-
super().__init__(sanitizer)
|
175 |
-
|
176 |
-
def generate_dataset_group_by_blueprint(dataset_group_blueprint: config_util.DatasetGroupBlueprint):
|
177 |
-
datasets: List[Union[DreamBoothDataset, train_util.FineTuningDataset]] = []
|
178 |
-
|
179 |
-
for dataset_blueprint in dataset_group_blueprint.datasets:
|
180 |
-
if dataset_blueprint.is_dreambooth:
|
181 |
-
subset_klass = train_util.DreamBoothSubset
|
182 |
-
dataset_klass = DreamBoothDataset
|
183 |
-
else:
|
184 |
-
subset_klass = train_util.FineTuningSubset
|
185 |
-
dataset_klass = train_util.FineTuningDataset
|
186 |
-
|
187 |
-
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
188 |
-
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
|
189 |
-
datasets.append(dataset)
|
190 |
-
|
191 |
-
# print info
|
192 |
-
info = ""
|
193 |
-
for i, dataset in enumerate(datasets):
|
194 |
-
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
195 |
-
info += config_util.dedent(f"""\
|
196 |
-
[Dataset {i}]
|
197 |
-
batch_size: {dataset.batch_size}
|
198 |
-
resolution: {(dataset.width, dataset.height)}
|
199 |
-
enable_bucket: {dataset.enable_bucket}
|
200 |
-
""")
|
201 |
-
|
202 |
-
if dataset.enable_bucket:
|
203 |
-
info += config_util.indent(config_util.dedent(f"""\
|
204 |
-
min_bucket_reso: {dataset.min_bucket_reso}
|
205 |
-
max_bucket_reso: {dataset.max_bucket_reso}
|
206 |
-
bucket_reso_steps: {dataset.bucket_reso_steps}
|
207 |
-
bucket_no_upscale: {dataset.bucket_no_upscale}
|
208 |
-
\n"""), " ")
|
209 |
-
else:
|
210 |
-
info += "\n"
|
211 |
-
|
212 |
-
for j, subset in enumerate(dataset.subsets):
|
213 |
-
info += config_util.indent(config_util.dedent(f"""\
|
214 |
-
[Subset {j} of Dataset {i}]
|
215 |
-
image_dir: "{subset.image_dir}"
|
216 |
-
image_count: {subset.img_count}
|
217 |
-
num_repeats: {subset.num_repeats}
|
218 |
-
shuffle_caption: {subset.shuffle_caption}
|
219 |
-
keep_tokens: {subset.keep_tokens}
|
220 |
-
caption_dropout_rate: {subset.caption_dropout_rate}
|
221 |
-
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
222 |
-
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
223 |
-
color_aug: {subset.color_aug}
|
224 |
-
flip_aug: {subset.flip_aug}
|
225 |
-
face_crop_aug_range: {subset.face_crop_aug_range}
|
226 |
-
random_crop: {subset.random_crop}
|
227 |
-
"""), " ")
|
228 |
-
|
229 |
-
if is_dreambooth:
|
230 |
-
info += config_util.indent(config_util.dedent(f"""\
|
231 |
-
is_reg: {subset.is_reg}
|
232 |
-
class_tokens: {subset.class_tokens}
|
233 |
-
caption_extension: {subset.caption_extension}
|
234 |
-
\n"""), " ")
|
235 |
-
else:
|
236 |
-
info += config_util.indent(config_util.dedent(f"""\
|
237 |
-
metadata_file: {subset.metadata_file}
|
238 |
-
\n"""), " ")
|
239 |
-
|
240 |
-
print(info)
|
241 |
-
|
242 |
-
# make buckets first because it determines the length of dataset
|
243 |
-
for i, dataset in enumerate(datasets):
|
244 |
-
print(f"[Dataset {i}]")
|
245 |
-
dataset.make_buckets()
|
246 |
-
|
247 |
-
return train_util.DatasetGroup(datasets)
|
248 |
-
#============================================================================================================
|
249 |
#train_util 内より
|
250 |
#============================================================================================================
|
251 |
class BucketManager_append(train_util.BucketManager):
|
@@ -310,7 +179,7 @@ class BucketManager_append(train_util.BucketManager):
|
|
310 |
bucket_size_id_list.append(bucket_size_id + i + 1)
|
311 |
_min_error = 1000.
|
312 |
_min_id = bucket_size_id
|
313 |
-
for now_size_id in
|
314 |
self.predefined_aspect_ratios = self.predefined_aspect_ratios_list[now_size_id]
|
315 |
ar_errors = self.predefined_aspect_ratios - aspect_ratio
|
316 |
ar_error = np.abs(ar_errors).min()
|
@@ -384,13 +253,13 @@ class BucketManager_append(train_util.BucketManager):
|
|
384 |
return reso, resized_size, ar_error
|
385 |
|
386 |
class DreamBoothDataset(train_util.DreamBoothDataset):
|
387 |
-
def __init__(self,
|
388 |
print("use append DreamBoothDataset")
|
389 |
self.min_resolution = min_resolution
|
390 |
self.area_step = area_step
|
391 |
-
super().__init__(
|
392 |
-
|
393 |
-
|
394 |
def make_buckets(self):
|
395 |
'''
|
396 |
bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
|
@@ -483,50 +352,40 @@ class DreamBoothDataset(train_util.DreamBoothDataset):
|
|
483 |
self.shuffle_buckets()
|
484 |
self._length = len(self.buckets_indices)
|
485 |
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
#============================================================================================================
|
500 |
#networks.lora
|
501 |
#============================================================================================================
|
502 |
-
|
503 |
-
def replace_prepare_optimizer_params(networks
|
504 |
-
def prepare_optimizer_params(self, text_encoder_lr, unet_lr,
|
505 |
-
|
506 |
def enumerate_params(loras, lora_name=None):
|
507 |
params = []
|
508 |
for lora in loras:
|
509 |
if lora_name is not None:
|
510 |
-
|
511 |
-
|
512 |
-
lora_names = [lora_name]
|
513 |
-
if "attentions" in lora_name:
|
514 |
-
lora_names.append(lora_name.replace("attentions", "resnets"))
|
515 |
-
elif "lora_unet_up_blocks_0_resnets_2" in lora_name:
|
516 |
-
lora_names.append("lora_unet_up_blocks_0_upsamplers_")
|
517 |
-
elif "lora_unet_up_blocks_1_attentions_2_" in lora_name:
|
518 |
-
lora_names.append("lora_unet_up_blocks_1_upsamplers_")
|
519 |
-
elif "lora_unet_up_blocks_2_attentions_2_" in lora_name:
|
520 |
-
lora_names.append("lora_unet_up_blocks_2_upsamplers_")
|
521 |
-
|
522 |
-
for _name in lora_names:
|
523 |
-
if _name in lora.lora_name:
|
524 |
-
get_param_flag = True
|
525 |
-
break
|
526 |
-
else:
|
527 |
-
if lora_name in lora.lora_name:
|
528 |
-
get_param_flag = True
|
529 |
-
if get_param_flag: params.extend(lora.parameters())
|
530 |
else:
|
531 |
params.extend(lora.parameters())
|
532 |
return params
|
@@ -534,7 +393,6 @@ def replace_prepare_optimizer_params(networks, network_module):
|
|
534 |
self.requires_grad_(True)
|
535 |
all_params = []
|
536 |
ret_scheduler_lr = []
|
537 |
-
used_names = []
|
538 |
|
539 |
if loranames is not None:
|
540 |
textencoder_names = [None]
|
@@ -547,181 +405,37 @@ def replace_prepare_optimizer_params(networks, network_module):
|
|
547 |
if self.text_encoder_loras:
|
548 |
for textencoder_name in textencoder_names:
|
549 |
param_data = {'params': enumerate_params(self.text_encoder_loras, lora_name=textencoder_name)}
|
550 |
-
used_names.append(textencoder_name)
|
551 |
if text_encoder_lr is not None:
|
552 |
param_data['lr'] = text_encoder_lr
|
553 |
-
|
554 |
-
|
555 |
-
param_data['lr'] = lr_dic[textencoder_name]
|
556 |
-
print(f"{textencoder_name} lr: {param_data['lr']}")
|
557 |
-
|
558 |
-
if block_args_dic is not None:
|
559 |
-
if "lora_te_" in block_args_dic:
|
560 |
-
for pname, value in block_args_dic["lora_te_"].items():
|
561 |
-
param_data[pname] = value
|
562 |
-
if textencoder_name in block_args_dic:
|
563 |
-
for pname, value in block_args_dic[textencoder_name].items():
|
564 |
-
param_data[pname] = value
|
565 |
-
|
566 |
-
if text_encoder_lr is not None:
|
567 |
-
ret_scheduler_lr.append(text_encoder_lr)
|
568 |
-
else:
|
569 |
-
ret_scheduler_lr.append(0.)
|
570 |
-
if lr_dic is not None:
|
571 |
-
if textencoder_name in lr_dic:
|
572 |
-
ret_scheduler_lr[-1] = lr_dic[textencoder_name]
|
573 |
all_params.append(param_data)
|
574 |
|
575 |
if self.unet_loras:
|
576 |
for unet_name in unet_names:
|
577 |
param_data = {'params': enumerate_params(self.unet_loras, lora_name=unet_name)}
|
578 |
-
if len(param_data["params"])==0: continue
|
579 |
-
used_names.append(unet_name)
|
580 |
if unet_lr is not None:
|
581 |
param_data['lr'] = unet_lr
|
582 |
-
|
583 |
-
|
584 |
-
param_data['lr'] = lr_dic[unet_name]
|
585 |
-
print(f"{unet_name} lr: {param_data['lr']}")
|
586 |
-
|
587 |
-
if block_args_dic is not None:
|
588 |
-
if "lora_unet_" in block_args_dic:
|
589 |
-
for pname, value in block_args_dic["lora_unet_"].items():
|
590 |
-
param_data[pname] = value
|
591 |
-
if unet_name in block_args_dic:
|
592 |
-
for pname, value in block_args_dic[unet_name].items():
|
593 |
-
param_data[pname] = value
|
594 |
-
|
595 |
-
if unet_lr is not None:
|
596 |
-
ret_scheduler_lr.append(unet_lr)
|
597 |
-
else:
|
598 |
-
ret_scheduler_lr.append(0.)
|
599 |
-
if lr_dic is not None:
|
600 |
-
if unet_name in lr_dic:
|
601 |
-
ret_scheduler_lr[-1] = lr_dic[unet_name]
|
602 |
all_params.append(param_data)
|
603 |
|
604 |
-
return all_params,
|
605 |
-
|
606 |
-
|
607 |
-
except:
|
608 |
-
print("cant't replace prepare_optimizer_params")
|
609 |
|
610 |
#============================================================================================================
|
611 |
#新規追加
|
612 |
#============================================================================================================
|
613 |
def add_append_arguments(parser: argparse.ArgumentParser):
|
614 |
# for train_network_opt.py
|
615 |
-
|
616 |
-
|
617 |
-
parser.add_argument("--use_lookahead", action="store_true")
|
618 |
-
parser.add_argument("--lookahead_arg", type=str, nargs="*", default=None)
|
619 |
parser.add_argument("--split_lora_networks", action="store_true")
|
620 |
parser.add_argument("--split_lora_level", type=int, default=0, help="どれくらい細分化するかの設定 0がunetのみを層別に 1がunetを大枠で分割 2がtextencoder含めて層別")
|
621 |
-
parser.add_argument("--blocks_lr_setting", type=str, default=None)
|
622 |
-
parser.add_argument("--block_optim_args", type=str, nargs="*", default=None)
|
623 |
parser.add_argument("--min_resolution", type=str, default=None)
|
624 |
parser.add_argument("--area_step", type=int, default=1)
|
625 |
parser.add_argument("--config", type=str, default=None)
|
626 |
-
parser.add_argument("--not_output_config", action="store_true")
|
627 |
-
|
628 |
-
class MyNetwork_Names:
|
629 |
-
ex_block_weight_dic = {
|
630 |
-
"BASE": ["te"],
|
631 |
-
"IN01": ["down_0_at_0","donw_0_res_0"], "IN02": ["down_0_at_1","down_0_res_1"], "IN03": ["down_0_down"],
|
632 |
-
"IN04": ["down_1_at_0","donw_1_res_0"], "IN05": ["down_1_at_1","donw_1_res_1"], "IN06": ["down_1_down"],
|
633 |
-
"IN07": ["down_2_at_0","donw_2_res_0"], "IN08": ["down_2_at_1","donw_2_res_1"], "IN09": ["down_2_down"],
|
634 |
-
"IN10": ["down_3_res_0"], "IN11": ["down_3_res_1"],
|
635 |
-
"MID": ["mid"],
|
636 |
-
"OUT00": ["up_0_res_0"], "OUT01": ["up_0_res_1"], "OUT02": ["up_0_res_2", "up_0_up"],
|
637 |
-
"OUT03": ["up_1_at_0", "up_1_res_0"], "OUT04": ["up_1_at_1", "up_1_res_1"], "OUT05": ["up_1_at_2", "up_1_res_2", "up_1_up"],
|
638 |
-
"OUT06": ["up_2_at_0", "up_2_res_0"], "OUT07": ["up_2_at_1", "up_2_res_1"], "OUT08": ["up_2_at_2", "up_2_res_2", "up_2_up"],
|
639 |
-
"OUT09": ["up_3_at_0", "up_3_res_0"], "OUT10": ["up_3_at_1", "up_3_res_1"], "OUT11": ["up_3_at_2", "up_3_res_2"],
|
640 |
-
}
|
641 |
-
|
642 |
-
blocks_name_dic = { "te": "lora_te_",
|
643 |
-
"unet": "lora_unet_",
|
644 |
-
"mid": "lora_unet_mid_block_",
|
645 |
-
"down": "lora_unet_down_blocks_",
|
646 |
-
"up": "lora_unet_up_blocks_"}
|
647 |
-
for i in range(12):
|
648 |
-
blocks_name_dic[f"te_{i}"] = f"lora_te_text_model_encoder_layers_{i}_"
|
649 |
-
for i in range(3):
|
650 |
-
blocks_name_dic[f"down_{i}"] = f"lora_unet_down_blocks_{i}"
|
651 |
-
blocks_name_dic[f"up_{i+1}"] = f"lora_unet_up_blocks_{i+1}"
|
652 |
-
for i in range(4):
|
653 |
-
for j in range(2):
|
654 |
-
if i<=2: blocks_name_dic[f"down_{i}_at_{j}"] = f"lora_unet_down_blocks_{i}_attentions_{j}_"
|
655 |
-
blocks_name_dic[f"down_{i}_res_{j}"] = f"lora_unet_down_blocks_{i}_resnets_{j}"
|
656 |
-
for j in range(3):
|
657 |
-
if i>=1: blocks_name_dic[f"up_{i}_at_{j}"] = f"lora_unet_up_blocks_{i}_attentions_{j}_"
|
658 |
-
blocks_name_dic[f"up_{i}_res_{j}"] = f"lora_unet_up_blocks_{i}_resnets_{j}"
|
659 |
-
if i<=2:
|
660 |
-
blocks_name_dic[f"down_{i}_down"] = f"lora_unet_down_blocks_{i}_downsamplers_"
|
661 |
-
blocks_name_dic[f"up_{i}_up"] = f"lora_unet_up_blocks_{i}_upsamplers_"
|
662 |
-
|
663 |
-
def create_lr_blocks(lr_setting_str=None, block_optim_args=None):
|
664 |
-
ex_block_weight_dic = MyNetwork_Names.ex_block_weight_dic
|
665 |
-
blocks_name_dic = MyNetwork_Names.blocks_name_dic
|
666 |
-
|
667 |
-
lr_dic = {}
|
668 |
-
if lr_setting_str==None or lr_setting_str=="":
|
669 |
-
pass
|
670 |
-
else:
|
671 |
-
lr_settings = lr_setting_str.replace(" ", "").split(",")
|
672 |
-
for lr_setting in lr_settings:
|
673 |
-
key, value = lr_setting.split("=")
|
674 |
-
if key in ex_block_weight_dic:
|
675 |
-
keys = ex_block_weight_dic[key]
|
676 |
-
else:
|
677 |
-
keys = [key]
|
678 |
-
for key in keys:
|
679 |
-
if key in blocks_name_dic:
|
680 |
-
new_key = blocks_name_dic[key]
|
681 |
-
lr_dic[new_key] = float(value)
|
682 |
-
if len(lr_dic)==0:
|
683 |
-
lr_dic = None
|
684 |
-
|
685 |
-
args_dic = {}
|
686 |
-
if (block_optim_args is None):
|
687 |
-
block_optim_args = []
|
688 |
-
if (len(block_optim_args)>0):
|
689 |
-
for my_arg in block_optim_args:
|
690 |
-
my_arg = my_arg.replace(" ", "")
|
691 |
-
splits = my_arg.split(":")
|
692 |
-
b_name = splits[0]
|
693 |
-
|
694 |
-
key, _value = splits[1].split("=")
|
695 |
-
value_type = float
|
696 |
-
if len(splits)==3:
|
697 |
-
if _value=="str":
|
698 |
-
value_type = str
|
699 |
-
elif _value=="int":
|
700 |
-
value_type = int
|
701 |
-
_value = splits[2]
|
702 |
-
if _value=="true" or _value=="false":
|
703 |
-
value_type = bool
|
704 |
-
if "," in _value:
|
705 |
-
_value = _value.split(",")
|
706 |
-
for i in range(len(_value)):
|
707 |
-
_value[i] = value_type(_value[i])
|
708 |
-
value=tuple(_value)
|
709 |
-
else:
|
710 |
-
value = value_type(_value)
|
711 |
-
|
712 |
-
if b_name in ex_block_weight_dic:
|
713 |
-
b_names = ex_block_weight_dic[b_name]
|
714 |
-
else:
|
715 |
-
b_names = [b_name]
|
716 |
-
for b_name in b_names:
|
717 |
-
new_b_name = blocks_name_dic[b_name]
|
718 |
-
if not new_b_name in args_dic:
|
719 |
-
args_dic[new_b_name] = {}
|
720 |
-
args_dic[new_b_name][key] = value
|
721 |
-
|
722 |
-
if len(args_dic)==0:
|
723 |
-
args_dic = None
|
724 |
-
return lr_dic, args_dic
|
725 |
|
726 |
def create_split_names(split_flag, split_level):
|
727 |
split_names = None
|
@@ -732,28 +446,14 @@ def create_split_names(split_flag, split_level):
|
|
732 |
if split_level==1:
|
733 |
unet_names.append(f"lora_unet_down_blocks_")
|
734 |
unet_names.append(f"lora_unet_up_blocks_")
|
735 |
-
elif split_level==2 or split_level==0
|
736 |
-
if split_level
|
737 |
text_encoder_names = []
|
738 |
for i in range(12):
|
739 |
text_encoder_names.append(f"lora_te_text_model_encoder_layers_{i}_")
|
740 |
-
|
741 |
-
|
742 |
-
|
743 |
-
unet_names.append(f"lora_unet_down_blocks_{i}")
|
744 |
-
unet_names.append(f"lora_unet_up_blocks_{i+1}")
|
745 |
-
|
746 |
-
if split_level>=3:
|
747 |
-
for i in range(4):
|
748 |
-
for j in range(2):
|
749 |
-
if i<=2: unet_names.append(f"lora_unet_down_blocks_{i}_attentions_{j}_")
|
750 |
-
if i== 3: unet_names.append(f"lora_unet_down_blocks_{i}_resnets_{j}")
|
751 |
-
for j in range(3):
|
752 |
-
if i>=1: unet_names.append(f"lora_unet_up_blocks_{i}_attentions_{j}_")
|
753 |
-
if i==0: unet_names.append(f"lora_unet_up_blocks_{i}_resnets_{j}")
|
754 |
-
if i<=2:
|
755 |
-
unet_names.append(f"lora_unet_down_blocks_{i}_downsamplers_")
|
756 |
-
|
757 |
split_names["text_encoder"] = text_encoder_names
|
758 |
split_names["unet"] = unet_names
|
759 |
return split_names
|
@@ -765,7 +465,7 @@ def get_config(parser):
|
|
765 |
import datetime
|
766 |
if os.path.splitext(args.config)[-1] == ".yaml":
|
767 |
args.config = os.path.splitext(args.config)[0]
|
768 |
-
config_path = f"{args.config}.yaml"
|
769 |
if os.path.exists(config_path):
|
770 |
print(f"{config_path} から設定を読み込み中...")
|
771 |
margs, rest = parser.parse_known_args()
|
@@ -786,41 +486,19 @@ def get_config(parser):
|
|
786 |
args_type_dic[key] = act.type
|
787 |
#データタイプの確認とargsにkeyの内容を代入していく
|
788 |
for key, v in configs.items():
|
789 |
-
if
|
790 |
-
if key
|
791 |
-
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
if not type(v) == args_type_dic[key]:
|
797 |
v = args_type_dic[key](v)
|
798 |
-
|
799 |
#最後にデフォから指定が変わってるものを変更する
|
800 |
for key, v in change_def_dic.items():
|
801 |
args_dic[key] = v
|
802 |
else:
|
803 |
print(f"{config_path} が見つかりませんでした")
|
804 |
return args
|
805 |
-
|
806 |
-
'''
|
807 |
-
class GradientReversalFunction(torch.autograd.Function):
|
808 |
-
@staticmethod
|
809 |
-
def forward(ctx, input_forward: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
810 |
-
ctx.save_for_backward(scale)
|
811 |
-
return input_forward
|
812 |
-
@staticmethod
|
813 |
-
def backward(ctx, grad_backward: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
814 |
-
scale, = ctx.saved_tensors
|
815 |
-
return scale * -grad_backward, None
|
816 |
-
|
817 |
-
class GradientReversal(torch.nn.Module):
|
818 |
-
def __init__(self, scale: float):
|
819 |
-
super(GradientReversal, self).__init__()
|
820 |
-
self.scale = torch.tensor(scale)
|
821 |
-
def forward(self, x: torch.Tensor, flag: bool = False) -> torch.Tensor:
|
822 |
-
if flag:
|
823 |
-
return x
|
824 |
-
else:
|
825 |
-
return GradientReversalFunction.apply(x, self.scale)
|
826 |
-
'''
|
|
|
2 |
import json
|
3 |
import shutil
|
4 |
import time
|
5 |
+
from typing import Dict, List, NamedTuple, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from accelerate import Accelerator
|
7 |
from torch.autograd.function import Function
|
8 |
import glob
|
|
|
28 |
|
29 |
import library.model_util as model_util
|
30 |
import library.train_util as train_util
|
|
|
31 |
|
32 |
#============================================================================================================
|
33 |
#AdafactorScheduleに暫定的にinitial_lrを層別に適用できるようにしたもの
|
|
|
115 |
return area_size_resos_list, area_size_list
|
116 |
|
117 |
#============================================================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
#train_util 内より
|
119 |
#============================================================================================================
|
120 |
class BucketManager_append(train_util.BucketManager):
|
|
|
179 |
bucket_size_id_list.append(bucket_size_id + i + 1)
|
180 |
_min_error = 1000.
|
181 |
_min_id = bucket_size_id
|
182 |
+
for now_size_id in bucket_size_id:
|
183 |
self.predefined_aspect_ratios = self.predefined_aspect_ratios_list[now_size_id]
|
184 |
ar_errors = self.predefined_aspect_ratios - aspect_ratio
|
185 |
ar_error = np.abs(ar_errors).min()
|
|
|
253 |
return reso, resized_size, ar_error
|
254 |
|
255 |
class DreamBoothDataset(train_util.DreamBoothDataset):
|
256 |
+
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset, min_resolution=None, area_step=None) -> None:
|
257 |
print("use append DreamBoothDataset")
|
258 |
self.min_resolution = min_resolution
|
259 |
self.area_step = area_step
|
260 |
+
super().__init__(batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens,
|
261 |
+
resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight,
|
262 |
+
flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
263 |
def make_buckets(self):
|
264 |
'''
|
265 |
bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
|
|
|
352 |
self.shuffle_buckets()
|
353 |
self._length = len(self.buckets_indices)
|
354 |
|
355 |
+
class FineTuningDataset(train_util.FineTuningDataset):
|
356 |
+
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
|
357 |
+
train_util.glob_images = glob_images
|
358 |
+
super().__init__( json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
359 |
+
resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range,
|
360 |
+
random_crop, dataset_repeats, debug_dataset)
|
361 |
+
|
362 |
+
def glob_images(directory, base="*", npz_flag=True):
|
363 |
+
img_paths = []
|
364 |
+
dots = []
|
365 |
+
for ext in train_util.IMAGE_EXTENSIONS:
|
366 |
+
dots.append(ext)
|
367 |
+
if npz_flag:
|
368 |
+
dots.append(".npz")
|
369 |
+
for ext in dots:
|
370 |
+
if base == '*':
|
371 |
+
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
|
372 |
+
else:
|
373 |
+
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
374 |
+
return img_paths
|
375 |
+
|
376 |
#============================================================================================================
|
377 |
#networks.lora
|
378 |
#============================================================================================================
|
379 |
+
from networks.lora import LoRANetwork
|
380 |
+
def replace_prepare_optimizer_params(networks):
|
381 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, scheduler_lr=None, loranames=None):
|
382 |
+
|
383 |
def enumerate_params(loras, lora_name=None):
|
384 |
params = []
|
385 |
for lora in loras:
|
386 |
if lora_name is not None:
|
387 |
+
if lora_name in lora.lora_name:
|
388 |
+
params.extend(lora.parameters())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
else:
|
390 |
params.extend(lora.parameters())
|
391 |
return params
|
|
|
393 |
self.requires_grad_(True)
|
394 |
all_params = []
|
395 |
ret_scheduler_lr = []
|
|
|
396 |
|
397 |
if loranames is not None:
|
398 |
textencoder_names = [None]
|
|
|
405 |
if self.text_encoder_loras:
|
406 |
for textencoder_name in textencoder_names:
|
407 |
param_data = {'params': enumerate_params(self.text_encoder_loras, lora_name=textencoder_name)}
|
|
|
408 |
if text_encoder_lr is not None:
|
409 |
param_data['lr'] = text_encoder_lr
|
410 |
+
if scheduler_lr is not None:
|
411 |
+
ret_scheduler_lr.append(scheduler_lr[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
all_params.append(param_data)
|
413 |
|
414 |
if self.unet_loras:
|
415 |
for unet_name in unet_names:
|
416 |
param_data = {'params': enumerate_params(self.unet_loras, lora_name=unet_name)}
|
|
|
|
|
417 |
if unet_lr is not None:
|
418 |
param_data['lr'] = unet_lr
|
419 |
+
if scheduler_lr is not None:
|
420 |
+
ret_scheduler_lr.append(scheduler_lr[1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
421 |
all_params.append(param_data)
|
422 |
|
423 |
+
return all_params, ret_scheduler_lr
|
424 |
+
|
425 |
+
LoRANetwork.prepare_optimizer_params = prepare_optimizer_params
|
|
|
|
|
426 |
|
427 |
#============================================================================================================
|
428 |
#新規追加
|
429 |
#============================================================================================================
|
430 |
def add_append_arguments(parser: argparse.ArgumentParser):
|
431 |
# for train_network_opt.py
|
432 |
+
parser.add_argument("--optimizer", type=str, default="AdamW", choices=["AdamW", "RAdam", "AdaBound", "AdaBelief", "AggMo", "AdamP", "Adastand", "Adastand_belief", "Apollo", "Lamb", "Ranger", "RangerVA", "Lookahead_Adam", "Lookahead_DiffGrad", "Yogi", "NovoGrad", "QHAdam", "DiffGrad", "MADGRAD", "Adafactor"], help="使用するoptimizerを指定する")
|
433 |
+
parser.add_argument("--optimizer_arg", type=str, default=None, nargs='*')
|
|
|
|
|
434 |
parser.add_argument("--split_lora_networks", action="store_true")
|
435 |
parser.add_argument("--split_lora_level", type=int, default=0, help="どれくらい細分化するかの設定 0がunetのみを層別に 1がunetを大枠で分割 2がtextencoder含めて層別")
|
|
|
|
|
436 |
parser.add_argument("--min_resolution", type=str, default=None)
|
437 |
parser.add_argument("--area_step", type=int, default=1)
|
438 |
parser.add_argument("--config", type=str, default=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
|
440 |
def create_split_names(split_flag, split_level):
|
441 |
split_names = None
|
|
|
446 |
if split_level==1:
|
447 |
unet_names.append(f"lora_unet_down_blocks_")
|
448 |
unet_names.append(f"lora_unet_up_blocks_")
|
449 |
+
elif split_level==2 or split_level==0:
|
450 |
+
if split_level==2:
|
451 |
text_encoder_names = []
|
452 |
for i in range(12):
|
453 |
text_encoder_names.append(f"lora_te_text_model_encoder_layers_{i}_")
|
454 |
+
for i in range(3):
|
455 |
+
unet_names.append(f"lora_unet_down_blocks_{i}")
|
456 |
+
unet_names.append(f"lora_unet_up_blocks_{i+1}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
457 |
split_names["text_encoder"] = text_encoder_names
|
458 |
split_names["unet"] = unet_names
|
459 |
return split_names
|
|
|
465 |
import datetime
|
466 |
if os.path.splitext(args.config)[-1] == ".yaml":
|
467 |
args.config = os.path.splitext(args.config)[0]
|
468 |
+
config_path = f"./{args.config}.yaml"
|
469 |
if os.path.exists(config_path):
|
470 |
print(f"{config_path} から設定を読み込み中...")
|
471 |
margs, rest = parser.parse_known_args()
|
|
|
486 |
args_type_dic[key] = act.type
|
487 |
#データタイプの確認とargsにkeyの内容を代入していく
|
488 |
for key, v in configs.items():
|
489 |
+
if key in args_dic:
|
490 |
+
if args_dic[key] is not None:
|
491 |
+
new_type = type(args_dic[key])
|
492 |
+
if (not type(v) == new_type) and (not new_type==list):
|
493 |
+
v = new_type(v)
|
494 |
+
else:
|
495 |
+
if v is not None:
|
496 |
if not type(v) == args_type_dic[key]:
|
497 |
v = args_type_dic[key](v)
|
498 |
+
args_dic[key] = v
|
499 |
#最後にデフォから指定が変わってるものを変更する
|
500 |
for key, v in change_def_dic.items():
|
501 |
args_dic[key] = v
|
502 |
else:
|
503 |
print(f"{config_path} が見つかりませんでした")
|
504 |
return args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fine_tune.py
CHANGED
@@ -13,11 +13,7 @@ import diffusers
|
|
13 |
from diffusers import DDPMScheduler
|
14 |
|
15 |
import library.train_util as train_util
|
16 |
-
|
17 |
-
from library.config_util import (
|
18 |
-
ConfigSanitizer,
|
19 |
-
BlueprintGenerator,
|
20 |
-
)
|
21 |
|
22 |
def collate_fn(examples):
|
23 |
return examples[0]
|
@@ -34,36 +30,25 @@ def train(args):
|
|
34 |
|
35 |
tokenizer = train_util.load_tokenizer(args)
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
"image_dir": args.train_data_dir,
|
49 |
-
"metadata_file": args.in_json,
|
50 |
-
}]
|
51 |
-
}]
|
52 |
-
}
|
53 |
-
|
54 |
-
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
55 |
-
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
56 |
|
57 |
if args.debug_dataset:
|
58 |
-
train_util.debug_dataset(
|
59 |
return
|
60 |
-
if len(
|
61 |
print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
|
62 |
return
|
63 |
|
64 |
-
if cache_latents:
|
65 |
-
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
66 |
-
|
67 |
# acceleratorを準備する
|
68 |
print("prepare accelerator")
|
69 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
@@ -124,7 +109,7 @@ def train(args):
|
|
124 |
vae.requires_grad_(False)
|
125 |
vae.eval()
|
126 |
with torch.no_grad():
|
127 |
-
|
128 |
vae.to("cpu")
|
129 |
if torch.cuda.is_available():
|
130 |
torch.cuda.empty_cache()
|
@@ -164,13 +149,33 @@ def train(args):
|
|
164 |
|
165 |
# 学習に必要なクラスを準備する
|
166 |
print("prepare optimizer, data loader etc.")
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
# dataloaderを準備する
|
170 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
171 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
172 |
train_dataloader = torch.utils.data.DataLoader(
|
173 |
-
|
174 |
|
175 |
# 学習ステップ数を計算する
|
176 |
if args.max_train_epochs is not None:
|
@@ -178,9 +183,8 @@ def train(args):
|
|
178 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
179 |
|
180 |
# lr schedulerを用意する
|
181 |
-
lr_scheduler =
|
182 |
-
|
183 |
-
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
184 |
|
185 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
186 |
if args.full_fp16:
|
@@ -214,7 +218,7 @@ def train(args):
|
|
214 |
# 学習する
|
215 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
216 |
print("running training / 学習開始")
|
217 |
-
print(f" num examples / サンプル数: {
|
218 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
219 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
220 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
@@ -233,7 +237,7 @@ def train(args):
|
|
233 |
|
234 |
for epoch in range(num_train_epochs):
|
235 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
236 |
-
|
237 |
|
238 |
for m in training_models:
|
239 |
m.train()
|
@@ -282,11 +286,11 @@ def train(args):
|
|
282 |
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
283 |
|
284 |
accelerator.backward(loss)
|
285 |
-
if accelerator.sync_gradients
|
286 |
params_to_clip = []
|
287 |
for m in training_models:
|
288 |
params_to_clip.extend(m.parameters())
|
289 |
-
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
290 |
|
291 |
optimizer.step()
|
292 |
lr_scheduler.step()
|
@@ -297,16 +301,11 @@ def train(args):
|
|
297 |
progress_bar.update(1)
|
298 |
global_step += 1
|
299 |
|
300 |
-
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
301 |
-
|
302 |
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
303 |
if args.logging_dir is not None:
|
304 |
-
logs = {"loss": current_loss, "lr":
|
305 |
-
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
306 |
-
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
307 |
accelerator.log(logs, step=global_step)
|
308 |
|
309 |
-
# TODO moving averageにする
|
310 |
loss_total += current_loss
|
311 |
avr_loss = loss_total / (step+1)
|
312 |
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
@@ -316,7 +315,7 @@ def train(args):
|
|
316 |
break
|
317 |
|
318 |
if args.logging_dir is not None:
|
319 |
-
logs = {"
|
320 |
accelerator.log(logs, step=epoch+1)
|
321 |
|
322 |
accelerator.wait_for_everyone()
|
@@ -326,8 +325,6 @@ def train(args):
|
|
326 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
327 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
328 |
|
329 |
-
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
330 |
-
|
331 |
is_main_process = accelerator.is_main_process
|
332 |
if is_main_process:
|
333 |
unet = unwrap_model(unet)
|
@@ -354,8 +351,6 @@ if __name__ == '__main__':
|
|
354 |
train_util.add_dataset_arguments(parser, False, True, True)
|
355 |
train_util.add_training_arguments(parser, False)
|
356 |
train_util.add_sd_saving_arguments(parser)
|
357 |
-
train_util.add_optimizer_arguments(parser)
|
358 |
-
config_util.add_config_arguments(parser)
|
359 |
|
360 |
parser.add_argument("--diffusers_xformers", action='store_true',
|
361 |
help='use xformers by diffusers / Diffusersでxformersを使用する')
|
|
|
13 |
from diffusers import DDPMScheduler
|
14 |
|
15 |
import library.train_util as train_util
|
16 |
+
|
|
|
|
|
|
|
|
|
17 |
|
18 |
def collate_fn(examples):
|
19 |
return examples[0]
|
|
|
30 |
|
31 |
tokenizer = train_util.load_tokenizer(args)
|
32 |
|
33 |
+
train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
34 |
+
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
35 |
+
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
36 |
+
args.bucket_reso_steps, args.bucket_no_upscale,
|
37 |
+
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
38 |
+
args.dataset_repeats, args.debug_dataset)
|
39 |
+
|
40 |
+
# 学習データのdropout率を設定する
|
41 |
+
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
|
42 |
+
|
43 |
+
train_dataset.make_buckets()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
if args.debug_dataset:
|
46 |
+
train_util.debug_dataset(train_dataset)
|
47 |
return
|
48 |
+
if len(train_dataset) == 0:
|
49 |
print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
|
50 |
return
|
51 |
|
|
|
|
|
|
|
52 |
# acceleratorを準備する
|
53 |
print("prepare accelerator")
|
54 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
|
|
109 |
vae.requires_grad_(False)
|
110 |
vae.eval()
|
111 |
with torch.no_grad():
|
112 |
+
train_dataset.cache_latents(vae)
|
113 |
vae.to("cpu")
|
114 |
if torch.cuda.is_available():
|
115 |
torch.cuda.empty_cache()
|
|
|
149 |
|
150 |
# 学習に必要なクラスを準備する
|
151 |
print("prepare optimizer, data loader etc.")
|
152 |
+
|
153 |
+
# 8-bit Adamを使う
|
154 |
+
if args.use_8bit_adam:
|
155 |
+
try:
|
156 |
+
import bitsandbytes as bnb
|
157 |
+
except ImportError:
|
158 |
+
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
159 |
+
print("use 8-bit Adam optimizer")
|
160 |
+
optimizer_class = bnb.optim.AdamW8bit
|
161 |
+
elif args.use_lion_optimizer:
|
162 |
+
try:
|
163 |
+
import lion_pytorch
|
164 |
+
except ImportError:
|
165 |
+
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
166 |
+
print("use Lion optimizer")
|
167 |
+
optimizer_class = lion_pytorch.Lion
|
168 |
+
else:
|
169 |
+
optimizer_class = torch.optim.AdamW
|
170 |
+
|
171 |
+
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
172 |
+
optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
|
173 |
|
174 |
# dataloaderを準備する
|
175 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
176 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
177 |
train_dataloader = torch.utils.data.DataLoader(
|
178 |
+
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
179 |
|
180 |
# 学習ステップ数を計算する
|
181 |
if args.max_train_epochs is not None:
|
|
|
183 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
184 |
|
185 |
# lr schedulerを用意する
|
186 |
+
lr_scheduler = diffusers.optimization.get_scheduler(
|
187 |
+
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
|
|
188 |
|
189 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
190 |
if args.full_fp16:
|
|
|
218 |
# 学習する
|
219 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
220 |
print("running training / 学習開始")
|
221 |
+
print(f" num examples / サンプル数: {train_dataset.num_train_images}")
|
222 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
223 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
224 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
|
|
237 |
|
238 |
for epoch in range(num_train_epochs):
|
239 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
240 |
+
train_dataset.set_current_epoch(epoch + 1)
|
241 |
|
242 |
for m in training_models:
|
243 |
m.train()
|
|
|
286 |
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
287 |
|
288 |
accelerator.backward(loss)
|
289 |
+
if accelerator.sync_gradients:
|
290 |
params_to_clip = []
|
291 |
for m in training_models:
|
292 |
params_to_clip.extend(m.parameters())
|
293 |
+
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
294 |
|
295 |
optimizer.step()
|
296 |
lr_scheduler.step()
|
|
|
301 |
progress_bar.update(1)
|
302 |
global_step += 1
|
303 |
|
|
|
|
|
304 |
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
305 |
if args.logging_dir is not None:
|
306 |
+
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
|
|
307 |
accelerator.log(logs, step=global_step)
|
308 |
|
|
|
309 |
loss_total += current_loss
|
310 |
avr_loss = loss_total / (step+1)
|
311 |
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
315 |
break
|
316 |
|
317 |
if args.logging_dir is not None:
|
318 |
+
logs = {"epoch_loss": loss_total / len(train_dataloader)}
|
319 |
accelerator.log(logs, step=epoch+1)
|
320 |
|
321 |
accelerator.wait_for_everyone()
|
|
|
325 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
326 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
327 |
|
|
|
|
|
328 |
is_main_process = accelerator.is_main_process
|
329 |
if is_main_process:
|
330 |
unet = unwrap_model(unet)
|
|
|
351 |
train_util.add_dataset_arguments(parser, False, True, True)
|
352 |
train_util.add_training_arguments(parser, False)
|
353 |
train_util.add_sd_saving_arguments(parser)
|
|
|
|
|
354 |
|
355 |
parser.add_argument("--diffusers_xformers", action='store_true',
|
356 |
help='use xformers by diffusers / Diffusersでxformersを使用する')
|
gen_img_diffusers.py
CHANGED
@@ -47,7 +47,7 @@ VGG(
|
|
47 |
"""
|
48 |
|
49 |
import json
|
50 |
-
from typing import
|
51 |
import glob
|
52 |
import importlib
|
53 |
import inspect
|
@@ -60,6 +60,7 @@ import math
|
|
60 |
import os
|
61 |
import random
|
62 |
import re
|
|
|
63 |
|
64 |
import diffusers
|
65 |
import numpy as np
|
@@ -80,9 +81,6 @@ from PIL import Image
|
|
80 |
from PIL.PngImagePlugin import PngInfo
|
81 |
|
82 |
import library.model_util as model_util
|
83 |
-
import library.train_util as train_util
|
84 |
-
import tools.original_control_net as original_control_net
|
85 |
-
from tools.original_control_net import ControlNetInfo
|
86 |
|
87 |
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
88 |
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
@@ -489,9 +487,6 @@ class PipelineLike():
|
|
489 |
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
490 |
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
491 |
|
492 |
-
# ControlNet
|
493 |
-
self.control_nets: List[ControlNetInfo] = []
|
494 |
-
|
495 |
# Textual Inversion
|
496 |
def add_token_replacement(self, target_token_id, rep_token_ids):
|
497 |
self.token_replacements[target_token_id] = rep_token_ids
|
@@ -505,11 +500,7 @@ class PipelineLike():
|
|
505 |
new_tokens.append(token)
|
506 |
return new_tokens
|
507 |
|
508 |
-
def set_control_nets(self, ctrl_nets):
|
509 |
-
self.control_nets = ctrl_nets
|
510 |
-
|
511 |
# region xformersとか使う部分:独自に書き換えるので関係なし
|
512 |
-
|
513 |
def enable_xformers_memory_efficient_attention(self):
|
514 |
r"""
|
515 |
Enable memory efficient attention as implemented in xformers.
|
@@ -590,8 +581,6 @@ class PipelineLike():
|
|
590 |
latents: Optional[torch.FloatTensor] = None,
|
591 |
max_embeddings_multiples: Optional[int] = 3,
|
592 |
output_type: Optional[str] = "pil",
|
593 |
-
vae_batch_size: float = None,
|
594 |
-
return_latents: bool = False,
|
595 |
# return_dict: bool = True,
|
596 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
597 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
@@ -683,9 +672,6 @@ class PipelineLike():
|
|
683 |
else:
|
684 |
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
685 |
|
686 |
-
vae_batch_size = batch_size if vae_batch_size is None else (
|
687 |
-
int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size)))
|
688 |
-
|
689 |
if strength < 0 or strength > 1:
|
690 |
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
691 |
|
@@ -766,7 +752,7 @@ class PipelineLike():
|
|
766 |
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
|
767 |
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
|
768 |
|
769 |
-
if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None
|
770 |
if isinstance(clip_guide_images, PIL.Image.Image):
|
771 |
clip_guide_images = [clip_guide_images]
|
772 |
|
@@ -779,7 +765,7 @@ class PipelineLike():
|
|
779 |
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
|
780 |
if len(image_embeddings_clip) == 1:
|
781 |
image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
|
782 |
-
|
783 |
size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
|
784 |
clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
|
785 |
clip_guide_images = torch.cat(clip_guide_images, dim=0)
|
@@ -788,10 +774,6 @@ class PipelineLike():
|
|
788 |
image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
|
789 |
if len(image_embeddings_vgg16) == 1:
|
790 |
image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
|
791 |
-
else:
|
792 |
-
# ControlNetのhintにguide imageを流用する
|
793 |
-
# 前処理はControlNet側で行う
|
794 |
-
pass
|
795 |
|
796 |
# set timesteps
|
797 |
self.scheduler.set_timesteps(num_inference_steps, self.device)
|
@@ -799,6 +781,7 @@ class PipelineLike():
|
|
799 |
latents_dtype = text_embeddings.dtype
|
800 |
init_latents_orig = None
|
801 |
mask = None
|
|
|
802 |
|
803 |
if init_image is None:
|
804 |
# get the initial random noise unless the user supplied it
|
@@ -830,8 +813,6 @@ class PipelineLike():
|
|
830 |
if isinstance(init_image[0], PIL.Image.Image):
|
831 |
init_image = [preprocess_image(im) for im in init_image]
|
832 |
init_image = torch.cat(init_image)
|
833 |
-
if isinstance(init_image, list):
|
834 |
-
init_image = torch.stack(init_image)
|
835 |
|
836 |
# mask image to tensor
|
837 |
if mask_image is not None:
|
@@ -842,24 +823,9 @@ class PipelineLike():
|
|
842 |
|
843 |
# encode the init image into latents and scale the latents
|
844 |
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
if vae_batch_size >= batch_size:
|
849 |
-
init_latent_dist = self.vae.encode(init_image).latent_dist
|
850 |
-
init_latents = init_latent_dist.sample(generator=generator)
|
851 |
-
else:
|
852 |
-
if torch.cuda.is_available():
|
853 |
-
torch.cuda.empty_cache()
|
854 |
-
init_latents = []
|
855 |
-
for i in tqdm(range(0, batch_size, vae_batch_size)):
|
856 |
-
init_latent_dist = self.vae.encode(init_image[i:i + vae_batch_size]
|
857 |
-
if vae_batch_size > 1 else init_image[i].unsqueeze(0)).latent_dist
|
858 |
-
init_latents.append(init_latent_dist.sample(generator=generator))
|
859 |
-
init_latents = torch.cat(init_latents)
|
860 |
-
|
861 |
-
init_latents = 0.18215 * init_latents
|
862 |
-
|
863 |
if len(init_latents) == 1:
|
864 |
init_latents = init_latents.repeat((batch_size, 1, 1, 1))
|
865 |
init_latents_orig = init_latents
|
@@ -898,21 +864,12 @@ class PipelineLike():
|
|
898 |
extra_step_kwargs["eta"] = eta
|
899 |
|
900 |
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
|
901 |
-
|
902 |
-
if self.control_nets:
|
903 |
-
guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
|
904 |
-
|
905 |
for i, t in enumerate(tqdm(timesteps)):
|
906 |
# expand the latents if we are doing classifier free guidance
|
907 |
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
908 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
909 |
-
|
910 |
# predict the noise residual
|
911 |
-
|
912 |
-
noise_pred = original_control_net.call_unet_and_control_net(
|
913 |
-
i, num_latent_input, self.unet, self.control_nets, guided_hints, i / len(timesteps), latent_model_input, t, text_embeddings).sample
|
914 |
-
else:
|
915 |
-
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
916 |
|
917 |
# perform guidance
|
918 |
if do_classifier_free_guidance:
|
@@ -954,19 +911,8 @@ class PipelineLike():
|
|
954 |
if is_cancelled_callback is not None and is_cancelled_callback():
|
955 |
return None
|
956 |
|
957 |
-
if return_latents:
|
958 |
-
return (latents, False)
|
959 |
-
|
960 |
latents = 1 / 0.18215 * latents
|
961 |
-
|
962 |
-
image = self.vae.decode(latents).sample
|
963 |
-
else:
|
964 |
-
if torch.cuda.is_available():
|
965 |
-
torch.cuda.empty_cache()
|
966 |
-
images = []
|
967 |
-
for i in tqdm(range(0, batch_size, vae_batch_size)):
|
968 |
-
images.append(self.vae.decode(latents[i:i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample)
|
969 |
-
image = torch.cat(images)
|
970 |
|
971 |
image = (image / 2 + 0.5).clamp(0, 1)
|
972 |
|
@@ -1853,7 +1799,7 @@ def preprocess_mask(mask):
|
|
1853 |
mask = mask.convert("L")
|
1854 |
w, h = mask.size
|
1855 |
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
1856 |
-
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.
|
1857 |
mask = np.array(mask).astype(np.float32) / 255.0
|
1858 |
mask = np.tile(mask, (4, 1, 1))
|
1859 |
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
@@ -1871,35 +1817,6 @@ def preprocess_mask(mask):
|
|
1871 |
# return text_encoder
|
1872 |
|
1873 |
|
1874 |
-
class BatchDataBase(NamedTuple):
|
1875 |
-
# バッチ分割が必要ないデータ
|
1876 |
-
step: int
|
1877 |
-
prompt: str
|
1878 |
-
negative_prompt: str
|
1879 |
-
seed: int
|
1880 |
-
init_image: Any
|
1881 |
-
mask_image: Any
|
1882 |
-
clip_prompt: str
|
1883 |
-
guide_image: Any
|
1884 |
-
|
1885 |
-
|
1886 |
-
class BatchDataExt(NamedTuple):
|
1887 |
-
# バッチ分割が必要なデータ
|
1888 |
-
width: int
|
1889 |
-
height: int
|
1890 |
-
steps: int
|
1891 |
-
scale: float
|
1892 |
-
negative_scale: float
|
1893 |
-
strength: float
|
1894 |
-
network_muls: Tuple[float]
|
1895 |
-
|
1896 |
-
|
1897 |
-
class BatchData(NamedTuple):
|
1898 |
-
return_latents: bool
|
1899 |
-
base: BatchDataBase
|
1900 |
-
ext: BatchDataExt
|
1901 |
-
|
1902 |
-
|
1903 |
def main(args):
|
1904 |
if args.fp16:
|
1905 |
dtype = torch.float16
|
@@ -1964,7 +1881,10 @@ def main(args):
|
|
1964 |
# tokenizerを読み込む
|
1965 |
print("loading tokenizer")
|
1966 |
if use_stable_diffusion_format:
|
1967 |
-
|
|
|
|
|
|
|
1968 |
|
1969 |
# schedulerを用意する
|
1970 |
sched_init_args = {}
|
@@ -2075,13 +1995,11 @@ def main(args):
|
|
2075 |
# networkを組み込む
|
2076 |
if args.network_module:
|
2077 |
networks = []
|
2078 |
-
network_default_muls = []
|
2079 |
for i, network_module in enumerate(args.network_module):
|
2080 |
print("import network module:", network_module)
|
2081 |
imported_module = importlib.import_module(network_module)
|
2082 |
|
2083 |
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
2084 |
-
network_default_muls.append(network_mul)
|
2085 |
|
2086 |
net_kwargs = {}
|
2087 |
if args.network_args and i < len(args.network_args):
|
@@ -2096,7 +2014,7 @@ def main(args):
|
|
2096 |
network_weight = args.network_weights[i]
|
2097 |
print("load network weights from:", network_weight)
|
2098 |
|
2099 |
-
if model_util.is_safetensors(network_weight)
|
2100 |
from safetensors.torch import safe_open
|
2101 |
with safe_open(network_weight, framework="pt") as f:
|
2102 |
metadata = f.metadata()
|
@@ -2119,18 +2037,6 @@ def main(args):
|
|
2119 |
else:
|
2120 |
networks = []
|
2121 |
|
2122 |
-
# ControlNetの処理
|
2123 |
-
control_nets: List[ControlNetInfo] = []
|
2124 |
-
if args.control_net_models:
|
2125 |
-
for i, model in enumerate(args.control_net_models):
|
2126 |
-
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
|
2127 |
-
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
|
2128 |
-
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
|
2129 |
-
|
2130 |
-
ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
|
2131 |
-
prep = original_control_net.load_preprocess(prep_type)
|
2132 |
-
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
|
2133 |
-
|
2134 |
if args.opt_channels_last:
|
2135 |
print(f"set optimizing: channels last")
|
2136 |
text_encoder.to(memory_format=torch.channels_last)
|
@@ -2144,14 +2050,9 @@ def main(args):
|
|
2144 |
if vgg16_model is not None:
|
2145 |
vgg16_model.to(memory_format=torch.channels_last)
|
2146 |
|
2147 |
-
for cn in control_nets:
|
2148 |
-
cn.unet.to(memory_format=torch.channels_last)
|
2149 |
-
cn.net.to(memory_format=torch.channels_last)
|
2150 |
-
|
2151 |
pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
|
2152 |
clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
|
2153 |
vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
|
2154 |
-
pipe.set_control_nets(control_nets)
|
2155 |
print("pipeline is ready.")
|
2156 |
|
2157 |
if args.diffusers_xformers:
|
@@ -2285,12 +2186,9 @@ def main(args):
|
|
2285 |
|
2286 |
prev_image = None # for VGG16 guided
|
2287 |
if args.guide_image_path is not None:
|
2288 |
-
print(f"load image for CLIP/VGG16
|
2289 |
-
guide_images =
|
2290 |
-
for
|
2291 |
-
guide_images.extend(load_images(p))
|
2292 |
-
|
2293 |
-
print(f"loaded {len(guide_images)} guide images for guidance")
|
2294 |
if len(guide_images) == 0:
|
2295 |
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
|
2296 |
guide_images = None
|
@@ -2321,46 +2219,33 @@ def main(args):
|
|
2321 |
iter_seed = random.randint(0, 0x7fffffff)
|
2322 |
|
2323 |
# バッチ処理の関数
|
2324 |
-
def process_batch(batch
|
2325 |
batch_size = len(batch)
|
2326 |
|
2327 |
# highres_fixの処理
|
2328 |
if highres_fix and not highres_1st:
|
2329 |
-
# 1st stage
|
2330 |
-
print("process 1st
|
2331 |
batch_1st = []
|
2332 |
-
for
|
2333 |
-
width_1st = int(
|
2334 |
-
height_1st = int(
|
2335 |
width_1st = width_1st - width_1st % 32
|
2336 |
height_1st = height_1st - height_1st % 32
|
2337 |
-
|
2338 |
-
ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale,
|
2339 |
-
ext.negative_scale, ext.strength, ext.network_muls)
|
2340 |
-
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
|
2341 |
images_1st = process_batch(batch_1st, True, True)
|
2342 |
|
2343 |
# 2nd stageのバッチを作成して以下処理する
|
2344 |
-
print("process 2nd
|
2345 |
-
if args.highres_fix_latents_upscaling:
|
2346 |
-
org_dtype = images_1st.dtype
|
2347 |
-
if images_1st.dtype == torch.bfloat16:
|
2348 |
-
images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
|
2349 |
-
images_1st = torch.nn.functional.interpolate(
|
2350 |
-
images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode='bilinear') # , antialias=True)
|
2351 |
-
images_1st = images_1st.to(org_dtype)
|
2352 |
-
|
2353 |
batch_2nd = []
|
2354 |
-
for i, (
|
2355 |
-
|
2356 |
-
|
2357 |
-
|
2358 |
-
batch_2nd.append(bd_2nd)
|
2359 |
batch = batch_2nd
|
2360 |
|
2361 |
-
|
2362 |
-
|
2363 |
-
(width, height, steps, scale, negative_scale, strength, network_muls) = batch[0]
|
2364 |
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
2365 |
|
2366 |
prompts = []
|
@@ -2393,7 +2278,7 @@ def main(args):
|
|
2393 |
all_images_are_same = True
|
2394 |
all_masks_are_same = True
|
2395 |
all_guide_images_are_same = True
|
2396 |
-
for i, (
|
2397 |
prompts.append(prompt)
|
2398 |
negative_prompts.append(negative_prompt)
|
2399 |
seeds.append(seed)
|
@@ -2410,13 +2295,9 @@ def main(args):
|
|
2410 |
all_masks_are_same = mask_images[-2] is mask_image
|
2411 |
|
2412 |
if guide_image is not None:
|
2413 |
-
|
2414 |
-
|
2415 |
-
all_guide_images_are_same =
|
2416 |
-
else:
|
2417 |
-
guide_images.append(guide_image)
|
2418 |
-
if i > 0 and all_guide_images_are_same:
|
2419 |
-
all_guide_images_are_same = guide_images[-2] is guide_image
|
2420 |
|
2421 |
# make start code
|
2422 |
torch.manual_seed(seed)
|
@@ -2439,24 +2320,10 @@ def main(args):
|
|
2439 |
if guide_images is not None and all_guide_images_are_same:
|
2440 |
guide_images = guide_images[0]
|
2441 |
|
2442 |
-
# ControlNet使用時はguide imageをリサイズする
|
2443 |
-
if control_nets:
|
2444 |
-
# TODO resampleのメソッド
|
2445 |
-
guide_images = guide_images if type(guide_images) == list else [guide_images]
|
2446 |
-
guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images]
|
2447 |
-
if len(guide_images) == 1:
|
2448 |
-
guide_images = guide_images[0]
|
2449 |
-
|
2450 |
# generate
|
2451 |
-
if networks:
|
2452 |
-
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
2453 |
-
n.set_multiplier(m)
|
2454 |
-
|
2455 |
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
|
2456 |
-
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises,
|
2457 |
-
|
2458 |
-
clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
|
2459 |
-
if highres_1st and not args.highres_fix_save_1st: # return images or latents
|
2460 |
return images
|
2461 |
|
2462 |
# save image
|
@@ -2531,7 +2398,6 @@ def main(args):
|
|
2531 |
strength = 0.8 if args.strength is None else args.strength
|
2532 |
negative_prompt = ""
|
2533 |
clip_prompt = None
|
2534 |
-
network_muls = None
|
2535 |
|
2536 |
prompt_args = prompt.strip().split(' --')
|
2537 |
prompt = prompt_args[0]
|
@@ -2595,15 +2461,6 @@ def main(args):
|
|
2595 |
clip_prompt = m.group(1)
|
2596 |
print(f"clip prompt: {clip_prompt}")
|
2597 |
continue
|
2598 |
-
|
2599 |
-
m = re.match(r'am ([\d\.\-,]+)', parg, re.IGNORECASE)
|
2600 |
-
if m: # network multiplies
|
2601 |
-
network_muls = [float(v) for v in m.group(1).split(",")]
|
2602 |
-
while len(network_muls) < len(networks):
|
2603 |
-
network_muls.append(network_muls[-1])
|
2604 |
-
print(f"network mul: {network_muls}")
|
2605 |
-
continue
|
2606 |
-
|
2607 |
except ValueError as ex:
|
2608 |
print(f"Exception in parsing / 解析エラー: {parg}")
|
2609 |
print(ex)
|
@@ -2641,12 +2498,7 @@ def main(args):
|
|
2641 |
mask_image = mask_images[global_step % len(mask_images)]
|
2642 |
|
2643 |
if guide_images is not None:
|
2644 |
-
|
2645 |
-
c = len(control_nets)
|
2646 |
-
p = global_step % (len(guide_images) // c)
|
2647 |
-
guide_image = guide_images[p * c:p * c + c]
|
2648 |
-
else:
|
2649 |
-
guide_image = guide_images[global_step % len(guide_images)]
|
2650 |
elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
|
2651 |
if prev_image is None:
|
2652 |
print("Generate 1st image without guide image.")
|
@@ -2654,9 +2506,10 @@ def main(args):
|
|
2654 |
print("Use previous image as guide image.")
|
2655 |
guide_image = prev_image
|
2656 |
|
2657 |
-
|
2658 |
-
|
2659 |
-
|
|
|
2660 |
process_batch(batch_data, highres_fix)
|
2661 |
batch_data.clear()
|
2662 |
|
@@ -2700,8 +2553,6 @@ if __name__ == '__main__':
|
|
2700 |
parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
|
2701 |
parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
|
2702 |
parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
|
2703 |
-
parser.add_argument("--vae_batch_size", type=float, default=None,
|
2704 |
-
help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率")
|
2705 |
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
|
2706 |
parser.add_argument('--sampler', type=str, default='ddim',
|
2707 |
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
|
@@ -2713,8 +2564,6 @@ if __name__ == '__main__':
|
|
2713 |
parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
|
2714 |
parser.add_argument("--vae", type=str, default=None,
|
2715 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
2716 |
-
parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
|
2717 |
-
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
|
2718 |
# parser.add_argument("--replace_clip_l14_336", action='store_true',
|
2719 |
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
|
2720 |
parser.add_argument("--seed", type=int, default=None,
|
@@ -2729,15 +2578,12 @@ if __name__ == '__main__':
|
|
2729 |
parser.add_argument("--opt_channels_last", action='store_true',
|
2730 |
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
|
2731 |
parser.add_argument("--network_module", type=str, default=None, nargs='*',
|
2732 |
-
help='
|
2733 |
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
2734 |
-
help='
|
2735 |
-
parser.add_argument("--network_mul", type=float, default=None, nargs='*',
|
2736 |
-
help='additional network multiplier / 追加ネットワークの効果の倍率')
|
2737 |
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
2738 |
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
2739 |
-
parser.add_argument("--network_show_meta", action='store_true',
|
2740 |
-
help='show metadata of network model / ネットワークモデルのメタデータを表示する')
|
2741 |
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
|
2742 |
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
|
2743 |
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
@@ -2751,26 +2597,15 @@ if __name__ == '__main__':
|
|
2751 |
help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
|
2752 |
parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
|
2753 |
help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
|
2754 |
-
parser.add_argument("--guide_image_path", type=str, default=None,
|
2755 |
-
help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
|
2756 |
parser.add_argument("--highres_fix_scale", type=float, default=None,
|
2757 |
help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
|
2758 |
parser.add_argument("--highres_fix_steps", type=int, default=28,
|
2759 |
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
|
2760 |
parser.add_argument("--highres_fix_save_1st", action='store_true',
|
2761 |
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
|
2762 |
-
parser.add_argument("--highres_fix_latents_upscaling", action='store_true',
|
2763 |
-
help="use latents upscaling for highres fix / highres fixでlatentで拡大する")
|
2764 |
parser.add_argument("--negative_scale", type=float, default=None,
|
2765 |
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
|
2766 |
|
2767 |
-
parser.add_argument("--control_net_models", type=str, default=None, nargs='*',
|
2768 |
-
help='ControlNet models to use / 使用するControlNetのモデル名')
|
2769 |
-
parser.add_argument("--control_net_preps", type=str, default=None, nargs='*',
|
2770 |
-
help='ControlNet preprocess to use / 使用するControlNetのプリプロセス名')
|
2771 |
-
parser.add_argument("--control_net_weights", type=float, default=None, nargs='*', help='ControlNet weights / ControlNetの重み')
|
2772 |
-
parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*',
|
2773 |
-
help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率')
|
2774 |
-
|
2775 |
args = parser.parse_args()
|
2776 |
main(args)
|
|
|
47 |
"""
|
48 |
|
49 |
import json
|
50 |
+
from typing import List, Optional, Union
|
51 |
import glob
|
52 |
import importlib
|
53 |
import inspect
|
|
|
60 |
import os
|
61 |
import random
|
62 |
import re
|
63 |
+
from typing import Any, Callable, List, Optional, Union
|
64 |
|
65 |
import diffusers
|
66 |
import numpy as np
|
|
|
81 |
from PIL.PngImagePlugin import PngInfo
|
82 |
|
83 |
import library.model_util as model_util
|
|
|
|
|
|
|
84 |
|
85 |
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
86 |
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
|
|
487 |
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
488 |
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
489 |
|
|
|
|
|
|
|
490 |
# Textual Inversion
|
491 |
def add_token_replacement(self, target_token_id, rep_token_ids):
|
492 |
self.token_replacements[target_token_id] = rep_token_ids
|
|
|
500 |
new_tokens.append(token)
|
501 |
return new_tokens
|
502 |
|
|
|
|
|
|
|
503 |
# region xformersとか使う部分:独自に書き換えるので関係なし
|
|
|
504 |
def enable_xformers_memory_efficient_attention(self):
|
505 |
r"""
|
506 |
Enable memory efficient attention as implemented in xformers.
|
|
|
581 |
latents: Optional[torch.FloatTensor] = None,
|
582 |
max_embeddings_multiples: Optional[int] = 3,
|
583 |
output_type: Optional[str] = "pil",
|
|
|
|
|
584 |
# return_dict: bool = True,
|
585 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
586 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
|
|
672 |
else:
|
673 |
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
674 |
|
|
|
|
|
|
|
675 |
if strength < 0 or strength > 1:
|
676 |
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
677 |
|
|
|
752 |
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
|
753 |
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
|
754 |
|
755 |
+
if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None:
|
756 |
if isinstance(clip_guide_images, PIL.Image.Image):
|
757 |
clip_guide_images = [clip_guide_images]
|
758 |
|
|
|
765 |
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
|
766 |
if len(image_embeddings_clip) == 1:
|
767 |
image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
|
768 |
+
else:
|
769 |
size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
|
770 |
clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
|
771 |
clip_guide_images = torch.cat(clip_guide_images, dim=0)
|
|
|
774 |
image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
|
775 |
if len(image_embeddings_vgg16) == 1:
|
776 |
image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
|
|
|
|
|
|
|
|
|
777 |
|
778 |
# set timesteps
|
779 |
self.scheduler.set_timesteps(num_inference_steps, self.device)
|
|
|
781 |
latents_dtype = text_embeddings.dtype
|
782 |
init_latents_orig = None
|
783 |
mask = None
|
784 |
+
noise = None
|
785 |
|
786 |
if init_image is None:
|
787 |
# get the initial random noise unless the user supplied it
|
|
|
813 |
if isinstance(init_image[0], PIL.Image.Image):
|
814 |
init_image = [preprocess_image(im) for im in init_image]
|
815 |
init_image = torch.cat(init_image)
|
|
|
|
|
816 |
|
817 |
# mask image to tensor
|
818 |
if mask_image is not None:
|
|
|
823 |
|
824 |
# encode the init image into latents and scale the latents
|
825 |
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
826 |
+
init_latent_dist = self.vae.encode(init_image).latent_dist
|
827 |
+
init_latents = init_latent_dist.sample(generator=generator)
|
828 |
+
init_latents = 0.18215 * init_latents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
829 |
if len(init_latents) == 1:
|
830 |
init_latents = init_latents.repeat((batch_size, 1, 1, 1))
|
831 |
init_latents_orig = init_latents
|
|
|
864 |
extra_step_kwargs["eta"] = eta
|
865 |
|
866 |
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
|
|
|
|
|
|
|
|
|
867 |
for i, t in enumerate(tqdm(timesteps)):
|
868 |
# expand the latents if we are doing classifier free guidance
|
869 |
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
870 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
|
|
871 |
# predict the noise residual
|
872 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
|
|
|
|
|
|
|
|
873 |
|
874 |
# perform guidance
|
875 |
if do_classifier_free_guidance:
|
|
|
911 |
if is_cancelled_callback is not None and is_cancelled_callback():
|
912 |
return None
|
913 |
|
|
|
|
|
|
|
914 |
latents = 1 / 0.18215 * latents
|
915 |
+
image = self.vae.decode(latents).sample
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
916 |
|
917 |
image = (image / 2 + 0.5).clamp(0, 1)
|
918 |
|
|
|
1799 |
mask = mask.convert("L")
|
1800 |
w, h = mask.size
|
1801 |
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
1802 |
+
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.LANCZOS)
|
1803 |
mask = np.array(mask).astype(np.float32) / 255.0
|
1804 |
mask = np.tile(mask, (4, 1, 1))
|
1805 |
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
|
|
1817 |
# return text_encoder
|
1818 |
|
1819 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1820 |
def main(args):
|
1821 |
if args.fp16:
|
1822 |
dtype = torch.float16
|
|
|
1881 |
# tokenizerを読み込む
|
1882 |
print("loading tokenizer")
|
1883 |
if use_stable_diffusion_format:
|
1884 |
+
if args.v2:
|
1885 |
+
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
1886 |
+
else:
|
1887 |
+
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
|
1888 |
|
1889 |
# schedulerを用意する
|
1890 |
sched_init_args = {}
|
|
|
1995 |
# networkを組み込む
|
1996 |
if args.network_module:
|
1997 |
networks = []
|
|
|
1998 |
for i, network_module in enumerate(args.network_module):
|
1999 |
print("import network module:", network_module)
|
2000 |
imported_module = importlib.import_module(network_module)
|
2001 |
|
2002 |
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
|
|
2003 |
|
2004 |
net_kwargs = {}
|
2005 |
if args.network_args and i < len(args.network_args):
|
|
|
2014 |
network_weight = args.network_weights[i]
|
2015 |
print("load network weights from:", network_weight)
|
2016 |
|
2017 |
+
if model_util.is_safetensors(network_weight):
|
2018 |
from safetensors.torch import safe_open
|
2019 |
with safe_open(network_weight, framework="pt") as f:
|
2020 |
metadata = f.metadata()
|
|
|
2037 |
else:
|
2038 |
networks = []
|
2039 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2040 |
if args.opt_channels_last:
|
2041 |
print(f"set optimizing: channels last")
|
2042 |
text_encoder.to(memory_format=torch.channels_last)
|
|
|
2050 |
if vgg16_model is not None:
|
2051 |
vgg16_model.to(memory_format=torch.channels_last)
|
2052 |
|
|
|
|
|
|
|
|
|
2053 |
pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
|
2054 |
clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
|
2055 |
vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
|
|
|
2056 |
print("pipeline is ready.")
|
2057 |
|
2058 |
if args.diffusers_xformers:
|
|
|
2186 |
|
2187 |
prev_image = None # for VGG16 guided
|
2188 |
if args.guide_image_path is not None:
|
2189 |
+
print(f"load image for CLIP/VGG16 guidance: {args.guide_image_path}")
|
2190 |
+
guide_images = load_images(args.guide_image_path)
|
2191 |
+
print(f"loaded {len(guide_images)} guide images for CLIP/VGG16 guidance")
|
|
|
|
|
|
|
2192 |
if len(guide_images) == 0:
|
2193 |
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
|
2194 |
guide_images = None
|
|
|
2219 |
iter_seed = random.randint(0, 0x7fffffff)
|
2220 |
|
2221 |
# バッチ処理の関数
|
2222 |
+
def process_batch(batch, highres_fix, highres_1st=False):
|
2223 |
batch_size = len(batch)
|
2224 |
|
2225 |
# highres_fixの処理
|
2226 |
if highres_fix and not highres_1st:
|
2227 |
+
# 1st stageのバッチを作成して呼び出す
|
2228 |
+
print("process 1st stage1")
|
2229 |
batch_1st = []
|
2230 |
+
for params1, (width, height, steps, scale, negative_scale, strength) in batch:
|
2231 |
+
width_1st = int(width * args.highres_fix_scale + .5)
|
2232 |
+
height_1st = int(height * args.highres_fix_scale + .5)
|
2233 |
width_1st = width_1st - width_1st % 32
|
2234 |
height_1st = height_1st - height_1st % 32
|
2235 |
+
batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, negative_scale, strength)))
|
|
|
|
|
|
|
2236 |
images_1st = process_batch(batch_1st, True, True)
|
2237 |
|
2238 |
# 2nd stageのバッチを作成して以下処理する
|
2239 |
+
print("process 2nd stage1")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2240 |
batch_2nd = []
|
2241 |
+
for i, (b1, image) in enumerate(zip(batch, images_1st)):
|
2242 |
+
image = image.resize((width, height), resample=PIL.Image.LANCZOS)
|
2243 |
+
(step, prompt, negative_prompt, seed, _, _, clip_prompt, guide_image), params2 = b1
|
2244 |
+
batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2))
|
|
|
2245 |
batch = batch_2nd
|
2246 |
|
2247 |
+
(step_first, _, _, _, init_image, mask_image, _, guide_image), (width,
|
2248 |
+
height, steps, scale, negative_scale, strength) = batch[0]
|
|
|
2249 |
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
2250 |
|
2251 |
prompts = []
|
|
|
2278 |
all_images_are_same = True
|
2279 |
all_masks_are_same = True
|
2280 |
all_guide_images_are_same = True
|
2281 |
+
for i, ((_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
|
2282 |
prompts.append(prompt)
|
2283 |
negative_prompts.append(negative_prompt)
|
2284 |
seeds.append(seed)
|
|
|
2295 |
all_masks_are_same = mask_images[-2] is mask_image
|
2296 |
|
2297 |
if guide_image is not None:
|
2298 |
+
guide_images.append(guide_image)
|
2299 |
+
if i > 0 and all_guide_images_are_same:
|
2300 |
+
all_guide_images_are_same = guide_images[-2] is guide_image
|
|
|
|
|
|
|
|
|
2301 |
|
2302 |
# make start code
|
2303 |
torch.manual_seed(seed)
|
|
|
2320 |
if guide_images is not None and all_guide_images_are_same:
|
2321 |
guide_images = guide_images[0]
|
2322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2323 |
# generate
|
|
|
|
|
|
|
|
|
2324 |
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
|
2325 |
+
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
|
2326 |
+
if highres_1st and not args.highres_fix_save_1st:
|
|
|
|
|
2327 |
return images
|
2328 |
|
2329 |
# save image
|
|
|
2398 |
strength = 0.8 if args.strength is None else args.strength
|
2399 |
negative_prompt = ""
|
2400 |
clip_prompt = None
|
|
|
2401 |
|
2402 |
prompt_args = prompt.strip().split(' --')
|
2403 |
prompt = prompt_args[0]
|
|
|
2461 |
clip_prompt = m.group(1)
|
2462 |
print(f"clip prompt: {clip_prompt}")
|
2463 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2464 |
except ValueError as ex:
|
2465 |
print(f"Exception in parsing / 解析エラー: {parg}")
|
2466 |
print(ex)
|
|
|
2498 |
mask_image = mask_images[global_step % len(mask_images)]
|
2499 |
|
2500 |
if guide_images is not None:
|
2501 |
+
guide_image = guide_images[global_step % len(guide_images)]
|
|
|
|
|
|
|
|
|
|
|
2502 |
elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
|
2503 |
if prev_image is None:
|
2504 |
print("Generate 1st image without guide image.")
|
|
|
2506 |
print("Use previous image as guide image.")
|
2507 |
guide_image = prev_image
|
2508 |
|
2509 |
+
# TODO named tupleか何かにする
|
2510 |
+
b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
2511 |
+
(width, height, steps, scale, negative_scale, strength))
|
2512 |
+
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
|
2513 |
process_batch(batch_data, highres_fix)
|
2514 |
batch_data.clear()
|
2515 |
|
|
|
2553 |
parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
|
2554 |
parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
|
2555 |
parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
|
|
|
|
|
2556 |
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
|
2557 |
parser.add_argument('--sampler', type=str, default='ddim',
|
2558 |
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
|
|
|
2564 |
parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
|
2565 |
parser.add_argument("--vae", type=str, default=None,
|
2566 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
|
|
|
|
2567 |
# parser.add_argument("--replace_clip_l14_336", action='store_true',
|
2568 |
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
|
2569 |
parser.add_argument("--seed", type=int, default=None,
|
|
|
2578 |
parser.add_argument("--opt_channels_last", action='store_true',
|
2579 |
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
|
2580 |
parser.add_argument("--network_module", type=str, default=None, nargs='*',
|
2581 |
+
help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
|
2582 |
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
2583 |
+
help='Hypernetwork weights to load / Hypernetworkの重み')
|
2584 |
+
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
|
|
2585 |
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
2586 |
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
|
|
|
|
2587 |
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
|
2588 |
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
|
2589 |
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
|
|
2597 |
help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
|
2598 |
parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
|
2599 |
help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
|
2600 |
+
parser.add_argument("--guide_image_path", type=str, default=None, help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
|
|
|
2601 |
parser.add_argument("--highres_fix_scale", type=float, default=None,
|
2602 |
help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
|
2603 |
parser.add_argument("--highres_fix_steps", type=int, default=28,
|
2604 |
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
|
2605 |
parser.add_argument("--highres_fix_save_1st", action='store_true',
|
2606 |
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
|
|
|
|
|
2607 |
parser.add_argument("--negative_scale", type=float, default=None,
|
2608 |
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
|
2609 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2610 |
args = parser.parse_args()
|
2611 |
main(args)
|
library/train_util.py
CHANGED
@@ -1,21 +1,12 @@
|
|
1 |
# common functions for training
|
2 |
|
3 |
import argparse
|
4 |
-
import importlib
|
5 |
import json
|
6 |
-
import re
|
7 |
import shutil
|
8 |
import time
|
9 |
-
from typing import
|
10 |
-
Dict,
|
11 |
-
List,
|
12 |
-
NamedTuple,
|
13 |
-
Optional,
|
14 |
-
Sequence,
|
15 |
-
Tuple,
|
16 |
-
Union,
|
17 |
-
)
|
18 |
from accelerate import Accelerator
|
|
|
19 |
import glob
|
20 |
import math
|
21 |
import os
|
@@ -26,16 +17,10 @@ from io import BytesIO
|
|
26 |
|
27 |
from tqdm import tqdm
|
28 |
import torch
|
29 |
-
from torch.optim import Optimizer
|
30 |
from torchvision import transforms
|
31 |
from transformers import CLIPTokenizer
|
32 |
-
import transformers
|
33 |
import diffusers
|
34 |
-
from diffusers
|
35 |
-
from diffusers import (StableDiffusionPipeline, DDPMScheduler,
|
36 |
-
EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler,
|
37 |
-
LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler,
|
38 |
-
KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler)
|
39 |
import albumentations as albu
|
40 |
import numpy as np
|
41 |
from PIL import Image
|
@@ -210,93 +195,23 @@ class BucketBatchIndex(NamedTuple):
|
|
210 |
batch_index: int
|
211 |
|
212 |
|
213 |
-
class AugHelper:
|
214 |
-
def __init__(self):
|
215 |
-
# prepare all possible augmentators
|
216 |
-
color_aug_method = albu.OneOf([
|
217 |
-
albu.HueSaturationValue(8, 0, 0, p=.5),
|
218 |
-
albu.RandomGamma((95, 105), p=.5),
|
219 |
-
], p=.33)
|
220 |
-
flip_aug_method = albu.HorizontalFlip(p=0.5)
|
221 |
-
|
222 |
-
# key: (use_color_aug, use_flip_aug)
|
223 |
-
self.augmentors = {
|
224 |
-
(True, True): albu.Compose([
|
225 |
-
color_aug_method,
|
226 |
-
flip_aug_method,
|
227 |
-
], p=1.),
|
228 |
-
(True, False): albu.Compose([
|
229 |
-
color_aug_method,
|
230 |
-
], p=1.),
|
231 |
-
(False, True): albu.Compose([
|
232 |
-
flip_aug_method,
|
233 |
-
], p=1.),
|
234 |
-
(False, False): None
|
235 |
-
}
|
236 |
-
|
237 |
-
def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]:
|
238 |
-
return self.augmentors[(use_color_aug, use_flip_aug)]
|
239 |
-
|
240 |
-
|
241 |
-
class BaseSubset:
|
242 |
-
def __init__(self, image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, keep_tokens: int, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], random_crop: bool, caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float) -> None:
|
243 |
-
self.image_dir = image_dir
|
244 |
-
self.num_repeats = num_repeats
|
245 |
-
self.shuffle_caption = shuffle_caption
|
246 |
-
self.keep_tokens = keep_tokens
|
247 |
-
self.color_aug = color_aug
|
248 |
-
self.flip_aug = flip_aug
|
249 |
-
self.face_crop_aug_range = face_crop_aug_range
|
250 |
-
self.random_crop = random_crop
|
251 |
-
self.caption_dropout_rate = caption_dropout_rate
|
252 |
-
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
|
253 |
-
self.caption_tag_dropout_rate = caption_tag_dropout_rate
|
254 |
-
|
255 |
-
self.img_count = 0
|
256 |
-
|
257 |
-
|
258 |
-
class DreamBoothSubset(BaseSubset):
|
259 |
-
def __init__(self, image_dir: str, is_reg: bool, class_tokens: Optional[str], caption_extension: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
|
260 |
-
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
261 |
-
|
262 |
-
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
|
263 |
-
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
|
264 |
-
|
265 |
-
self.is_reg = is_reg
|
266 |
-
self.class_tokens = class_tokens
|
267 |
-
self.caption_extension = caption_extension
|
268 |
-
|
269 |
-
def __eq__(self, other) -> bool:
|
270 |
-
if not isinstance(other, DreamBoothSubset):
|
271 |
-
return NotImplemented
|
272 |
-
return self.image_dir == other.image_dir
|
273 |
-
|
274 |
-
class FineTuningSubset(BaseSubset):
|
275 |
-
def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
|
276 |
-
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
277 |
-
|
278 |
-
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
|
279 |
-
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
|
280 |
-
|
281 |
-
self.metadata_file = metadata_file
|
282 |
-
|
283 |
-
def __eq__(self, other) -> bool:
|
284 |
-
if not isinstance(other, FineTuningSubset):
|
285 |
-
return NotImplemented
|
286 |
-
return self.metadata_file == other.metadata_file
|
287 |
-
|
288 |
class BaseDataset(torch.utils.data.Dataset):
|
289 |
-
def __init__(self, tokenizer
|
290 |
super().__init__()
|
291 |
-
self.tokenizer = tokenizer
|
292 |
self.max_token_length = max_token_length
|
|
|
|
|
293 |
# width/height is used when enable_bucket==False
|
294 |
self.width, self.height = (None, None) if resolution is None else resolution
|
|
|
|
|
|
|
295 |
self.debug_dataset = debug_dataset
|
296 |
-
|
297 |
-
self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
|
298 |
-
|
299 |
self.token_padding_disabled = False
|
|
|
|
|
300 |
self.tag_frequency = {}
|
301 |
|
302 |
self.enable_bucket = False
|
@@ -310,28 +225,49 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
310 |
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
311 |
|
312 |
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
|
|
|
|
|
|
313 |
|
314 |
# augmentation
|
315 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
|
317 |
self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
|
318 |
|
319 |
self.image_data: Dict[str, ImageInfo] = {}
|
320 |
-
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
|
321 |
|
322 |
self.replacements = {}
|
323 |
|
324 |
def set_current_epoch(self, epoch):
|
325 |
self.current_epoch = epoch
|
326 |
-
|
|
|
|
|
|
|
|
|
|
|
327 |
|
328 |
def set_tag_frequency(self, dir_name, captions):
|
329 |
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
330 |
self.tag_frequency[dir_name] = frequency_for_dir
|
331 |
for caption in captions:
|
332 |
for tag in caption.split(","):
|
333 |
-
tag
|
334 |
-
if tag:
|
335 |
tag = tag.lower()
|
336 |
frequency = frequency_for_dir.get(tag, 0)
|
337 |
frequency_for_dir[tag] = frequency + 1
|
@@ -342,36 +278,42 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
342 |
def add_replacement(self, str_from, str_to):
|
343 |
self.replacements[str_from] = str_to
|
344 |
|
345 |
-
def process_caption(self,
|
346 |
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
347 |
-
is_drop_out =
|
348 |
-
is_drop_out = is_drop_out or
|
349 |
|
350 |
if is_drop_out:
|
351 |
caption = ""
|
352 |
else:
|
353 |
-
if
|
354 |
def dropout_tags(tokens):
|
355 |
-
if
|
356 |
return tokens
|
357 |
l = []
|
358 |
for token in tokens:
|
359 |
-
if random.random() >=
|
360 |
l.append(token)
|
361 |
return l
|
362 |
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
|
|
|
|
|
|
|
|
|
|
368 |
|
369 |
-
|
370 |
-
|
371 |
|
372 |
-
|
373 |
|
374 |
-
|
|
|
375 |
|
376 |
# textual inversion対応
|
377 |
for str_from, str_to in self.replacements.items():
|
@@ -425,9 +367,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
425 |
input_ids = torch.stack(iids_list) # 3,77
|
426 |
return input_ids
|
427 |
|
428 |
-
def register_image(self, info: ImageInfo
|
429 |
self.image_data[info.image_key] = info
|
430 |
-
self.image_to_subset[info.image_key] = subset
|
431 |
|
432 |
def make_buckets(self):
|
433 |
'''
|
@@ -526,7 +467,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
526 |
img = np.array(image, np.uint8)
|
527 |
return img
|
528 |
|
529 |
-
def trim_and_resize_if_required(self,
|
530 |
image_height, image_width = image.shape[0:2]
|
531 |
|
532 |
if image_width != resized_size[0] or image_height != resized_size[1]:
|
@@ -536,27 +477,22 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
536 |
image_height, image_width = image.shape[0:2]
|
537 |
if image_width > reso[0]:
|
538 |
trim_size = image_width - reso[0]
|
539 |
-
p = trim_size // 2 if not
|
540 |
# print("w", trim_size, p)
|
541 |
image = image[:, p:p + reso[0]]
|
542 |
if image_height > reso[1]:
|
543 |
trim_size = image_height - reso[1]
|
544 |
-
p = trim_size // 2 if not
|
545 |
# print("h", trim_size, p)
|
546 |
image = image[p:p + reso[1]]
|
547 |
|
548 |
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
549 |
return image
|
550 |
|
551 |
-
def is_latent_cacheable(self):
|
552 |
-
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
553 |
-
|
554 |
def cache_latents(self, vae):
|
555 |
# TODO ここを高速化したい
|
556 |
print("caching latents.")
|
557 |
for info in tqdm(self.image_data.values()):
|
558 |
-
subset = self.image_to_subset[info.image_key]
|
559 |
-
|
560 |
if info.latents_npz is not None:
|
561 |
info.latents = self.load_latents_from_npz(info, False)
|
562 |
info.latents = torch.FloatTensor(info.latents)
|
@@ -566,13 +502,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
566 |
continue
|
567 |
|
568 |
image = self.load_image(info.absolute_path)
|
569 |
-
image = self.trim_and_resize_if_required(
|
570 |
|
571 |
img_tensor = self.image_transforms(image)
|
572 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
573 |
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
574 |
|
575 |
-
if
|
576 |
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
|
577 |
img_tensor = self.image_transforms(image)
|
578 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
@@ -582,11 +518,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
582 |
image = Image.open(image_path)
|
583 |
return image.size
|
584 |
|
585 |
-
def load_image_with_face_info(self,
|
586 |
img = self.load_image(image_path)
|
587 |
|
588 |
face_cx = face_cy = face_w = face_h = 0
|
589 |
-
if
|
590 |
tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
|
591 |
if len(tokens) >= 5:
|
592 |
face_cx = int(tokens[-4])
|
@@ -597,7 +533,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
597 |
return img, face_cx, face_cy, face_w, face_h
|
598 |
|
599 |
# いい感じに切り出す
|
600 |
-
def crop_target(self,
|
601 |
height, width = image.shape[0:2]
|
602 |
if height == self.height and width == self.width:
|
603 |
return image
|
@@ -605,8 +541,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
605 |
# 画像サイズはsizeより大きいのでリサイズする
|
606 |
face_size = max(face_w, face_h)
|
607 |
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
|
608 |
-
min_scale = min(1.0, max(min_scale, self.size / (face_size *
|
609 |
-
max_scale = min(1.0, max(min_scale, self.size / (face_size *
|
610 |
if min_scale >= max_scale: # range指定がmin==max
|
611 |
scale = min_scale
|
612 |
else:
|
@@ -624,13 +560,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
624 |
for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
|
625 |
p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
|
626 |
|
627 |
-
if
|
628 |
# 背景も含めるために顔を中心に置く確率を高めつつずらす
|
629 |
range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
|
630 |
p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
|
631 |
else:
|
632 |
# range指定があるときのみ、すこしだけランダムに(わりと適当)
|
633 |
-
if
|
634 |
if face_size > self.size // 10 and face_size >= 40:
|
635 |
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
|
636 |
|
@@ -653,6 +589,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
653 |
return self._length
|
654 |
|
655 |
def __getitem__(self, index):
|
|
|
|
|
|
|
656 |
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
657 |
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
658 |
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
@@ -665,29 +604,28 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
665 |
|
666 |
for image_key in bucket[image_index:image_index + bucket_batch_size]:
|
667 |
image_info = self.image_data[image_key]
|
668 |
-
subset = self.image_to_subset[image_key]
|
669 |
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
670 |
|
671 |
# image/latentsを処理する
|
672 |
if image_info.latents is not None:
|
673 |
-
latents = image_info.latents if not
|
674 |
image = None
|
675 |
elif image_info.latents_npz is not None:
|
676 |
-
latents = self.load_latents_from_npz(image_info,
|
677 |
latents = torch.FloatTensor(latents)
|
678 |
image = None
|
679 |
else:
|
680 |
# 画像を読み込み、必要ならcropする
|
681 |
-
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(
|
682 |
im_h, im_w = img.shape[0:2]
|
683 |
|
684 |
if self.enable_bucket:
|
685 |
-
img = self.trim_and_resize_if_required(
|
686 |
else:
|
687 |
if face_cx > 0: # 顔位置情報あり
|
688 |
-
img = self.crop_target(
|
689 |
elif im_h > self.height or im_w > self.width:
|
690 |
-
assert
|
691 |
if im_h > self.height:
|
692 |
p = random.randint(0, im_h - self.height)
|
693 |
img = img[p:p + self.height]
|
@@ -699,9 +637,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
699 |
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
700 |
|
701 |
# augmentation
|
702 |
-
|
703 |
-
|
704 |
-
img = aug(image=img)['image']
|
705 |
|
706 |
latents = None
|
707 |
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
@@ -709,7 +646,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
709 |
images.append(image)
|
710 |
latents_list.append(latents)
|
711 |
|
712 |
-
caption = self.process_caption(
|
713 |
captions.append(caption)
|
714 |
if not self.token_padding_disabled: # this option might be omitted in future
|
715 |
input_ids_list.append(self.get_input_ids(caption))
|
@@ -740,8 +677,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
740 |
|
741 |
|
742 |
class DreamBoothDataset(BaseDataset):
|
743 |
-
def __init__(self,
|
744 |
-
super().__init__(tokenizer, max_token_length,
|
|
|
745 |
|
746 |
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
747 |
|
@@ -764,7 +702,7 @@ class DreamBoothDataset(BaseDataset):
|
|
764 |
self.bucket_reso_steps = None # この情報は使われない
|
765 |
self.bucket_no_upscale = False
|
766 |
|
767 |
-
def read_caption(img_path
|
768 |
# captionの候補ファイル名を作る
|
769 |
base_name = os.path.splitext(img_path)[0]
|
770 |
base_name_face_det = base_name
|
@@ -787,171 +725,153 @@ class DreamBoothDataset(BaseDataset):
|
|
787 |
break
|
788 |
return caption
|
789 |
|
790 |
-
def load_dreambooth_dir(
|
791 |
-
if not os.path.isdir(
|
792 |
-
print(f"
|
793 |
-
return [], []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
794 |
|
795 |
-
|
796 |
-
|
|
|
797 |
|
798 |
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
799 |
captions = []
|
800 |
for img_path in img_paths:
|
801 |
-
cap_for_img = read_caption(img_path
|
802 |
-
if cap_for_img is None
|
803 |
-
print(f"neither caption file nor class tokens are found. use empty caption for {img_path}")
|
804 |
-
captions.append("")
|
805 |
-
else:
|
806 |
-
captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
|
807 |
-
|
808 |
-
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
809 |
|
810 |
-
|
811 |
-
|
812 |
-
print("prepare images.")
|
813 |
-
num_train_images = 0
|
814 |
-
num_reg_images = 0
|
815 |
-
reg_infos: List[ImageInfo] = []
|
816 |
-
for subset in subsets:
|
817 |
-
if subset.num_repeats < 1:
|
818 |
-
print(f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
819 |
-
continue
|
820 |
-
|
821 |
-
if subset in self.subsets:
|
822 |
-
print(f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
823 |
-
continue
|
824 |
|
825 |
-
img_paths, captions
|
826 |
-
if len(img_paths) < 1:
|
827 |
-
print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
|
828 |
-
continue
|
829 |
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
|
|
|
|
834 |
|
835 |
for img_path, caption in zip(img_paths, captions):
|
836 |
-
info = ImageInfo(img_path,
|
837 |
-
|
838 |
-
reg_infos.append(info)
|
839 |
-
else:
|
840 |
-
self.register_image(info, subset)
|
841 |
|
842 |
-
|
843 |
-
self.subsets.append(subset)
|
844 |
|
845 |
print(f"{num_train_images} train images with repeating.")
|
846 |
self.num_train_images = num_train_images
|
847 |
|
848 |
-
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
-
|
853 |
-
print("no regularization images / 正則化画像が見つかりませんでした")
|
854 |
-
else:
|
855 |
-
# num_repeatsを計算する:どうせ大した数ではないのでループで処理する
|
856 |
-
n = 0
|
857 |
-
first_loop = True
|
858 |
-
while n < num_train_images:
|
859 |
-
for info in reg_infos:
|
860 |
-
if first_loop:
|
861 |
-
self.register_image(info, subset)
|
862 |
-
n += info.num_repeats
|
863 |
-
else:
|
864 |
-
info.num_repeats += 1
|
865 |
-
n += 1
|
866 |
-
if n >= num_train_images:
|
867 |
-
break
|
868 |
-
first_loop = False
|
869 |
-
|
870 |
-
self.num_reg_images = num_reg_images
|
871 |
-
|
872 |
-
|
873 |
-
class FineTuningDataset(BaseDataset):
|
874 |
-
def __init__(self, subsets: Sequence[FineTuningSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset) -> None:
|
875 |
-
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
876 |
|
877 |
-
|
|
|
|
|
|
|
878 |
|
879 |
-
|
880 |
-
|
|
|
881 |
|
882 |
-
|
883 |
-
if subset.num_repeats < 1:
|
884 |
-
print(f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
885 |
-
continue
|
886 |
|
887 |
-
|
888 |
-
|
889 |
-
|
890 |
|
891 |
-
|
892 |
-
|
893 |
-
print(f"loading existing metadata: {subset.metadata_file}")
|
894 |
-
with open(subset.metadata_file, "rt", encoding='utf-8') as f:
|
895 |
-
metadata = json.load(f)
|
896 |
else:
|
897 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
898 |
|
899 |
-
|
900 |
-
print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します")
|
901 |
-
continue
|
902 |
-
|
903 |
-
tags_list = []
|
904 |
-
for image_key, img_md in metadata.items():
|
905 |
-
# path情報を作る
|
906 |
-
if os.path.exists(image_key):
|
907 |
-
abs_path = image_key
|
908 |
-
else:
|
909 |
-
# わりといい加減だがいい方法が思いつかん
|
910 |
-
abs_path = glob_images(subset.image_dir, image_key)
|
911 |
-
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
912 |
-
abs_path = abs_path[0]
|
913 |
-
|
914 |
-
caption = img_md.get('caption')
|
915 |
-
tags = img_md.get('tags')
|
916 |
-
if caption is None:
|
917 |
-
caption = tags
|
918 |
-
elif tags is not None and len(tags) > 0:
|
919 |
-
caption = caption + ', ' + tags
|
920 |
-
tags_list.append(tags)
|
921 |
-
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
|
922 |
|
923 |
-
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
|
924 |
-
image_info.image_size = img_md.get('train_resolution')
|
925 |
|
926 |
-
|
927 |
-
|
928 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
929 |
|
930 |
-
|
|
|
|
|
931 |
|
932 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
933 |
|
934 |
-
|
935 |
-
|
936 |
-
|
937 |
-
self.subsets.append(subset)
|
938 |
|
939 |
# check existence of all npz files
|
940 |
-
use_npz_latents =
|
941 |
if use_npz_latents:
|
942 |
-
flip_aug_in_subset = False
|
943 |
npz_any = False
|
944 |
npz_all = True
|
945 |
-
|
946 |
for image_info in self.image_data.values():
|
947 |
-
subset = self.image_to_subset[image_info.image_key]
|
948 |
-
|
949 |
has_npz = image_info.latents_npz is not None
|
950 |
npz_any = npz_any or has_npz
|
951 |
|
952 |
-
if
|
953 |
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
954 |
-
flip_aug_in_subset = True
|
955 |
npz_all = npz_all and has_npz
|
956 |
|
957 |
if npz_any and not npz_all:
|
@@ -963,7 +883,7 @@ class FineTuningDataset(BaseDataset):
|
|
963 |
elif not npz_all:
|
964 |
use_npz_latents = False
|
965 |
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
|
966 |
-
if
|
967 |
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
968 |
# else:
|
969 |
# print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
@@ -1009,7 +929,7 @@ class FineTuningDataset(BaseDataset):
|
|
1009 |
for image_info in self.image_data.values():
|
1010 |
image_info.latents_npz = image_info.latents_npz_flipped = None
|
1011 |
|
1012 |
-
def image_key_to_npz_file(self,
|
1013 |
base_name = os.path.splitext(image_key)[0]
|
1014 |
npz_file_norm = base_name + '.npz'
|
1015 |
|
@@ -1021,8 +941,8 @@ class FineTuningDataset(BaseDataset):
|
|
1021 |
return npz_file_norm, npz_file_flip
|
1022 |
|
1023 |
# image_key is relative path
|
1024 |
-
npz_file_norm = os.path.join(
|
1025 |
-
npz_file_flip = os.path.join(
|
1026 |
|
1027 |
if not os.path.exists(npz_file_norm):
|
1028 |
npz_file_norm = None
|
@@ -1033,60 +953,13 @@ class FineTuningDataset(BaseDataset):
|
|
1033 |
return npz_file_norm, npz_file_flip
|
1034 |
|
1035 |
|
1036 |
-
# behave as Dataset mock
|
1037 |
-
class DatasetGroup(torch.utils.data.ConcatDataset):
|
1038 |
-
def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
|
1039 |
-
self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]]
|
1040 |
-
|
1041 |
-
super().__init__(datasets)
|
1042 |
-
|
1043 |
-
self.image_data = {}
|
1044 |
-
self.num_train_images = 0
|
1045 |
-
self.num_reg_images = 0
|
1046 |
-
|
1047 |
-
# simply concat together
|
1048 |
-
# TODO: handling image_data key duplication among dataset
|
1049 |
-
# In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset.
|
1050 |
-
for dataset in datasets:
|
1051 |
-
self.image_data.update(dataset.image_data)
|
1052 |
-
self.num_train_images += dataset.num_train_images
|
1053 |
-
self.num_reg_images += dataset.num_reg_images
|
1054 |
-
|
1055 |
-
def add_replacement(self, str_from, str_to):
|
1056 |
-
for dataset in self.datasets:
|
1057 |
-
dataset.add_replacement(str_from, str_to)
|
1058 |
-
|
1059 |
-
# def make_buckets(self):
|
1060 |
-
# for dataset in self.datasets:
|
1061 |
-
# dataset.make_buckets()
|
1062 |
-
|
1063 |
-
def cache_latents(self, vae):
|
1064 |
-
for i, dataset in enumerate(self.datasets):
|
1065 |
-
print(f"[Dataset {i}]")
|
1066 |
-
dataset.cache_latents(vae)
|
1067 |
-
|
1068 |
-
def is_latent_cacheable(self) -> bool:
|
1069 |
-
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
1070 |
-
|
1071 |
-
def set_current_epoch(self, epoch):
|
1072 |
-
for dataset in self.datasets:
|
1073 |
-
dataset.set_current_epoch(epoch)
|
1074 |
-
|
1075 |
-
def disable_token_padding(self):
|
1076 |
-
for dataset in self.datasets:
|
1077 |
-
dataset.disable_token_padding()
|
1078 |
-
|
1079 |
-
|
1080 |
def debug_dataset(train_dataset, show_input_ids=False):
|
1081 |
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
1082 |
print("Escape for exit. / Escキーで中断、終了します")
|
1083 |
|
1084 |
train_dataset.set_current_epoch(1)
|
1085 |
k = 0
|
1086 |
-
|
1087 |
-
random.shuffle(indices)
|
1088 |
-
for i, idx in enumerate(indices):
|
1089 |
-
example = train_dataset[idx]
|
1090 |
if example['latents'] is not None:
|
1091 |
print(f"sample has latents from npz file: {example['latents'].size()}")
|
1092 |
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
|
@@ -1491,35 +1364,6 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
|
1491 |
help='enable v-parameterization training / v-parameterization学習を有効にする')
|
1492 |
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
|
1493 |
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
1494 |
-
parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
|
1495 |
-
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
|
1496 |
-
|
1497 |
-
|
1498 |
-
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
1499 |
-
parser.add_argument("--optimizer_type", type=str, default="",
|
1500 |
-
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor")
|
1501 |
-
|
1502 |
-
# backward compatibility
|
1503 |
-
parser.add_argument("--use_8bit_adam", action="store_true",
|
1504 |
-
help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
1505 |
-
parser.add_argument("--use_lion_optimizer", action="store_true",
|
1506 |
-
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
1507 |
-
|
1508 |
-
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
1509 |
-
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
1510 |
-
help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない")
|
1511 |
-
|
1512 |
-
parser.add_argument("--optimizer_args", type=str, default=None, nargs='*',
|
1513 |
-
help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")")
|
1514 |
-
|
1515 |
-
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
1516 |
-
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor")
|
1517 |
-
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
1518 |
-
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
1519 |
-
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
1520 |
-
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
1521 |
-
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
1522 |
-
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
1523 |
|
1524 |
|
1525 |
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
@@ -1543,6 +1387,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|
1543 |
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
1544 |
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
1545 |
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
|
|
|
|
|
|
|
|
1546 |
parser.add_argument("--mem_eff_attn", action="store_true",
|
1547 |
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
1548 |
parser.add_argument("--xformers", action="store_true",
|
@@ -1550,6 +1398,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|
1550 |
parser.add_argument("--vae", type=str, default=None,
|
1551 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
1552 |
|
|
|
1553 |
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
1554 |
parser.add_argument("--max_train_epochs", type=int, default=None,
|
1555 |
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
@@ -1570,23 +1419,15 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|
1570 |
parser.add_argument("--logging_dir", type=str, default=None,
|
1571 |
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
1572 |
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
|
|
|
|
|
|
|
|
1573 |
parser.add_argument("--noise_offset", type=float, default=None,
|
1574 |
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
|
1575 |
parser.add_argument("--lowram", action="store_true",
|
1576 |
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
|
1577 |
|
1578 |
-
parser.add_argument("--sample_every_n_steps", type=int, default=None,
|
1579 |
-
help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する")
|
1580 |
-
parser.add_argument("--sample_every_n_epochs", type=int, default=None,
|
1581 |
-
help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)")
|
1582 |
-
parser.add_argument("--sample_prompts", type=str, default=None,
|
1583 |
-
help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル")
|
1584 |
-
parser.add_argument('--sample_sampler', type=str, default='ddim',
|
1585 |
-
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
|
1586 |
-
'dpmsolver++', 'dpmsingle',
|
1587 |
-
'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'],
|
1588 |
-
help=f'sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類')
|
1589 |
-
|
1590 |
if support_dreambooth:
|
1591 |
# DreamBooth training
|
1592 |
parser.add_argument("--prior_loss_weight", type=float, default=1.0,
|
@@ -1608,8 +1449,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
|
1608 |
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
|
1609 |
parser.add_argument("--caption_extention", type=str, default=None,
|
1610 |
help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
|
1611 |
-
parser.add_argument("--keep_tokens", type=int, default=
|
1612 |
-
help="keep heading N tokens when shuffling caption tokens
|
1613 |
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
|
1614 |
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
|
1615 |
parser.add_argument("--face_crop_aug_range", type=str, default=None,
|
@@ -1634,11 +1475,11 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
|
1634 |
if support_caption_dropout:
|
1635 |
# Textual Inversion はcaptionのdropoutをsupportしない
|
1636 |
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
1637 |
-
parser.add_argument("--caption_dropout_rate", type=float, default=0
|
1638 |
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
1639 |
-
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=
|
1640 |
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
1641 |
-
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0
|
1642 |
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
|
1643 |
|
1644 |
if support_dreambooth:
|
@@ -1663,249 +1504,16 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
|
1663 |
# region utils
|
1664 |
|
1665 |
|
1666 |
-
def get_optimizer(args, trainable_params):
|
1667 |
-
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
|
1668 |
-
|
1669 |
-
optimizer_type = args.optimizer_type
|
1670 |
-
if args.use_8bit_adam:
|
1671 |
-
assert not args.use_lion_optimizer, "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています"
|
1672 |
-
assert optimizer_type is None or optimizer_type == "", "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています"
|
1673 |
-
optimizer_type = "AdamW8bit"
|
1674 |
-
|
1675 |
-
elif args.use_lion_optimizer:
|
1676 |
-
assert optimizer_type is None or optimizer_type == "", "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています"
|
1677 |
-
optimizer_type = "Lion"
|
1678 |
-
|
1679 |
-
if optimizer_type is None or optimizer_type == "":
|
1680 |
-
optimizer_type = "AdamW"
|
1681 |
-
optimizer_type = optimizer_type.lower()
|
1682 |
-
|
1683 |
-
# 引数を分解する:boolとfloat、tupleのみ対応
|
1684 |
-
optimizer_kwargs = {}
|
1685 |
-
if args.optimizer_args is not None and len(args.optimizer_args) > 0:
|
1686 |
-
for arg in args.optimizer_args:
|
1687 |
-
key, value = arg.split('=')
|
1688 |
-
|
1689 |
-
value = value.split(",")
|
1690 |
-
for i in range(len(value)):
|
1691 |
-
if value[i].lower() == "true" or value[i].lower() == "false":
|
1692 |
-
value[i] = (value[i].lower() == "true")
|
1693 |
-
else:
|
1694 |
-
value[i] = float(value[i])
|
1695 |
-
if len(value) == 1:
|
1696 |
-
value = value[0]
|
1697 |
-
else:
|
1698 |
-
value = tuple(value)
|
1699 |
-
|
1700 |
-
optimizer_kwargs[key] = value
|
1701 |
-
# print("optkwargs:", optimizer_kwargs)
|
1702 |
-
|
1703 |
-
lr = args.learning_rate
|
1704 |
-
|
1705 |
-
if optimizer_type == "AdamW8bit".lower():
|
1706 |
-
try:
|
1707 |
-
import bitsandbytes as bnb
|
1708 |
-
except ImportError:
|
1709 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
1710 |
-
print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
|
1711 |
-
optimizer_class = bnb.optim.AdamW8bit
|
1712 |
-
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1713 |
-
|
1714 |
-
elif optimizer_type == "SGDNesterov8bit".lower():
|
1715 |
-
try:
|
1716 |
-
import bitsandbytes as bnb
|
1717 |
-
except ImportError:
|
1718 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
1719 |
-
print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
|
1720 |
-
if "momentum" not in optimizer_kwargs:
|
1721 |
-
print(f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
|
1722 |
-
optimizer_kwargs["momentum"] = 0.9
|
1723 |
-
|
1724 |
-
optimizer_class = bnb.optim.SGD8bit
|
1725 |
-
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
1726 |
-
|
1727 |
-
elif optimizer_type == "Lion".lower():
|
1728 |
-
try:
|
1729 |
-
import lion_pytorch
|
1730 |
-
except ImportError:
|
1731 |
-
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
1732 |
-
print(f"use Lion optimizer | {optimizer_kwargs}")
|
1733 |
-
optimizer_class = lion_pytorch.Lion
|
1734 |
-
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1735 |
-
|
1736 |
-
elif optimizer_type == "SGDNesterov".lower():
|
1737 |
-
print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
|
1738 |
-
if "momentum" not in optimizer_kwargs:
|
1739 |
-
print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
|
1740 |
-
optimizer_kwargs["momentum"] = 0.9
|
1741 |
-
|
1742 |
-
optimizer_class = torch.optim.SGD
|
1743 |
-
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
1744 |
-
|
1745 |
-
elif optimizer_type == "DAdaptation".lower():
|
1746 |
-
try:
|
1747 |
-
import dadaptation
|
1748 |
-
except ImportError:
|
1749 |
-
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
1750 |
-
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
|
1751 |
-
|
1752 |
-
min_lr = lr
|
1753 |
-
if type(trainable_params) == list and type(trainable_params[0]) == dict:
|
1754 |
-
for group in trainable_params:
|
1755 |
-
min_lr = min(min_lr, group.get("lr", lr))
|
1756 |
-
|
1757 |
-
if min_lr <= 0.1:
|
1758 |
-
print(
|
1759 |
-
f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {min_lr}')
|
1760 |
-
print('recommend option: lr=1.0 / 推奨は1.0です')
|
1761 |
-
|
1762 |
-
optimizer_class = dadaptation.DAdaptAdam
|
1763 |
-
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1764 |
-
|
1765 |
-
elif optimizer_type == "Adafactor".lower():
|
1766 |
-
# 引数を確認して適宜補正する
|
1767 |
-
if "relative_step" not in optimizer_kwargs:
|
1768 |
-
optimizer_kwargs["relative_step"] = True # default
|
1769 |
-
if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
|
1770 |
-
print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします")
|
1771 |
-
optimizer_kwargs["relative_step"] = True
|
1772 |
-
print(f"use Adafactor optimizer | {optimizer_kwargs}")
|
1773 |
-
|
1774 |
-
if optimizer_kwargs["relative_step"]:
|
1775 |
-
print(f"relative_step is true / relative_stepがtrueです")
|
1776 |
-
if lr != 0.0:
|
1777 |
-
print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
|
1778 |
-
args.learning_rate = None
|
1779 |
-
|
1780 |
-
# trainable_paramsがgroupだった時の処理:lrを削除する
|
1781 |
-
if type(trainable_params) == list and type(trainable_params[0]) == dict:
|
1782 |
-
has_group_lr = False
|
1783 |
-
for group in trainable_params:
|
1784 |
-
p = group.pop("lr", None)
|
1785 |
-
has_group_lr = has_group_lr or (p is not None)
|
1786 |
-
|
1787 |
-
if has_group_lr:
|
1788 |
-
# 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない
|
1789 |
-
print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます")
|
1790 |
-
args.unet_lr = None
|
1791 |
-
args.text_encoder_lr = None
|
1792 |
-
|
1793 |
-
if args.lr_scheduler != "adafactor":
|
1794 |
-
print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
|
1795 |
-
args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
|
1796 |
-
|
1797 |
-
lr = None
|
1798 |
-
else:
|
1799 |
-
if args.max_grad_norm != 0.0:
|
1800 |
-
print(f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません")
|
1801 |
-
if args.lr_scheduler != "constant_with_warmup":
|
1802 |
-
print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
|
1803 |
-
if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
|
1804 |
-
print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
|
1805 |
-
|
1806 |
-
optimizer_class = transformers.optimization.Adafactor
|
1807 |
-
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1808 |
-
|
1809 |
-
elif optimizer_type == "AdamW".lower():
|
1810 |
-
print(f"use AdamW optimizer | {optimizer_kwargs}")
|
1811 |
-
optimizer_class = torch.optim.AdamW
|
1812 |
-
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1813 |
-
|
1814 |
-
else:
|
1815 |
-
# 任意のoptimizerを使う
|
1816 |
-
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
1817 |
-
print(f"use {optimizer_type} | {optimizer_kwargs}")
|
1818 |
-
if "." not in optimizer_type:
|
1819 |
-
optimizer_module = torch.optim
|
1820 |
-
else:
|
1821 |
-
values = optimizer_type.split(".")
|
1822 |
-
optimizer_module = importlib.import_module(".".join(values[:-1]))
|
1823 |
-
optimizer_type = values[-1]
|
1824 |
-
|
1825 |
-
optimizer_class = getattr(optimizer_module, optimizer_type)
|
1826 |
-
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1827 |
-
|
1828 |
-
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
1829 |
-
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
|
1830 |
-
|
1831 |
-
return optimizer_name, optimizer_args, optimizer
|
1832 |
-
|
1833 |
-
|
1834 |
-
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
1835 |
-
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
1836 |
-
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
1837 |
-
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
1838 |
-
|
1839 |
-
|
1840 |
-
def get_scheduler_fix(
|
1841 |
-
name: Union[str, SchedulerType],
|
1842 |
-
optimizer: Optimizer,
|
1843 |
-
num_warmup_steps: Optional[int] = None,
|
1844 |
-
num_training_steps: Optional[int] = None,
|
1845 |
-
num_cycles: int = 1,
|
1846 |
-
power: float = 1.0,
|
1847 |
-
):
|
1848 |
-
"""
|
1849 |
-
Unified API to get any scheduler from its name.
|
1850 |
-
Args:
|
1851 |
-
name (`str` or `SchedulerType`):
|
1852 |
-
The name of the scheduler to use.
|
1853 |
-
optimizer (`torch.optim.Optimizer`):
|
1854 |
-
The optimizer that will be used during training.
|
1855 |
-
num_warmup_steps (`int`, *optional*):
|
1856 |
-
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
1857 |
-
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
1858 |
-
num_training_steps (`int``, *optional*):
|
1859 |
-
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
1860 |
-
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
1861 |
-
num_cycles (`int`, *optional*):
|
1862 |
-
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
1863 |
-
power (`float`, *optional*, defaults to 1.0):
|
1864 |
-
Power factor. See `POLYNOMIAL` scheduler
|
1865 |
-
last_epoch (`int`, *optional*, defaults to -1):
|
1866 |
-
The index of the last epoch when resuming training.
|
1867 |
-
"""
|
1868 |
-
if name.startswith("adafactor"):
|
1869 |
-
assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
|
1870 |
-
initial_lr = float(name.split(':')[1])
|
1871 |
-
# print("adafactor scheduler init lr", initial_lr)
|
1872 |
-
return transformers.optimization.AdafactorSchedule(optimizer, initial_lr)
|
1873 |
-
|
1874 |
-
name = SchedulerType(name)
|
1875 |
-
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
1876 |
-
if name == SchedulerType.CONSTANT:
|
1877 |
-
return schedule_func(optimizer)
|
1878 |
-
|
1879 |
-
# All other schedulers require `num_warmup_steps`
|
1880 |
-
if num_warmup_steps is None:
|
1881 |
-
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
1882 |
-
|
1883 |
-
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
1884 |
-
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
1885 |
-
|
1886 |
-
# All other schedulers require `num_training_steps`
|
1887 |
-
if num_training_steps is None:
|
1888 |
-
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
1889 |
-
|
1890 |
-
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
1891 |
-
return schedule_func(
|
1892 |
-
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
1893 |
-
)
|
1894 |
-
|
1895 |
-
if name == SchedulerType.POLYNOMIAL:
|
1896 |
-
return schedule_func(
|
1897 |
-
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
1898 |
-
)
|
1899 |
-
|
1900 |
-
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
1901 |
-
|
1902 |
-
|
1903 |
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
1904 |
# backward compatibility
|
1905 |
if args.caption_extention is not None:
|
1906 |
args.caption_extension = args.caption_extention
|
1907 |
args.caption_extention = None
|
1908 |
|
|
|
|
|
|
|
|
|
1909 |
# assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
|
1910 |
if args.resolution is not None:
|
1911 |
args.resolution = tuple([int(r) for r in args.resolution.split(',')])
|
@@ -1928,28 +1536,12 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
|
1928 |
|
1929 |
def load_tokenizer(args: argparse.Namespace):
|
1930 |
print("prepare tokenizer")
|
1931 |
-
|
1932 |
-
|
1933 |
-
|
1934 |
-
|
1935 |
-
|
1936 |
-
if os.path.exists(local_tokenizer_path):
|
1937 |
-
print(f"load tokenizer from cache: {local_tokenizer_path}")
|
1938 |
-
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2
|
1939 |
-
|
1940 |
-
if tokenizer is None:
|
1941 |
-
if args.v2:
|
1942 |
-
tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer")
|
1943 |
-
else:
|
1944 |
-
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
1945 |
-
|
1946 |
-
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
1947 |
print(f"update token length: {args.max_token_length}")
|
1948 |
-
|
1949 |
-
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
1950 |
-
print(f"save Tokenizer to cache: {local_tokenizer_path}")
|
1951 |
-
tokenizer.save_pretrained(local_tokenizer_path)
|
1952 |
-
|
1953 |
return tokenizer
|
1954 |
|
1955 |
|
@@ -2000,19 +1592,13 @@ def prepare_dtype(args: argparse.Namespace):
|
|
2000 |
|
2001 |
|
2002 |
def load_target_model(args: argparse.Namespace, weight_dtype):
|
2003 |
-
|
2004 |
-
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
2005 |
-
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
2006 |
if load_stable_diffusion_format:
|
2007 |
print("load StableDiffusion checkpoint")
|
2008 |
-
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2,
|
2009 |
else:
|
2010 |
print("load Diffusers pretrained models")
|
2011 |
-
|
2012 |
-
pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
|
2013 |
-
except EnvironmentError as ex:
|
2014 |
-
print(
|
2015 |
-
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}")
|
2016 |
text_encoder = pipe.text_encoder
|
2017 |
vae = pipe.vae
|
2018 |
unet = pipe.unet
|
@@ -2181,185 +1767,6 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
|
2181 |
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
2182 |
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
2183 |
|
2184 |
-
|
2185 |
-
# scheduler:
|
2186 |
-
SCHEDULER_LINEAR_START = 0.00085
|
2187 |
-
SCHEDULER_LINEAR_END = 0.0120
|
2188 |
-
SCHEDULER_TIMESTEPS = 1000
|
2189 |
-
SCHEDLER_SCHEDULE = 'scaled_linear'
|
2190 |
-
|
2191 |
-
|
2192 |
-
def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None):
|
2193 |
-
"""
|
2194 |
-
生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない
|
2195 |
-
clip skipは対応した
|
2196 |
-
"""
|
2197 |
-
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
2198 |
-
return
|
2199 |
-
if args.sample_every_n_epochs is not None:
|
2200 |
-
# sample_every_n_steps は無視する
|
2201 |
-
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
2202 |
-
return
|
2203 |
-
else:
|
2204 |
-
if steps % args.sample_every_n_steps != 0:
|
2205 |
-
return
|
2206 |
-
|
2207 |
-
print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
2208 |
-
if not os.path.isfile(args.sample_prompts):
|
2209 |
-
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
2210 |
-
return
|
2211 |
-
|
2212 |
-
# ここでCUDAのキャッシュクリアとかしたほうがいいのか……
|
2213 |
-
|
2214 |
-
org_vae_device = vae.device # CPUにいるはず
|
2215 |
-
vae.to(device)
|
2216 |
-
|
2217 |
-
# clip skip 対応のための wrapper を作る
|
2218 |
-
if args.clip_skip is None:
|
2219 |
-
text_encoder_or_wrapper = text_encoder
|
2220 |
-
else:
|
2221 |
-
class Wrapper():
|
2222 |
-
def __init__(self, tenc) -> None:
|
2223 |
-
self.tenc = tenc
|
2224 |
-
self.config = {}
|
2225 |
-
super().__init__()
|
2226 |
-
|
2227 |
-
def __call__(self, input_ids, attention_mask):
|
2228 |
-
enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True)
|
2229 |
-
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
2230 |
-
encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states)
|
2231 |
-
pooled_output = enc_out['pooler_output']
|
2232 |
-
return encoder_hidden_states, pooled_output # 1st output is only used
|
2233 |
-
|
2234 |
-
text_encoder_or_wrapper = Wrapper(text_encoder)
|
2235 |
-
|
2236 |
-
# read prompts
|
2237 |
-
with open(args.sample_prompts, 'rt', encoding='utf-8') as f:
|
2238 |
-
prompts = f.readlines()
|
2239 |
-
|
2240 |
-
# schedulerを用意する
|
2241 |
-
sched_init_args = {}
|
2242 |
-
if args.sample_sampler == "ddim":
|
2243 |
-
scheduler_cls = DDIMScheduler
|
2244 |
-
elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
|
2245 |
-
scheduler_cls = DDPMScheduler
|
2246 |
-
elif args.sample_sampler == "pndm":
|
2247 |
-
scheduler_cls = PNDMScheduler
|
2248 |
-
elif args.sample_sampler == 'lms' or args.sample_sampler == 'k_lms':
|
2249 |
-
scheduler_cls = LMSDiscreteScheduler
|
2250 |
-
elif args.sample_sampler == 'euler' or args.sample_sampler == 'k_euler':
|
2251 |
-
scheduler_cls = EulerDiscreteScheduler
|
2252 |
-
elif args.sample_sampler == 'euler_a' or args.sample_sampler == 'k_euler_a':
|
2253 |
-
scheduler_cls = EulerAncestralDiscreteScheduler
|
2254 |
-
elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
|
2255 |
-
scheduler_cls = DPMSolverMultistepScheduler
|
2256 |
-
sched_init_args['algorithm_type'] = args.sample_sampler
|
2257 |
-
elif args.sample_sampler == "dpmsingle":
|
2258 |
-
scheduler_cls = DPMSolverSinglestepScheduler
|
2259 |
-
elif args.sample_sampler == "heun":
|
2260 |
-
scheduler_cls = HeunDiscreteScheduler
|
2261 |
-
elif args.sample_sampler == 'dpm_2' or args.sample_sampler == 'k_dpm_2':
|
2262 |
-
scheduler_cls = KDPM2DiscreteScheduler
|
2263 |
-
elif args.sample_sampler == 'dpm_2_a' or args.sample_sampler == 'k_dpm_2_a':
|
2264 |
-
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
2265 |
-
else:
|
2266 |
-
scheduler_cls = DDIMScheduler
|
2267 |
-
|
2268 |
-
if args.v_parameterization:
|
2269 |
-
sched_init_args['prediction_type'] = 'v_prediction'
|
2270 |
-
|
2271 |
-
scheduler = scheduler_cls(num_train_timesteps=SCHEDULER_TIMESTEPS,
|
2272 |
-
beta_start=SCHEDULER_LINEAR_START, beta_end=SCHEDULER_LINEAR_END,
|
2273 |
-
beta_schedule=SCHEDLER_SCHEDULE, **sched_init_args)
|
2274 |
-
|
2275 |
-
# clip_sample=Trueにする
|
2276 |
-
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
2277 |
-
# print("set clip_sample to True")
|
2278 |
-
scheduler.config.clip_sample = True
|
2279 |
-
|
2280 |
-
pipeline = StableDiffusionPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer,
|
2281 |
-
scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False)
|
2282 |
-
pipeline.to(device)
|
2283 |
-
|
2284 |
-
save_dir = args.output_dir + "/sample"
|
2285 |
-
os.makedirs(save_dir, exist_ok=True)
|
2286 |
-
|
2287 |
-
rng_state = torch.get_rng_state()
|
2288 |
-
cuda_rng_state = torch.cuda.get_rng_state()
|
2289 |
-
|
2290 |
-
with torch.no_grad():
|
2291 |
-
with accelerator.autocast():
|
2292 |
-
for i, prompt in enumerate(prompts):
|
2293 |
-
prompt = prompt.strip()
|
2294 |
-
if len(prompt) == 0 or prompt[0] == '#':
|
2295 |
-
continue
|
2296 |
-
|
2297 |
-
# subset of gen_img_diffusers
|
2298 |
-
prompt_args = prompt.split(' --')
|
2299 |
-
prompt = prompt_args[0]
|
2300 |
-
negative_prompt = None
|
2301 |
-
sample_steps = 30
|
2302 |
-
width = height = 512
|
2303 |
-
scale = 7.5
|
2304 |
-
seed = None
|
2305 |
-
for parg in prompt_args:
|
2306 |
-
try:
|
2307 |
-
m = re.match(r'w (\d+)', parg, re.IGNORECASE)
|
2308 |
-
if m:
|
2309 |
-
width = int(m.group(1))
|
2310 |
-
continue
|
2311 |
-
|
2312 |
-
m = re.match(r'h (\d+)', parg, re.IGNORECASE)
|
2313 |
-
if m:
|
2314 |
-
height = int(m.group(1))
|
2315 |
-
continue
|
2316 |
-
|
2317 |
-
m = re.match(r'd (\d+)', parg, re.IGNORECASE)
|
2318 |
-
if m:
|
2319 |
-
seed = int(m.group(1))
|
2320 |
-
continue
|
2321 |
-
|
2322 |
-
m = re.match(r's (\d+)', parg, re.IGNORECASE)
|
2323 |
-
if m: # steps
|
2324 |
-
sample_steps = max(1, min(1000, int(m.group(1))))
|
2325 |
-
continue
|
2326 |
-
|
2327 |
-
m = re.match(r'l ([\d\.]+)', parg, re.IGNORECASE)
|
2328 |
-
if m: # scale
|
2329 |
-
scale = float(m.group(1))
|
2330 |
-
continue
|
2331 |
-
|
2332 |
-
m = re.match(r'n (.+)', parg, re.IGNORECASE)
|
2333 |
-
if m: # negative prompt
|
2334 |
-
negative_prompt = m.group(1)
|
2335 |
-
continue
|
2336 |
-
|
2337 |
-
except ValueError as ex:
|
2338 |
-
print(f"Exception in parsing / 解析エラー: {parg}")
|
2339 |
-
print(ex)
|
2340 |
-
|
2341 |
-
if seed is not None:
|
2342 |
-
torch.manual_seed(seed)
|
2343 |
-
torch.cuda.manual_seed(seed)
|
2344 |
-
|
2345 |
-
if prompt_replacement is not None:
|
2346 |
-
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
2347 |
-
if negative_prompt is not None:
|
2348 |
-
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
2349 |
-
|
2350 |
-
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
|
2351 |
-
|
2352 |
-
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
|
2353 |
-
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
2354 |
-
seed_suffix = "" if seed is None else f"_{seed}"
|
2355 |
-
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
|
2356 |
-
|
2357 |
-
image.save(os.path.join(save_dir, img_filename))
|
2358 |
-
|
2359 |
-
torch.set_rng_state(rng_state)
|
2360 |
-
torch.cuda.set_rng_state(cuda_rng_state)
|
2361 |
-
vae.to(org_vae_device)
|
2362 |
-
|
2363 |
# endregion
|
2364 |
|
2365 |
# region 前処理用
|
|
|
1 |
# common functions for training
|
2 |
|
3 |
import argparse
|
|
|
4 |
import json
|
|
|
5 |
import shutil
|
6 |
import time
|
7 |
+
from typing import Dict, List, NamedTuple, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from accelerate import Accelerator
|
9 |
+
from torch.autograd.function import Function
|
10 |
import glob
|
11 |
import math
|
12 |
import os
|
|
|
17 |
|
18 |
from tqdm import tqdm
|
19 |
import torch
|
|
|
20 |
from torchvision import transforms
|
21 |
from transformers import CLIPTokenizer
|
|
|
22 |
import diffusers
|
23 |
+
from diffusers import DDPMScheduler, StableDiffusionPipeline
|
|
|
|
|
|
|
|
|
24 |
import albumentations as albu
|
25 |
import numpy as np
|
26 |
from PIL import Image
|
|
|
195 |
batch_index: int
|
196 |
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
class BaseDataset(torch.utils.data.Dataset):
|
199 |
+
def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
|
200 |
super().__init__()
|
201 |
+
self.tokenizer: CLIPTokenizer = tokenizer
|
202 |
self.max_token_length = max_token_length
|
203 |
+
self.shuffle_caption = shuffle_caption
|
204 |
+
self.shuffle_keep_tokens = shuffle_keep_tokens
|
205 |
# width/height is used when enable_bucket==False
|
206 |
self.width, self.height = (None, None) if resolution is None else resolution
|
207 |
+
self.face_crop_aug_range = face_crop_aug_range
|
208 |
+
self.flip_aug = flip_aug
|
209 |
+
self.color_aug = color_aug
|
210 |
self.debug_dataset = debug_dataset
|
211 |
+
self.random_crop = random_crop
|
|
|
|
|
212 |
self.token_padding_disabled = False
|
213 |
+
self.dataset_dirs_info = {}
|
214 |
+
self.reg_dataset_dirs_info = {}
|
215 |
self.tag_frequency = {}
|
216 |
|
217 |
self.enable_bucket = False
|
|
|
225 |
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
226 |
|
227 |
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
228 |
+
self.dropout_rate: float = 0
|
229 |
+
self.dropout_every_n_epochs: int = None
|
230 |
+
self.tag_dropout_rate: float = 0
|
231 |
|
232 |
# augmentation
|
233 |
+
flip_p = 0.5 if flip_aug else 0.0
|
234 |
+
if color_aug:
|
235 |
+
# わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
|
236 |
+
self.aug = albu.Compose([
|
237 |
+
albu.OneOf([
|
238 |
+
albu.HueSaturationValue(8, 0, 0, p=.5),
|
239 |
+
albu.RandomGamma((95, 105), p=.5),
|
240 |
+
], p=.33),
|
241 |
+
albu.HorizontalFlip(p=flip_p)
|
242 |
+
], p=1.)
|
243 |
+
elif flip_aug:
|
244 |
+
self.aug = albu.Compose([
|
245 |
+
albu.HorizontalFlip(p=flip_p)
|
246 |
+
], p=1.)
|
247 |
+
else:
|
248 |
+
self.aug = None
|
249 |
|
250 |
self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
|
251 |
|
252 |
self.image_data: Dict[str, ImageInfo] = {}
|
|
|
253 |
|
254 |
self.replacements = {}
|
255 |
|
256 |
def set_current_epoch(self, epoch):
|
257 |
self.current_epoch = epoch
|
258 |
+
|
259 |
+
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
|
260 |
+
# コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
|
261 |
+
self.dropout_rate = dropout_rate
|
262 |
+
self.dropout_every_n_epochs = dropout_every_n_epochs
|
263 |
+
self.tag_dropout_rate = tag_dropout_rate
|
264 |
|
265 |
def set_tag_frequency(self, dir_name, captions):
|
266 |
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
267 |
self.tag_frequency[dir_name] = frequency_for_dir
|
268 |
for caption in captions:
|
269 |
for tag in caption.split(","):
|
270 |
+
if tag and not tag.isspace():
|
|
|
271 |
tag = tag.lower()
|
272 |
frequency = frequency_for_dir.get(tag, 0)
|
273 |
frequency_for_dir[tag] = frequency + 1
|
|
|
278 |
def add_replacement(self, str_from, str_to):
|
279 |
self.replacements[str_from] = str_to
|
280 |
|
281 |
+
def process_caption(self, caption):
|
282 |
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
283 |
+
is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
|
284 |
+
is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
|
285 |
|
286 |
if is_drop_out:
|
287 |
caption = ""
|
288 |
else:
|
289 |
+
if self.shuffle_caption or self.tag_dropout_rate > 0:
|
290 |
def dropout_tags(tokens):
|
291 |
+
if self.tag_dropout_rate <= 0:
|
292 |
return tokens
|
293 |
l = []
|
294 |
for token in tokens:
|
295 |
+
if random.random() >= self.tag_dropout_rate:
|
296 |
l.append(token)
|
297 |
return l
|
298 |
|
299 |
+
tokens = [t.strip() for t in caption.strip().split(",")]
|
300 |
+
if self.shuffle_keep_tokens is None:
|
301 |
+
if self.shuffle_caption:
|
302 |
+
random.shuffle(tokens)
|
303 |
+
|
304 |
+
tokens = dropout_tags(tokens)
|
305 |
+
else:
|
306 |
+
if len(tokens) > self.shuffle_keep_tokens:
|
307 |
+
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
308 |
+
tokens = tokens[self.shuffle_keep_tokens:]
|
309 |
|
310 |
+
if self.shuffle_caption:
|
311 |
+
random.shuffle(tokens)
|
312 |
|
313 |
+
tokens = dropout_tags(tokens)
|
314 |
|
315 |
+
tokens = keep_tokens + tokens
|
316 |
+
caption = ", ".join(tokens)
|
317 |
|
318 |
# textual inversion対応
|
319 |
for str_from, str_to in self.replacements.items():
|
|
|
367 |
input_ids = torch.stack(iids_list) # 3,77
|
368 |
return input_ids
|
369 |
|
370 |
+
def register_image(self, info: ImageInfo):
|
371 |
self.image_data[info.image_key] = info
|
|
|
372 |
|
373 |
def make_buckets(self):
|
374 |
'''
|
|
|
467 |
img = np.array(image, np.uint8)
|
468 |
return img
|
469 |
|
470 |
+
def trim_and_resize_if_required(self, image, reso, resized_size):
|
471 |
image_height, image_width = image.shape[0:2]
|
472 |
|
473 |
if image_width != resized_size[0] or image_height != resized_size[1]:
|
|
|
477 |
image_height, image_width = image.shape[0:2]
|
478 |
if image_width > reso[0]:
|
479 |
trim_size = image_width - reso[0]
|
480 |
+
p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
|
481 |
# print("w", trim_size, p)
|
482 |
image = image[:, p:p + reso[0]]
|
483 |
if image_height > reso[1]:
|
484 |
trim_size = image_height - reso[1]
|
485 |
+
p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
|
486 |
# print("h", trim_size, p)
|
487 |
image = image[p:p + reso[1]]
|
488 |
|
489 |
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
490 |
return image
|
491 |
|
|
|
|
|
|
|
492 |
def cache_latents(self, vae):
|
493 |
# TODO ここを高速化したい
|
494 |
print("caching latents.")
|
495 |
for info in tqdm(self.image_data.values()):
|
|
|
|
|
496 |
if info.latents_npz is not None:
|
497 |
info.latents = self.load_latents_from_npz(info, False)
|
498 |
info.latents = torch.FloatTensor(info.latents)
|
|
|
502 |
continue
|
503 |
|
504 |
image = self.load_image(info.absolute_path)
|
505 |
+
image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
|
506 |
|
507 |
img_tensor = self.image_transforms(image)
|
508 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
509 |
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
510 |
|
511 |
+
if self.flip_aug:
|
512 |
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
|
513 |
img_tensor = self.image_transforms(image)
|
514 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
|
|
518 |
image = Image.open(image_path)
|
519 |
return image.size
|
520 |
|
521 |
+
def load_image_with_face_info(self, image_path: str):
|
522 |
img = self.load_image(image_path)
|
523 |
|
524 |
face_cx = face_cy = face_w = face_h = 0
|
525 |
+
if self.face_crop_aug_range is not None:
|
526 |
tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
|
527 |
if len(tokens) >= 5:
|
528 |
face_cx = int(tokens[-4])
|
|
|
533 |
return img, face_cx, face_cy, face_w, face_h
|
534 |
|
535 |
# いい感じに切り出す
|
536 |
+
def crop_target(self, image, face_cx, face_cy, face_w, face_h):
|
537 |
height, width = image.shape[0:2]
|
538 |
if height == self.height and width == self.width:
|
539 |
return image
|
|
|
541 |
# 画像サイズはsizeより大きいのでリサイズする
|
542 |
face_size = max(face_w, face_h)
|
543 |
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
|
544 |
+
min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
|
545 |
+
max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
|
546 |
if min_scale >= max_scale: # range指定がmin==max
|
547 |
scale = min_scale
|
548 |
else:
|
|
|
560 |
for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
|
561 |
p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
|
562 |
|
563 |
+
if self.random_crop:
|
564 |
# 背景も含めるために顔を中心に置く確率を高めつつずらす
|
565 |
range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
|
566 |
p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
|
567 |
else:
|
568 |
# range指定があるときのみ、すこしだけランダムに(わりと適当)
|
569 |
+
if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
|
570 |
if face_size > self.size // 10 and face_size >= 40:
|
571 |
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
|
572 |
|
|
|
589 |
return self._length
|
590 |
|
591 |
def __getitem__(self, index):
|
592 |
+
if index == 0:
|
593 |
+
self.shuffle_buckets()
|
594 |
+
|
595 |
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
596 |
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
597 |
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
|
|
604 |
|
605 |
for image_key in bucket[image_index:image_index + bucket_batch_size]:
|
606 |
image_info = self.image_data[image_key]
|
|
|
607 |
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
608 |
|
609 |
# image/latentsを処理する
|
610 |
if image_info.latents is not None:
|
611 |
+
latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped
|
612 |
image = None
|
613 |
elif image_info.latents_npz is not None:
|
614 |
+
latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5)
|
615 |
latents = torch.FloatTensor(latents)
|
616 |
image = None
|
617 |
else:
|
618 |
# 画像を読み込み、必要ならcropする
|
619 |
+
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
|
620 |
im_h, im_w = img.shape[0:2]
|
621 |
|
622 |
if self.enable_bucket:
|
623 |
+
img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
|
624 |
else:
|
625 |
if face_cx > 0: # 顔位置情報あり
|
626 |
+
img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
|
627 |
elif im_h > self.height or im_w > self.width:
|
628 |
+
assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
|
629 |
if im_h > self.height:
|
630 |
p = random.randint(0, im_h - self.height)
|
631 |
img = img[p:p + self.height]
|
|
|
637 |
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
638 |
|
639 |
# augmentation
|
640 |
+
if self.aug is not None:
|
641 |
+
img = self.aug(image=img)['image']
|
|
|
642 |
|
643 |
latents = None
|
644 |
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
|
|
646 |
images.append(image)
|
647 |
latents_list.append(latents)
|
648 |
|
649 |
+
caption = self.process_caption(image_info.caption)
|
650 |
captions.append(caption)
|
651 |
if not self.token_padding_disabled: # this option might be omitted in future
|
652 |
input_ids_list.append(self.get_input_ids(caption))
|
|
|
677 |
|
678 |
|
679 |
class DreamBoothDataset(BaseDataset):
|
680 |
+
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
|
681 |
+
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
682 |
+
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
683 |
|
684 |
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
685 |
|
|
|
702 |
self.bucket_reso_steps = None # この情報は使われない
|
703 |
self.bucket_no_upscale = False
|
704 |
|
705 |
+
def read_caption(img_path):
|
706 |
# captionの候補ファイル名を作る
|
707 |
base_name = os.path.splitext(img_path)[0]
|
708 |
base_name_face_det = base_name
|
|
|
725 |
break
|
726 |
return caption
|
727 |
|
728 |
+
def load_dreambooth_dir(dir):
|
729 |
+
if not os.path.isdir(dir):
|
730 |
+
# print(f"ignore file: {dir}")
|
731 |
+
return 0, [], []
|
732 |
+
|
733 |
+
tokens = os.path.basename(dir).split('_')
|
734 |
+
try:
|
735 |
+
n_repeats = int(tokens[0])
|
736 |
+
except ValueError as e:
|
737 |
+
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
|
738 |
+
return 0, [], []
|
739 |
|
740 |
+
caption_by_folder = '_'.join(tokens[1:])
|
741 |
+
img_paths = glob_images(dir, "*")
|
742 |
+
print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
|
743 |
|
744 |
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
745 |
captions = []
|
746 |
for img_path in img_paths:
|
747 |
+
cap_for_img = read_caption(img_path)
|
748 |
+
captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
|
|
|
|
|
|
|
|
|
|
|
|
|
749 |
|
750 |
+
self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
751 |
|
752 |
+
return n_repeats, img_paths, captions
|
|
|
|
|
|
|
753 |
|
754 |
+
print("prepare train images.")
|
755 |
+
train_dirs = os.listdir(train_data_dir)
|
756 |
+
num_train_images = 0
|
757 |
+
for dir in train_dirs:
|
758 |
+
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
|
759 |
+
num_train_images += n_repeats * len(img_paths)
|
760 |
|
761 |
for img_path, caption in zip(img_paths, captions):
|
762 |
+
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
|
763 |
+
self.register_image(info)
|
|
|
|
|
|
|
764 |
|
765 |
+
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
|
|
766 |
|
767 |
print(f"{num_train_images} train images with repeating.")
|
768 |
self.num_train_images = num_train_images
|
769 |
|
770 |
+
# reg imageは数を数えて学習画像と同じ枚数にする
|
771 |
+
num_reg_images = 0
|
772 |
+
if reg_data_dir:
|
773 |
+
print("prepare reg images.")
|
774 |
+
reg_infos: List[ImageInfo] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
775 |
|
776 |
+
reg_dirs = os.listdir(reg_data_dir)
|
777 |
+
for dir in reg_dirs:
|
778 |
+
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
|
779 |
+
num_reg_images += n_repeats * len(img_paths)
|
780 |
|
781 |
+
for img_path, caption in zip(img_paths, captions):
|
782 |
+
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
783 |
+
reg_infos.append(info)
|
784 |
|
785 |
+
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
|
|
|
|
|
|
786 |
|
787 |
+
print(f"{num_reg_images} reg images.")
|
788 |
+
if num_train_images < num_reg_images:
|
789 |
+
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
|
790 |
|
791 |
+
if num_reg_images == 0:
|
792 |
+
print("no regularization images / 正則化画像が見つかりませんでした")
|
|
|
|
|
|
|
793 |
else:
|
794 |
+
# num_repeatsを計算する:どうせ大した数ではないのでループで処理する
|
795 |
+
n = 0
|
796 |
+
first_loop = True
|
797 |
+
while n < num_train_images:
|
798 |
+
for info in reg_infos:
|
799 |
+
if first_loop:
|
800 |
+
self.register_image(info)
|
801 |
+
n += info.num_repeats
|
802 |
+
else:
|
803 |
+
info.num_repeats += 1
|
804 |
+
n += 1
|
805 |
+
if n >= num_train_images:
|
806 |
+
break
|
807 |
+
first_loop = False
|
808 |
|
809 |
+
self.num_reg_images = num_reg_images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
810 |
|
|
|
|
|
811 |
|
812 |
+
class FineTuningDataset(BaseDataset):
|
813 |
+
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
|
814 |
+
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
815 |
+
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
816 |
+
|
817 |
+
# メタデータを読み込む
|
818 |
+
if os.path.exists(json_file_name):
|
819 |
+
print(f"loading existing metadata: {json_file_name}")
|
820 |
+
with open(json_file_name, "rt", encoding='utf-8') as f:
|
821 |
+
metadata = json.load(f)
|
822 |
+
else:
|
823 |
+
raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}")
|
824 |
|
825 |
+
self.metadata = metadata
|
826 |
+
self.train_data_dir = train_data_dir
|
827 |
+
self.batch_size = batch_size
|
828 |
|
829 |
+
tags_list = []
|
830 |
+
for image_key, img_md in metadata.items():
|
831 |
+
# path情報を作る
|
832 |
+
if os.path.exists(image_key):
|
833 |
+
abs_path = image_key
|
834 |
+
else:
|
835 |
+
# わりといい加減だがいい方法が思いつかん
|
836 |
+
abs_path = glob_images(train_data_dir, image_key)
|
837 |
+
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
838 |
+
abs_path = abs_path[0]
|
839 |
+
|
840 |
+
caption = img_md.get('caption')
|
841 |
+
tags = img_md.get('tags')
|
842 |
+
if caption is None:
|
843 |
+
caption = tags
|
844 |
+
elif tags is not None and len(tags) > 0:
|
845 |
+
caption = caption + ', ' + tags
|
846 |
+
tags_list.append(tags)
|
847 |
+
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
|
848 |
+
|
849 |
+
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
|
850 |
+
image_info.image_size = img_md.get('train_resolution')
|
851 |
+
|
852 |
+
if not self.color_aug and not self.random_crop:
|
853 |
+
# if npz exists, use them
|
854 |
+
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
|
855 |
+
|
856 |
+
self.register_image(image_info)
|
857 |
+
self.num_train_images = len(metadata) * dataset_repeats
|
858 |
+
self.num_reg_images = 0
|
859 |
|
860 |
+
# TODO do not record tag freq when no tag
|
861 |
+
self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
|
862 |
+
self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
|
|
|
863 |
|
864 |
# check existence of all npz files
|
865 |
+
use_npz_latents = not (self.color_aug or self.random_crop)
|
866 |
if use_npz_latents:
|
|
|
867 |
npz_any = False
|
868 |
npz_all = True
|
|
|
869 |
for image_info in self.image_data.values():
|
|
|
|
|
870 |
has_npz = image_info.latents_npz is not None
|
871 |
npz_any = npz_any or has_npz
|
872 |
|
873 |
+
if self.flip_aug:
|
874 |
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
|
|
875 |
npz_all = npz_all and has_npz
|
876 |
|
877 |
if npz_any and not npz_all:
|
|
|
883 |
elif not npz_all:
|
884 |
use_npz_latents = False
|
885 |
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
|
886 |
+
if self.flip_aug:
|
887 |
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
888 |
# else:
|
889 |
# print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
|
|
929 |
for image_info in self.image_data.values():
|
930 |
image_info.latents_npz = image_info.latents_npz_flipped = None
|
931 |
|
932 |
+
def image_key_to_npz_file(self, image_key):
|
933 |
base_name = os.path.splitext(image_key)[0]
|
934 |
npz_file_norm = base_name + '.npz'
|
935 |
|
|
|
941 |
return npz_file_norm, npz_file_flip
|
942 |
|
943 |
# image_key is relative path
|
944 |
+
npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz')
|
945 |
+
npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz')
|
946 |
|
947 |
if not os.path.exists(npz_file_norm):
|
948 |
npz_file_norm = None
|
|
|
953 |
return npz_file_norm, npz_file_flip
|
954 |
|
955 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
956 |
def debug_dataset(train_dataset, show_input_ids=False):
|
957 |
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
958 |
print("Escape for exit. / Escキーで中断、終了します")
|
959 |
|
960 |
train_dataset.set_current_epoch(1)
|
961 |
k = 0
|
962 |
+
for i, example in enumerate(train_dataset):
|
|
|
|
|
|
|
963 |
if example['latents'] is not None:
|
964 |
print(f"sample has latents from npz file: {example['latents'].size()}")
|
965 |
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
|
|
|
1364 |
help='enable v-parameterization training / v-parameterization学習を有効にする')
|
1365 |
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
|
1366 |
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1367 |
|
1368 |
|
1369 |
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
|
|
1387 |
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
1388 |
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
1389 |
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
1390 |
+
parser.add_argument("--use_8bit_adam", action="store_true",
|
1391 |
+
help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
1392 |
+
parser.add_argument("--use_lion_optimizer", action="store_true",
|
1393 |
+
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
1394 |
parser.add_argument("--mem_eff_attn", action="store_true",
|
1395 |
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
1396 |
parser.add_argument("--xformers", action="store_true",
|
|
|
1398 |
parser.add_argument("--vae", type=str, default=None,
|
1399 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
1400 |
|
1401 |
+
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
1402 |
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
1403 |
parser.add_argument("--max_train_epochs", type=int, default=None,
|
1404 |
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
|
|
1419 |
parser.add_argument("--logging_dir", type=str, default=None,
|
1420 |
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
1421 |
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
1422 |
+
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
1423 |
+
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
|
1424 |
+
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
1425 |
+
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
1426 |
parser.add_argument("--noise_offset", type=float, default=None,
|
1427 |
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
|
1428 |
parser.add_argument("--lowram", action="store_true",
|
1429 |
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
|
1430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1431 |
if support_dreambooth:
|
1432 |
# DreamBooth training
|
1433 |
parser.add_argument("--prior_loss_weight", type=float, default=1.0,
|
|
|
1449 |
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
|
1450 |
parser.add_argument("--caption_extention", type=str, default=None,
|
1451 |
help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
|
1452 |
+
parser.add_argument("--keep_tokens", type=int, default=None,
|
1453 |
+
help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
|
1454 |
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
|
1455 |
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
|
1456 |
parser.add_argument("--face_crop_aug_range", type=str, default=None,
|
|
|
1475 |
if support_caption_dropout:
|
1476 |
# Textual Inversion はcaptionのdropoutをsupportしない
|
1477 |
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
1478 |
+
parser.add_argument("--caption_dropout_rate", type=float, default=0,
|
1479 |
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
1480 |
+
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
|
1481 |
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
1482 |
+
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
|
1483 |
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
|
1484 |
|
1485 |
if support_dreambooth:
|
|
|
1504 |
# region utils
|
1505 |
|
1506 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1507 |
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
1508 |
# backward compatibility
|
1509 |
if args.caption_extention is not None:
|
1510 |
args.caption_extension = args.caption_extention
|
1511 |
args.caption_extention = None
|
1512 |
|
1513 |
+
if args.cache_latents:
|
1514 |
+
assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
|
1515 |
+
assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
|
1516 |
+
|
1517 |
# assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
|
1518 |
if args.resolution is not None:
|
1519 |
args.resolution = tuple([int(r) for r in args.resolution.split(',')])
|
|
|
1536 |
|
1537 |
def load_tokenizer(args: argparse.Namespace):
|
1538 |
print("prepare tokenizer")
|
1539 |
+
if args.v2:
|
1540 |
+
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
1541 |
+
else:
|
1542 |
+
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
|
1543 |
+
if args.max_token_length is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1544 |
print(f"update token length: {args.max_token_length}")
|
|
|
|
|
|
|
|
|
|
|
1545 |
return tokenizer
|
1546 |
|
1547 |
|
|
|
1592 |
|
1593 |
|
1594 |
def load_target_model(args: argparse.Namespace, weight_dtype):
|
1595 |
+
load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers
|
|
|
|
|
1596 |
if load_stable_diffusion_format:
|
1597 |
print("load StableDiffusion checkpoint")
|
1598 |
+
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
|
1599 |
else:
|
1600 |
print("load Diffusers pretrained models")
|
1601 |
+
pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
|
|
|
|
|
|
|
|
|
1602 |
text_encoder = pipe.text_encoder
|
1603 |
vae = pipe.vae
|
1604 |
unet = pipe.unet
|
|
|
1767 |
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
1768 |
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
1769 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1770 |
# endregion
|
1771 |
|
1772 |
# region 前処理用
|
networks/lora.py
CHANGED
@@ -126,11 +126,6 @@ class LoRANetwork(torch.nn.Module):
|
|
126 |
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
127 |
names.add(lora.lora_name)
|
128 |
|
129 |
-
def set_multiplier(self, multiplier):
|
130 |
-
self.multiplier = multiplier
|
131 |
-
for lora in self.text_encoder_loras + self.unet_loras:
|
132 |
-
lora.multiplier = self.multiplier
|
133 |
-
|
134 |
def load_weights(self, file):
|
135 |
if os.path.splitext(file)[1] == '.safetensors':
|
136 |
from safetensors.torch import load_file, safe_open
|
|
|
126 |
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
127 |
names.add(lora.lora_name)
|
128 |
|
|
|
|
|
|
|
|
|
|
|
129 |
def load_weights(self, file):
|
130 |
if os.path.splitext(file)[1] == '.safetensors':
|
131 |
from safetensors.torch import load_file, safe_open
|
tools/convert_diffusers20_original_sd.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# convert Diffusers v1.x/v2.0 model to original Stable Diffusion
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from diffusers import StableDiffusionPipeline
|
7 |
+
|
8 |
+
import library.model_util as model_util
|
9 |
+
|
10 |
+
|
11 |
+
def convert(args):
|
12 |
+
# 引数を確認する
|
13 |
+
load_dtype = torch.float16 if args.fp16 else None
|
14 |
+
|
15 |
+
save_dtype = None
|
16 |
+
if args.fp16:
|
17 |
+
save_dtype = torch.float16
|
18 |
+
elif args.bf16:
|
19 |
+
save_dtype = torch.bfloat16
|
20 |
+
elif args.float:
|
21 |
+
save_dtype = torch.float
|
22 |
+
|
23 |
+
is_load_ckpt = os.path.isfile(args.model_to_load)
|
24 |
+
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
|
25 |
+
|
26 |
+
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
|
27 |
+
assert is_save_ckpt or args.reference_model is not None, f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
|
28 |
+
|
29 |
+
# モデルを読み込む
|
30 |
+
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
|
31 |
+
print(f"loading {msg}: {args.model_to_load}")
|
32 |
+
|
33 |
+
if is_load_ckpt:
|
34 |
+
v2_model = args.v2
|
35 |
+
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load)
|
36 |
+
else:
|
37 |
+
pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None)
|
38 |
+
text_encoder = pipe.text_encoder
|
39 |
+
vae = pipe.vae
|
40 |
+
unet = pipe.unet
|
41 |
+
|
42 |
+
if args.v1 == args.v2:
|
43 |
+
# 自動判定する
|
44 |
+
v2_model = unet.config.cross_attention_dim == 1024
|
45 |
+
print("checking model version: model is " + ('v2' if v2_model else 'v1'))
|
46 |
+
else:
|
47 |
+
v2_model = not args.v1
|
48 |
+
|
49 |
+
# 変換して保存する
|
50 |
+
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
|
51 |
+
print(f"converting and saving as {msg}: {args.model_to_save}")
|
52 |
+
|
53 |
+
if is_save_ckpt:
|
54 |
+
original_model = args.model_to_load if is_load_ckpt else None
|
55 |
+
key_count = model_util.save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet,
|
56 |
+
original_model, args.epoch, args.global_step, save_dtype, vae)
|
57 |
+
print(f"model saved. total converted state_dict keys: {key_count}")
|
58 |
+
else:
|
59 |
+
print(f"copy scheduler/tokenizer config from: {args.reference_model}")
|
60 |
+
model_util.save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors)
|
61 |
+
print(f"model saved.")
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == '__main__':
|
65 |
+
parser = argparse.ArgumentParser()
|
66 |
+
parser.add_argument("--v1", action='store_true',
|
67 |
+
help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')
|
68 |
+
parser.add_argument("--v2", action='store_true',
|
69 |
+
help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む')
|
70 |
+
parser.add_argument("--fp16", action='store_true',
|
71 |
+
help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)')
|
72 |
+
parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)')
|
73 |
+
parser.add_argument("--float", action='store_true',
|
74 |
+
help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)')
|
75 |
+
parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値')
|
76 |
+
parser.add_argument("--global_step", type=int, default=0,
|
77 |
+
help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')
|
78 |
+
parser.add_argument("--reference_model", type=str, default=None,
|
79 |
+
help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要")
|
80 |
+
parser.add_argument("--use_safetensors", action='store_true',
|
81 |
+
help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)")
|
82 |
+
|
83 |
+
parser.add_argument("model_to_load", type=str, default=None,
|
84 |
+
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")
|
85 |
+
parser.add_argument("model_to_save", type=str, default=None,
|
86 |
+
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")
|
87 |
+
|
88 |
+
args = parser.parse_args()
|
89 |
+
convert(args)
|
tools/detect_face_rotate.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
|
2 |
+
# (c) 2022 Kohya S. @kohya_ss
|
3 |
+
|
4 |
+
# 横長の画像から顔検出して正立するように回転し、そこを中心に正方形に切り出す
|
5 |
+
|
6 |
+
# v2: extract max face if multiple faces are found
|
7 |
+
# v3: add crop_ratio option
|
8 |
+
# v4: add multiple faces extraction and min/max size
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
import math
|
12 |
+
import cv2
|
13 |
+
import glob
|
14 |
+
import os
|
15 |
+
from anime_face_detector import create_detector
|
16 |
+
from tqdm import tqdm
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
KP_REYE = 11
|
20 |
+
KP_LEYE = 19
|
21 |
+
|
22 |
+
SCORE_THRES = 0.90
|
23 |
+
|
24 |
+
|
25 |
+
def detect_faces(detector, image, min_size):
|
26 |
+
preds = detector(image) # bgr
|
27 |
+
# print(len(preds))
|
28 |
+
|
29 |
+
faces = []
|
30 |
+
for pred in preds:
|
31 |
+
bb = pred['bbox']
|
32 |
+
score = bb[-1]
|
33 |
+
if score < SCORE_THRES:
|
34 |
+
continue
|
35 |
+
|
36 |
+
left, top, right, bottom = bb[:4]
|
37 |
+
cx = int((left + right) / 2)
|
38 |
+
cy = int((top + bottom) / 2)
|
39 |
+
fw = int(right - left)
|
40 |
+
fh = int(bottom - top)
|
41 |
+
|
42 |
+
lex, ley = pred['keypoints'][KP_LEYE, 0:2]
|
43 |
+
rex, rey = pred['keypoints'][KP_REYE, 0:2]
|
44 |
+
angle = math.atan2(ley - rey, lex - rex)
|
45 |
+
angle = angle / math.pi * 180
|
46 |
+
|
47 |
+
faces.append((cx, cy, fw, fh, angle))
|
48 |
+
|
49 |
+
faces.sort(key=lambda x: max(x[2], x[3]), reverse=True) # 大きい順
|
50 |
+
return faces
|
51 |
+
|
52 |
+
|
53 |
+
def rotate_image(image, angle, cx, cy):
|
54 |
+
h, w = image.shape[0:2]
|
55 |
+
rot_mat = cv2.getRotationMatrix2D((cx, cy), angle, 1.0)
|
56 |
+
|
57 |
+
# # 回転する分、すこし画像サイズを大きくする→とりあえず無効化
|
58 |
+
# nh = max(h, int(w * math.sin(angle)))
|
59 |
+
# nw = max(w, int(h * math.sin(angle)))
|
60 |
+
# if nh > h or nw > w:
|
61 |
+
# pad_y = nh - h
|
62 |
+
# pad_t = pad_y // 2
|
63 |
+
# pad_x = nw - w
|
64 |
+
# pad_l = pad_x // 2
|
65 |
+
# m = np.array([[0, 0, pad_l],
|
66 |
+
# [0, 0, pad_t]])
|
67 |
+
# rot_mat = rot_mat + m
|
68 |
+
# h, w = nh, nw
|
69 |
+
# cx += pad_l
|
70 |
+
# cy += pad_t
|
71 |
+
|
72 |
+
result = cv2.warpAffine(image, rot_mat, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
|
73 |
+
return result, cx, cy
|
74 |
+
|
75 |
+
|
76 |
+
def process(args):
|
77 |
+
assert (not args.resize_fit) or args.resize_face_size is None, f"resize_fit and resize_face_size can't be specified both / resize_fitとresize_face_sizeはどちらか片方しか指定できません"
|
78 |
+
assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません"
|
79 |
+
|
80 |
+
# アニメ顔検出モデルを読み込む
|
81 |
+
print("loading face detector.")
|
82 |
+
detector = create_detector('yolov3')
|
83 |
+
|
84 |
+
# cropの引数を解析する
|
85 |
+
if args.crop_size is None:
|
86 |
+
crop_width = crop_height = None
|
87 |
+
else:
|
88 |
+
tokens = args.crop_size.split(',')
|
89 |
+
assert len(tokens) == 2, f"crop_size must be 'width,height' / crop_sizeは'幅,高さ'で指定してください"
|
90 |
+
crop_width, crop_height = [int(t) for t in tokens]
|
91 |
+
|
92 |
+
if args.crop_ratio is None:
|
93 |
+
crop_h_ratio = crop_v_ratio = None
|
94 |
+
else:
|
95 |
+
tokens = args.crop_ratio.split(',')
|
96 |
+
assert len(tokens) == 2, f"crop_ratio must be 'horizontal,vertical' / crop_ratioは'幅,高さ'の倍率で指定してください"
|
97 |
+
crop_h_ratio, crop_v_ratio = [float(t) for t in tokens]
|
98 |
+
|
99 |
+
# 画像を処理する
|
100 |
+
print("processing.")
|
101 |
+
output_extension = ".png"
|
102 |
+
|
103 |
+
os.makedirs(args.dst_dir, exist_ok=True)
|
104 |
+
paths = glob.glob(os.path.join(args.src_dir, "*.png")) + glob.glob(os.path.join(args.src_dir, "*.jpg")) + \
|
105 |
+
glob.glob(os.path.join(args.src_dir, "*.webp"))
|
106 |
+
for path in tqdm(paths):
|
107 |
+
basename = os.path.splitext(os.path.basename(path))[0]
|
108 |
+
|
109 |
+
# image = cv2.imread(path) # 日本語ファイル名でエラーになる
|
110 |
+
image = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_UNCHANGED)
|
111 |
+
if len(image.shape) == 2:
|
112 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
113 |
+
if image.shape[2] == 4:
|
114 |
+
print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
|
115 |
+
image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい
|
116 |
+
|
117 |
+
h, w = image.shape[:2]
|
118 |
+
|
119 |
+
faces = detect_faces(detector, image, args.multiple_faces)
|
120 |
+
for i, face in enumerate(faces):
|
121 |
+
cx, cy, fw, fh, angle = face
|
122 |
+
face_size = max(fw, fh)
|
123 |
+
if args.min_size is not None and face_size < args.min_size:
|
124 |
+
continue
|
125 |
+
if args.max_size is not None and face_size >= args.max_size:
|
126 |
+
continue
|
127 |
+
face_suffix = f"_{i+1:02d}" if args.multiple_faces else ""
|
128 |
+
|
129 |
+
# オプション指定があれば回転する
|
130 |
+
face_img = image
|
131 |
+
if args.rotate:
|
132 |
+
face_img, cx, cy = rotate_image(face_img, angle, cx, cy)
|
133 |
+
|
134 |
+
# オプション指定があれば顔を中心に切り出す
|
135 |
+
if crop_width is not None or crop_h_ratio is not None:
|
136 |
+
cur_crop_width, cur_crop_height = crop_width, crop_height
|
137 |
+
if crop_h_ratio is not None:
|
138 |
+
cur_crop_width = int(face_size * crop_h_ratio + .5)
|
139 |
+
cur_crop_height = int(face_size * crop_v_ratio + .5)
|
140 |
+
|
141 |
+
# リサイズを必要なら行う
|
142 |
+
scale = 1.0
|
143 |
+
if args.resize_face_size is not None:
|
144 |
+
# 顔サイズを基準にリサイズする
|
145 |
+
scale = args.resize_face_size / face_size
|
146 |
+
if scale < cur_crop_width / w:
|
147 |
+
print(
|
148 |
+
f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
|
149 |
+
scale = cur_crop_width / w
|
150 |
+
if scale < cur_crop_height / h:
|
151 |
+
print(
|
152 |
+
f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
|
153 |
+
scale = cur_crop_height / h
|
154 |
+
elif crop_h_ratio is not None:
|
155 |
+
# 倍率指定の時にはリサイズしない
|
156 |
+
pass
|
157 |
+
else:
|
158 |
+
# 切り出しサイズ指定あり
|
159 |
+
if w < cur_crop_width:
|
160 |
+
print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
|
161 |
+
scale = cur_crop_width / w
|
162 |
+
if h < cur_crop_height:
|
163 |
+
print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
|
164 |
+
scale = cur_crop_height / h
|
165 |
+
if args.resize_fit:
|
166 |
+
scale = max(cur_crop_width / w, cur_crop_height / h)
|
167 |
+
|
168 |
+
if scale != 1.0:
|
169 |
+
w = int(w * scale + .5)
|
170 |
+
h = int(h * scale + .5)
|
171 |
+
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4)
|
172 |
+
cx = int(cx * scale + .5)
|
173 |
+
cy = int(cy * scale + .5)
|
174 |
+
fw = int(fw * scale + .5)
|
175 |
+
fh = int(fh * scale + .5)
|
176 |
+
|
177 |
+
cur_crop_width = min(cur_crop_width, face_img.shape[1])
|
178 |
+
cur_crop_height = min(cur_crop_height, face_img.shape[0])
|
179 |
+
|
180 |
+
x = cx - cur_crop_width // 2
|
181 |
+
cx = cur_crop_width // 2
|
182 |
+
if x < 0:
|
183 |
+
cx = cx + x
|
184 |
+
x = 0
|
185 |
+
elif x + cur_crop_width > w:
|
186 |
+
cx = cx + (x + cur_crop_width - w)
|
187 |
+
x = w - cur_crop_width
|
188 |
+
face_img = face_img[:, x:x+cur_crop_width]
|
189 |
+
|
190 |
+
y = cy - cur_crop_height // 2
|
191 |
+
cy = cur_crop_height // 2
|
192 |
+
if y < 0:
|
193 |
+
cy = cy + y
|
194 |
+
y = 0
|
195 |
+
elif y + cur_crop_height > h:
|
196 |
+
cy = cy + (y + cur_crop_height - h)
|
197 |
+
y = h - cur_crop_height
|
198 |
+
face_img = face_img[y:y + cur_crop_height]
|
199 |
+
|
200 |
+
# # debug
|
201 |
+
# print(path, cx, cy, angle)
|
202 |
+
# crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8))
|
203 |
+
# cv2.imshow("image", crp)
|
204 |
+
# if cv2.waitKey() == 27:
|
205 |
+
# break
|
206 |
+
# cv2.destroyAllWindows()
|
207 |
+
|
208 |
+
# debug
|
209 |
+
if args.debug:
|
210 |
+
cv2.rectangle(face_img, (cx-fw//2, cy-fh//2), (cx+fw//2, cy+fh//2), (255, 0, 255), fw//20)
|
211 |
+
|
212 |
+
_, buf = cv2.imencode(output_extension, face_img)
|
213 |
+
with open(os.path.join(args.dst_dir, f"{basename}{face_suffix}_{cx:04d}_{cy:04d}_{fw:04d}_{fh:04d}{output_extension}"), "wb") as f:
|
214 |
+
buf.tofile(f)
|
215 |
+
|
216 |
+
|
217 |
+
if __name__ == '__main__':
|
218 |
+
parser = argparse.ArgumentParser()
|
219 |
+
parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ")
|
220 |
+
parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ")
|
221 |
+
parser.add_argument("--rotate", action="store_true", help="rotate images to align faces / 顔が正立するように画像を回転する")
|
222 |
+
parser.add_argument("--resize_fit", action="store_true",
|
223 |
+
help="resize to fit smaller side after cropping / 切り出し後の画像の短辺がcrop_sizeにあうようにリサイズする")
|
224 |
+
parser.add_argument("--resize_face_size", type=int, default=None,
|
225 |
+
help="resize image before cropping by face size / 切り出し前に顔がこのサイズになるようにリサイズする")
|
226 |
+
parser.add_argument("--crop_size", type=str, default=None,
|
227 |
+
help="crop images with 'width,height' pixels, face centered / 顔を中心として'幅,高さ'のサイズで切り出す")
|
228 |
+
parser.add_argument("--crop_ratio", type=str, default=None,
|
229 |
+
help="crop images with 'horizontal,vertical' ratio to face, face centered / 顔を中心として顔サイズの'幅倍率,高さ倍率'のサイズで切り出す")
|
230 |
+
parser.add_argument("--min_size", type=int, default=None,
|
231 |
+
help="minimum face size to output (included) / 処理対象とする顔の最小サイズ(この値以上)")
|
232 |
+
parser.add_argument("--max_size", type=int, default=None,
|
233 |
+
help="maximum face size to output (excluded) / 処理対象とする顔の最大サイズ(この値未満)")
|
234 |
+
parser.add_argument("--multiple_faces", action="store_true",
|
235 |
+
help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す")
|
236 |
+
parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します")
|
237 |
+
args = parser.parse_args()
|
238 |
+
|
239 |
+
process(args)
|
tools/resize_images_to_resolution.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import cv2
|
4 |
+
import argparse
|
5 |
+
import shutil
|
6 |
+
import math
|
7 |
+
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
|
11 |
+
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
|
12 |
+
# Split the max_resolution string by "," and strip any whitespaces
|
13 |
+
max_resolutions = [res.strip() for res in max_resolution.split(',')]
|
14 |
+
|
15 |
+
# # Calculate max_pixels from max_resolution string
|
16 |
+
# max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
|
17 |
+
|
18 |
+
# Create destination folder if it does not exist
|
19 |
+
if not os.path.exists(dst_img_folder):
|
20 |
+
os.makedirs(dst_img_folder)
|
21 |
+
|
22 |
+
# Select interpolation method
|
23 |
+
if interpolation == 'lanczos4':
|
24 |
+
cv2_interpolation = cv2.INTER_LANCZOS4
|
25 |
+
elif interpolation == 'cubic':
|
26 |
+
cv2_interpolation = cv2.INTER_CUBIC
|
27 |
+
else:
|
28 |
+
cv2_interpolation = cv2.INTER_AREA
|
29 |
+
|
30 |
+
# Iterate through all files in src_img_folder
|
31 |
+
img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
|
32 |
+
for filename in os.listdir(src_img_folder):
|
33 |
+
# Check if the image is png, jpg or webp etc...
|
34 |
+
if not filename.endswith(img_exts):
|
35 |
+
# Copy the file to the destination folder if not png, jpg or webp etc (.txt or .caption or etc.)
|
36 |
+
shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
|
37 |
+
continue
|
38 |
+
|
39 |
+
# Load image
|
40 |
+
# img = cv2.imread(os.path.join(src_img_folder, filename))
|
41 |
+
image = Image.open(os.path.join(src_img_folder, filename))
|
42 |
+
if not image.mode == "RGB":
|
43 |
+
image = image.convert("RGB")
|
44 |
+
img = np.array(image, np.uint8)
|
45 |
+
|
46 |
+
base, _ = os.path.splitext(filename)
|
47 |
+
for max_resolution in max_resolutions:
|
48 |
+
# Calculate max_pixels from max_resolution string
|
49 |
+
max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
|
50 |
+
|
51 |
+
# Calculate current number of pixels
|
52 |
+
current_pixels = img.shape[0] * img.shape[1]
|
53 |
+
|
54 |
+
# Check if the image needs resizing
|
55 |
+
if current_pixels > max_pixels:
|
56 |
+
# Calculate scaling factor
|
57 |
+
scale_factor = max_pixels / current_pixels
|
58 |
+
|
59 |
+
# Calculate new dimensions
|
60 |
+
new_height = int(img.shape[0] * math.sqrt(scale_factor))
|
61 |
+
new_width = int(img.shape[1] * math.sqrt(scale_factor))
|
62 |
+
|
63 |
+
# Resize image
|
64 |
+
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
|
65 |
+
else:
|
66 |
+
new_height, new_width = img.shape[0:2]
|
67 |
+
|
68 |
+
# Calculate the new height and width that are divisible by divisible_by (with/without resizing)
|
69 |
+
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
|
70 |
+
new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
|
71 |
+
|
72 |
+
# Center crop the image to the calculated dimensions
|
73 |
+
y = int((img.shape[0] - new_height) / 2)
|
74 |
+
x = int((img.shape[1] - new_width) / 2)
|
75 |
+
img = img[y:y + new_height, x:x + new_width]
|
76 |
+
|
77 |
+
# Split filename into base and extension
|
78 |
+
new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg')
|
79 |
+
|
80 |
+
# Save resized image in dst_img_folder
|
81 |
+
# cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
|
82 |
+
image = Image.fromarray(img)
|
83 |
+
image.save(os.path.join(dst_img_folder, new_filename), quality=100)
|
84 |
+
|
85 |
+
proc = "Resized" if current_pixels > max_pixels else "Saved"
|
86 |
+
print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
|
87 |
+
|
88 |
+
# If other files with same basename, copy them with resolution suffix
|
89 |
+
if copy_associated_files:
|
90 |
+
asoc_files = glob.glob(os.path.join(src_img_folder, base + ".*"))
|
91 |
+
for asoc_file in asoc_files:
|
92 |
+
ext = os.path.splitext(asoc_file)[1]
|
93 |
+
if ext in img_exts:
|
94 |
+
continue
|
95 |
+
for max_resolution in max_resolutions:
|
96 |
+
new_asoc_file = base + '+' + max_resolution + ext
|
97 |
+
print(f"Copy {asoc_file} as {new_asoc_file}")
|
98 |
+
shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))
|
99 |
+
|
100 |
+
|
101 |
+
def main():
|
102 |
+
parser = argparse.ArgumentParser(
|
103 |
+
description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします')
|
104 |
+
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ')
|
105 |
+
parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images / リサイズ後の画像を保存するフォルダ')
|
106 |
+
parser.add_argument('--max_resolution', type=str,
|
107 |
+
help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
|
108 |
+
parser.add_argument('--divisible_by', type=int,
|
109 |
+
help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
|
110 |
+
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
|
111 |
+
default='area', help='Interpolation method for resizing / リサイズ時の補完方法')
|
112 |
+
parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
|
113 |
+
parser.add_argument('--copy_associated_files', action='store_true',
|
114 |
+
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')
|
115 |
+
|
116 |
+
args = parser.parse_args()
|
117 |
+
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution,
|
118 |
+
args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files)
|
119 |
+
|
120 |
+
|
121 |
+
if __name__ == '__main__':
|
122 |
+
main()
|
train_db.py
CHANGED
@@ -15,11 +15,7 @@ import diffusers
|
|
15 |
from diffusers import DDPMScheduler
|
16 |
|
17 |
import library.train_util as train_util
|
18 |
-
|
19 |
-
from library.config_util import (
|
20 |
-
ConfigSanitizer,
|
21 |
-
BlueprintGenerator,
|
22 |
-
)
|
23 |
|
24 |
|
25 |
def collate_fn(examples):
|
@@ -37,33 +33,24 @@ def train(args):
|
|
37 |
|
38 |
tokenizer = train_util.load_tokenizer(args)
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
if any(getattr(args, attr) is not None for attr in ignored):
|
46 |
-
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
47 |
-
else:
|
48 |
-
user_config = {
|
49 |
-
"datasets": [{
|
50 |
-
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
|
51 |
-
}]
|
52 |
-
}
|
53 |
-
|
54 |
-
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
55 |
-
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
56 |
|
57 |
if args.no_token_padding:
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
if args.debug_dataset:
|
61 |
-
train_util.debug_dataset(
|
62 |
return
|
63 |
|
64 |
-
if cache_latents:
|
65 |
-
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
66 |
-
|
67 |
# acceleratorを準備する
|
68 |
print("prepare accelerator")
|
69 |
|
@@ -104,7 +91,7 @@ def train(args):
|
|
104 |
vae.requires_grad_(False)
|
105 |
vae.eval()
|
106 |
with torch.no_grad():
|
107 |
-
|
108 |
vae.to("cpu")
|
109 |
if torch.cuda.is_available():
|
110 |
torch.cuda.empty_cache()
|
@@ -128,18 +115,38 @@ def train(args):
|
|
128 |
|
129 |
# 学習に必要なクラスを準備する
|
130 |
print("prepare optimizer, data loader etc.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
if train_text_encoder:
|
132 |
trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
133 |
else:
|
134 |
trainable_params = unet.parameters()
|
135 |
|
136 |
-
|
|
|
137 |
|
138 |
# dataloaderを準備する
|
139 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
140 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
141 |
train_dataloader = torch.utils.data.DataLoader(
|
142 |
-
|
143 |
|
144 |
# 学習ステップ数を計算する
|
145 |
if args.max_train_epochs is not None:
|
@@ -149,10 +156,9 @@ def train(args):
|
|
149 |
if args.stop_text_encoder_training is None:
|
150 |
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
151 |
|
152 |
-
# lr schedulerを用意する
|
153 |
-
lr_scheduler =
|
154 |
-
|
155 |
-
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
156 |
|
157 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
158 |
if args.full_fp16:
|
@@ -189,8 +195,8 @@ def train(args):
|
|
189 |
# 学習する
|
190 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
191 |
print("running training / 学習開始")
|
192 |
-
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {
|
193 |
-
print(f" num reg images / 正則化画像の数: {
|
194 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
195 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
196 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
@@ -211,7 +217,7 @@ def train(args):
|
|
211 |
loss_total = 0.0
|
212 |
for epoch in range(num_train_epochs):
|
213 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
214 |
-
|
215 |
|
216 |
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
217 |
unet.train()
|
@@ -275,12 +281,12 @@ def train(args):
|
|
275 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
276 |
|
277 |
accelerator.backward(loss)
|
278 |
-
if accelerator.sync_gradients
|
279 |
if train_text_encoder:
|
280 |
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
281 |
else:
|
282 |
params_to_clip = unet.parameters()
|
283 |
-
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
284 |
|
285 |
optimizer.step()
|
286 |
lr_scheduler.step()
|
@@ -291,13 +297,9 @@ def train(args):
|
|
291 |
progress_bar.update(1)
|
292 |
global_step += 1
|
293 |
|
294 |
-
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
295 |
-
|
296 |
current_loss = loss.detach().item()
|
297 |
if args.logging_dir is not None:
|
298 |
-
logs = {"loss": current_loss, "lr":
|
299 |
-
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
300 |
-
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
301 |
accelerator.log(logs, step=global_step)
|
302 |
|
303 |
if epoch == 0:
|
@@ -324,8 +326,6 @@ def train(args):
|
|
324 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
325 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
326 |
|
327 |
-
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
328 |
-
|
329 |
is_main_process = accelerator.is_main_process
|
330 |
if is_main_process:
|
331 |
unet = unwrap_model(unet)
|
@@ -352,8 +352,6 @@ if __name__ == '__main__':
|
|
352 |
train_util.add_dataset_arguments(parser, True, False, True)
|
353 |
train_util.add_training_arguments(parser, True)
|
354 |
train_util.add_sd_saving_arguments(parser)
|
355 |
-
train_util.add_optimizer_arguments(parser)
|
356 |
-
config_util.add_config_arguments(parser)
|
357 |
|
358 |
parser.add_argument("--no_token_padding", action="store_true",
|
359 |
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
|
|
15 |
from diffusers import DDPMScheduler
|
16 |
|
17 |
import library.train_util as train_util
|
18 |
+
from library.train_util import DreamBoothDataset
|
|
|
|
|
|
|
|
|
19 |
|
20 |
|
21 |
def collate_fn(examples):
|
|
|
33 |
|
34 |
tokenizer = train_util.load_tokenizer(args)
|
35 |
|
36 |
+
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
|
37 |
+
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
|
38 |
+
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
39 |
+
args.bucket_reso_steps, args.bucket_no_upscale,
|
40 |
+
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
if args.no_token_padding:
|
43 |
+
train_dataset.disable_token_padding()
|
44 |
+
|
45 |
+
# 学習データのdropout率を設定する
|
46 |
+
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
|
47 |
+
|
48 |
+
train_dataset.make_buckets()
|
49 |
|
50 |
if args.debug_dataset:
|
51 |
+
train_util.debug_dataset(train_dataset)
|
52 |
return
|
53 |
|
|
|
|
|
|
|
54 |
# acceleratorを準備する
|
55 |
print("prepare accelerator")
|
56 |
|
|
|
91 |
vae.requires_grad_(False)
|
92 |
vae.eval()
|
93 |
with torch.no_grad():
|
94 |
+
train_dataset.cache_latents(vae)
|
95 |
vae.to("cpu")
|
96 |
if torch.cuda.is_available():
|
97 |
torch.cuda.empty_cache()
|
|
|
115 |
|
116 |
# 学習に必要なクラスを準備する
|
117 |
print("prepare optimizer, data loader etc.")
|
118 |
+
|
119 |
+
# 8-bit Adamを使う
|
120 |
+
if args.use_8bit_adam:
|
121 |
+
try:
|
122 |
+
import bitsandbytes as bnb
|
123 |
+
except ImportError:
|
124 |
+
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
125 |
+
print("use 8-bit Adam optimizer")
|
126 |
+
optimizer_class = bnb.optim.AdamW8bit
|
127 |
+
elif args.use_lion_optimizer:
|
128 |
+
try:
|
129 |
+
import lion_pytorch
|
130 |
+
except ImportError:
|
131 |
+
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
132 |
+
print("use Lion optimizer")
|
133 |
+
optimizer_class = lion_pytorch.Lion
|
134 |
+
else:
|
135 |
+
optimizer_class = torch.optim.AdamW
|
136 |
+
|
137 |
if train_text_encoder:
|
138 |
trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
139 |
else:
|
140 |
trainable_params = unet.parameters()
|
141 |
|
142 |
+
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
143 |
+
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
144 |
|
145 |
# dataloaderを準備する
|
146 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
147 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
148 |
train_dataloader = torch.utils.data.DataLoader(
|
149 |
+
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
150 |
|
151 |
# 学習ステップ数を計算する
|
152 |
if args.max_train_epochs is not None:
|
|
|
156 |
if args.stop_text_encoder_training is None:
|
157 |
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
158 |
|
159 |
+
# lr schedulerを用意する
|
160 |
+
lr_scheduler = diffusers.optimization.get_scheduler(
|
161 |
+
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
|
|
|
162 |
|
163 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
164 |
if args.full_fp16:
|
|
|
195 |
# 学習する
|
196 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
197 |
print("running training / 学習開始")
|
198 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
|
199 |
+
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
|
200 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
201 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
202 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
|
|
217 |
loss_total = 0.0
|
218 |
for epoch in range(num_train_epochs):
|
219 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
220 |
+
train_dataset.set_current_epoch(epoch + 1)
|
221 |
|
222 |
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
223 |
unet.train()
|
|
|
281 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
282 |
|
283 |
accelerator.backward(loss)
|
284 |
+
if accelerator.sync_gradients:
|
285 |
if train_text_encoder:
|
286 |
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
287 |
else:
|
288 |
params_to_clip = unet.parameters()
|
289 |
+
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
290 |
|
291 |
optimizer.step()
|
292 |
lr_scheduler.step()
|
|
|
297 |
progress_bar.update(1)
|
298 |
global_step += 1
|
299 |
|
|
|
|
|
300 |
current_loss = loss.detach().item()
|
301 |
if args.logging_dir is not None:
|
302 |
+
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
|
|
303 |
accelerator.log(logs, step=global_step)
|
304 |
|
305 |
if epoch == 0:
|
|
|
326 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
327 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
328 |
|
|
|
|
|
329 |
is_main_process = accelerator.is_main_process
|
330 |
if is_main_process:
|
331 |
unet = unwrap_model(unet)
|
|
|
352 |
train_util.add_dataset_arguments(parser, True, False, True)
|
353 |
train_util.add_training_arguments(parser, True)
|
354 |
train_util.add_sd_saving_arguments(parser)
|
|
|
|
|
355 |
|
356 |
parser.add_argument("--no_token_padding", action="store_true",
|
357 |
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
train_network.py
CHANGED
@@ -1,4 +1,8 @@
|
|
|
|
|
|
|
|
1 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
2 |
import importlib
|
3 |
import argparse
|
4 |
import gc
|
@@ -11,41 +15,94 @@ import json
|
|
11 |
from tqdm import tqdm
|
12 |
import torch
|
13 |
from accelerate.utils import set_seed
|
|
|
14 |
from diffusers import DDPMScheduler
|
15 |
|
16 |
import library.train_util as train_util
|
17 |
-
from library.train_util import
|
18 |
-
DreamBoothDataset,
|
19 |
-
)
|
20 |
-
import library.config_util as config_util
|
21 |
-
from library.config_util import (
|
22 |
-
ConfigSanitizer,
|
23 |
-
BlueprintGenerator,
|
24 |
-
)
|
25 |
|
26 |
|
27 |
def collate_fn(examples):
|
28 |
return examples[0]
|
29 |
|
30 |
|
31 |
-
# TODO 他のスクリプトと共通化する
|
32 |
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
33 |
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
34 |
|
35 |
if args.network_train_unet_only:
|
36 |
-
logs["lr/unet"] =
|
37 |
elif args.network_train_text_encoder_only:
|
38 |
-
logs["lr/textencoder"] =
|
39 |
else:
|
40 |
-
logs["lr/textencoder"] =
|
41 |
-
logs["lr/unet"] =
|
42 |
-
|
43 |
-
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
|
44 |
-
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
|
45 |
|
46 |
return logs
|
47 |
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
def train(args):
|
50 |
session_id = random.randint(0, 2**32)
|
51 |
training_started_at = time.time()
|
@@ -54,7 +111,6 @@ def train(args):
|
|
54 |
|
55 |
cache_latents = args.cache_latents
|
56 |
use_dreambooth_method = args.in_json is None
|
57 |
-
use_user_config = args.dataset_config is not None
|
58 |
|
59 |
if args.seed is not None:
|
60 |
set_seed(args.seed)
|
@@ -62,47 +118,35 @@ def train(args):
|
|
62 |
tokenizer = train_util.load_tokenizer(args)
|
63 |
|
64 |
# データセットを準備する
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
else:
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
"image_dir": args.train_data_dir,
|
87 |
-
"metadata_file": args.in_json,
|
88 |
-
}]
|
89 |
-
}]
|
90 |
-
}
|
91 |
-
|
92 |
-
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
93 |
-
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
94 |
|
95 |
if args.debug_dataset:
|
96 |
-
train_util.debug_dataset(
|
97 |
return
|
98 |
-
if len(
|
99 |
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
100 |
return
|
101 |
|
102 |
-
if cache_latents:
|
103 |
-
assert train_dataset_group.is_latent_cacheable(
|
104 |
-
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
105 |
-
|
106 |
# acceleratorを準備する
|
107 |
print("prepare accelerator")
|
108 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
@@ -117,7 +161,7 @@ def train(args):
|
|
117 |
if args.lowram:
|
118 |
text_encoder.to("cuda")
|
119 |
unet.to("cuda")
|
120 |
-
|
121 |
# モデルに xformers とか memory efficient attention を組み込む
|
122 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
123 |
|
@@ -127,7 +171,7 @@ def train(args):
|
|
127 |
vae.requires_grad_(False)
|
128 |
vae.eval()
|
129 |
with torch.no_grad():
|
130 |
-
|
131 |
vae.to("cpu")
|
132 |
if torch.cuda.is_available():
|
133 |
torch.cuda.empty_cache()
|
@@ -164,14 +208,36 @@ def train(args):
|
|
164 |
# 学習に必要なクラスを準備する
|
165 |
print("prepare optimizer, data loader etc.")
|
166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
168 |
-
|
|
|
|
|
169 |
|
170 |
# dataloaderを準備する
|
171 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
172 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
173 |
train_dataloader = torch.utils.data.DataLoader(
|
174 |
-
|
175 |
|
176 |
# 学習ステップ数を計算する
|
177 |
if args.max_train_epochs is not None:
|
@@ -179,9 +245,11 @@ def train(args):
|
|
179 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
180 |
|
181 |
# lr schedulerを用意する
|
182 |
-
lr_scheduler =
|
183 |
-
|
184 |
-
|
|
|
|
|
185 |
|
186 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
187 |
if args.full_fp16:
|
@@ -249,19 +317,17 @@ def train(args):
|
|
249 |
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
250 |
|
251 |
# 学習する
|
252 |
-
# TODO: find a way to handle total batch size when there are multiple datasets
|
253 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
254 |
print("running training / 学習開始")
|
255 |
-
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {
|
256 |
-
print(f" num reg images / 正則化画像の数: {
|
257 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
258 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
259 |
-
print(f" batch size per device / バッチサイズ: {
|
260 |
-
|
261 |
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
262 |
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
263 |
|
264 |
-
# TODO refactor metadata creation and move to util
|
265 |
metadata = {
|
266 |
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
267 |
"ss_training_started_at": training_started_at, # unix timestamp
|
@@ -269,10 +335,12 @@ def train(args):
|
|
269 |
"ss_learning_rate": args.learning_rate,
|
270 |
"ss_text_encoder_lr": args.text_encoder_lr,
|
271 |
"ss_unet_lr": args.unet_lr,
|
272 |
-
"ss_num_train_images":
|
273 |
-
"ss_num_reg_images":
|
274 |
"ss_num_batches_per_epoch": len(train_dataloader),
|
275 |
"ss_num_epochs": num_train_epochs,
|
|
|
|
|
276 |
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
277 |
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
278 |
"ss_max_train_steps": args.max_train_steps,
|
@@ -284,149 +352,29 @@ def train(args):
|
|
284 |
"ss_mixed_precision": args.mixed_precision,
|
285 |
"ss_full_fp16": bool(args.full_fp16),
|
286 |
"ss_v2": bool(args.v2),
|
|
|
287 |
"ss_clip_skip": args.clip_skip,
|
288 |
"ss_max_token_length": args.max_token_length,
|
|
|
|
|
|
|
|
|
289 |
"ss_cache_latents": bool(args.cache_latents),
|
|
|
|
|
|
|
290 |
"ss_seed": args.seed,
|
291 |
-
"
|
292 |
"ss_noise_offset": args.noise_offset,
|
|
|
|
|
|
|
|
|
293 |
"ss_training_comment": args.training_comment, # will not be updated after training
|
294 |
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
295 |
-
"ss_optimizer": optimizer_name
|
296 |
-
"ss_max_grad_norm": args.max_grad_norm,
|
297 |
-
"ss_caption_dropout_rate": args.caption_dropout_rate,
|
298 |
-
"ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
|
299 |
-
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
|
300 |
-
"ss_face_crop_aug_range": args.face_crop_aug_range,
|
301 |
-
"ss_prior_loss_weight": args.prior_loss_weight,
|
302 |
}
|
303 |
|
304 |
-
if use_user_config:
|
305 |
-
# save metadata of multiple datasets
|
306 |
-
# NOTE: pack "ss_datasets" value as json one time
|
307 |
-
# or should also pack nested collections as json?
|
308 |
-
datasets_metadata = []
|
309 |
-
tag_frequency = {} # merge tag frequency for metadata editor
|
310 |
-
dataset_dirs_info = {} # merge subset dirs for metadata editor
|
311 |
-
|
312 |
-
for dataset in train_dataset_group.datasets:
|
313 |
-
is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
|
314 |
-
dataset_metadata = {
|
315 |
-
"is_dreambooth": is_dreambooth_dataset,
|
316 |
-
"batch_size_per_device": dataset.batch_size,
|
317 |
-
"num_train_images": dataset.num_train_images, # includes repeating
|
318 |
-
"num_reg_images": dataset.num_reg_images,
|
319 |
-
"resolution": (dataset.width, dataset.height),
|
320 |
-
"enable_bucket": bool(dataset.enable_bucket),
|
321 |
-
"min_bucket_reso": dataset.min_bucket_reso,
|
322 |
-
"max_bucket_reso": dataset.max_bucket_reso,
|
323 |
-
"tag_frequency": dataset.tag_frequency,
|
324 |
-
"bucket_info": dataset.bucket_info,
|
325 |
-
}
|
326 |
-
|
327 |
-
subsets_metadata = []
|
328 |
-
for subset in dataset.subsets:
|
329 |
-
subset_metadata = {
|
330 |
-
"img_count": subset.img_count,
|
331 |
-
"num_repeats": subset.num_repeats,
|
332 |
-
"color_aug": bool(subset.color_aug),
|
333 |
-
"flip_aug": bool(subset.flip_aug),
|
334 |
-
"random_crop": bool(subset.random_crop),
|
335 |
-
"shuffle_caption": bool(subset.shuffle_caption),
|
336 |
-
"keep_tokens": subset.keep_tokens,
|
337 |
-
}
|
338 |
-
|
339 |
-
image_dir_or_metadata_file = None
|
340 |
-
if subset.image_dir:
|
341 |
-
image_dir = os.path.basename(subset.image_dir)
|
342 |
-
subset_metadata["image_dir"] = image_dir
|
343 |
-
image_dir_or_metadata_file = image_dir
|
344 |
-
|
345 |
-
if is_dreambooth_dataset:
|
346 |
-
subset_metadata["class_tokens"] = subset.class_tokens
|
347 |
-
subset_metadata["is_reg"] = subset.is_reg
|
348 |
-
if subset.is_reg:
|
349 |
-
image_dir_or_metadata_file = None # not merging reg dataset
|
350 |
-
else:
|
351 |
-
metadata_file = os.path.basename(subset.metadata_file)
|
352 |
-
subset_metadata["metadata_file"] = metadata_file
|
353 |
-
image_dir_or_metadata_file = metadata_file # may overwrite
|
354 |
-
|
355 |
-
subsets_metadata.append(subset_metadata)
|
356 |
-
|
357 |
-
# merge dataset dir: not reg subset only
|
358 |
-
# TODO update additional-network extension to show detailed dataset config from metadata
|
359 |
-
if image_dir_or_metadata_file is not None:
|
360 |
-
# datasets may have a certain dir multiple times
|
361 |
-
v = image_dir_or_metadata_file
|
362 |
-
i = 2
|
363 |
-
while v in dataset_dirs_info:
|
364 |
-
v = image_dir_or_metadata_file + f" ({i})"
|
365 |
-
i += 1
|
366 |
-
image_dir_or_metadata_file = v
|
367 |
-
|
368 |
-
dataset_dirs_info[image_dir_or_metadata_file] = {
|
369 |
-
"n_repeats": subset.num_repeats,
|
370 |
-
"img_count": subset.img_count
|
371 |
-
}
|
372 |
-
|
373 |
-
dataset_metadata["subsets"] = subsets_metadata
|
374 |
-
datasets_metadata.append(dataset_metadata)
|
375 |
-
|
376 |
-
# merge tag frequency:
|
377 |
-
for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
|
378 |
-
# あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
|
379 |
-
# もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
|
380 |
-
# なので、ここで複数datasetの回数を合算してもあまり意味はない
|
381 |
-
if ds_dir_name in tag_frequency:
|
382 |
-
continue
|
383 |
-
tag_frequency[ds_dir_name] = ds_freq_for_dir
|
384 |
-
|
385 |
-
metadata["ss_datasets"] = json.dumps(datasets_metadata)
|
386 |
-
metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
|
387 |
-
metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
|
388 |
-
else:
|
389 |
-
# conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
|
390 |
-
assert len(
|
391 |
-
train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
|
392 |
-
|
393 |
-
dataset = train_dataset_group.datasets[0]
|
394 |
-
|
395 |
-
dataset_dirs_info = {}
|
396 |
-
reg_dataset_dirs_info = {}
|
397 |
-
if use_dreambooth_method:
|
398 |
-
for subset in dataset.subsets:
|
399 |
-
info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
|
400 |
-
info[os.path.basename(subset.image_dir)] = {
|
401 |
-
"n_repeats": subset.num_repeats,
|
402 |
-
"img_count": subset.img_count
|
403 |
-
}
|
404 |
-
else:
|
405 |
-
for subset in dataset.subsets:
|
406 |
-
dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
|
407 |
-
"n_repeats": subset.num_repeats,
|
408 |
-
"img_count": subset.img_count
|
409 |
-
}
|
410 |
-
|
411 |
-
metadata.update({
|
412 |
-
"ss_batch_size_per_device": args.train_batch_size,
|
413 |
-
"ss_total_batch_size": total_batch_size,
|
414 |
-
"ss_resolution": args.resolution,
|
415 |
-
"ss_color_aug": bool(args.color_aug),
|
416 |
-
"ss_flip_aug": bool(args.flip_aug),
|
417 |
-
"ss_random_crop": bool(args.random_crop),
|
418 |
-
"ss_shuffle_caption": bool(args.shuffle_caption),
|
419 |
-
"ss_enable_bucket": bool(dataset.enable_bucket),
|
420 |
-
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
|
421 |
-
"ss_min_bucket_reso": dataset.min_bucket_reso,
|
422 |
-
"ss_max_bucket_reso": dataset.max_bucket_reso,
|
423 |
-
"ss_keep_tokens": args.keep_tokens,
|
424 |
-
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
|
425 |
-
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
|
426 |
-
"ss_tag_frequency": json.dumps(dataset.tag_frequency),
|
427 |
-
"ss_bucket_info": json.dumps(dataset.bucket_info),
|
428 |
-
})
|
429 |
-
|
430 |
# uncomment if another network is added
|
431 |
# for key, value in net_kwargs.items():
|
432 |
# metadata["ss_arg_" + key] = value
|
@@ -462,7 +410,7 @@ def train(args):
|
|
462 |
loss_total = 0.0
|
463 |
for epoch in range(num_train_epochs):
|
464 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
465 |
-
|
466 |
|
467 |
metadata["ss_epoch"] = str(epoch+1)
|
468 |
|
@@ -499,7 +447,7 @@ def train(args):
|
|
499 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
500 |
|
501 |
# Predict the noise residual
|
502 |
-
with
|
503 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
504 |
|
505 |
if args.v_parameterization:
|
@@ -517,9 +465,9 @@ def train(args):
|
|
517 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
518 |
|
519 |
accelerator.backward(loss)
|
520 |
-
if accelerator.sync_gradients
|
521 |
params_to_clip = network.get_trainable_params()
|
522 |
-
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
523 |
|
524 |
optimizer.step()
|
525 |
lr_scheduler.step()
|
@@ -530,8 +478,6 @@ def train(args):
|
|
530 |
progress_bar.update(1)
|
531 |
global_step += 1
|
532 |
|
533 |
-
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
534 |
-
|
535 |
current_loss = loss.detach().item()
|
536 |
if epoch == 0:
|
537 |
loss_list.append(current_loss)
|
@@ -562,7 +508,6 @@ def train(args):
|
|
562 |
def save_func():
|
563 |
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
564 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
565 |
-
metadata["ss_training_finished_at"] = str(time.time())
|
566 |
print(f"saving checkpoint: {ckpt_file}")
|
567 |
unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
568 |
|
@@ -577,12 +522,9 @@ def train(args):
|
|
577 |
if saving and args.save_state:
|
578 |
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
579 |
|
580 |
-
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
581 |
-
|
582 |
# end of epoch
|
583 |
|
584 |
metadata["ss_epoch"] = str(num_train_epochs)
|
585 |
-
metadata["ss_training_finished_at"] = str(time.time())
|
586 |
|
587 |
is_main_process = accelerator.is_main_process
|
588 |
if is_main_process:
|
@@ -613,8 +555,6 @@ if __name__ == '__main__':
|
|
613 |
train_util.add_sd_models_arguments(parser)
|
614 |
train_util.add_dataset_arguments(parser, True, True, True)
|
615 |
train_util.add_training_arguments(parser, True)
|
616 |
-
train_util.add_optimizer_arguments(parser)
|
617 |
-
config_util.add_config_arguments(parser)
|
618 |
|
619 |
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
620 |
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
@@ -622,6 +562,10 @@ if __name__ == '__main__':
|
|
622 |
|
623 |
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
624 |
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
|
|
|
|
|
|
|
|
625 |
|
626 |
parser.add_argument("--network_weights", type=str, default=None,
|
627 |
help="pretrained weights for network / 学習するネットワークの初期重み")
|
|
|
1 |
+
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
2 |
+
from torch.optim import Optimizer
|
3 |
+
from torch.cuda.amp import autocast
|
4 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
5 |
+
from typing import Optional, Union
|
6 |
import importlib
|
7 |
import argparse
|
8 |
import gc
|
|
|
15 |
from tqdm import tqdm
|
16 |
import torch
|
17 |
from accelerate.utils import set_seed
|
18 |
+
import diffusers
|
19 |
from diffusers import DDPMScheduler
|
20 |
|
21 |
import library.train_util as train_util
|
22 |
+
from library.train_util import DreamBoothDataset, FineTuningDataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
def collate_fn(examples):
|
26 |
return examples[0]
|
27 |
|
28 |
|
|
|
29 |
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
30 |
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
31 |
|
32 |
if args.network_train_unet_only:
|
33 |
+
logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
|
34 |
elif args.network_train_text_encoder_only:
|
35 |
+
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
36 |
else:
|
37 |
+
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
38 |
+
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
|
|
|
|
|
|
|
39 |
|
40 |
return logs
|
41 |
|
42 |
|
43 |
+
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
44 |
+
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
45 |
+
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
46 |
+
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
47 |
+
|
48 |
+
|
49 |
+
def get_scheduler_fix(
|
50 |
+
name: Union[str, SchedulerType],
|
51 |
+
optimizer: Optimizer,
|
52 |
+
num_warmup_steps: Optional[int] = None,
|
53 |
+
num_training_steps: Optional[int] = None,
|
54 |
+
num_cycles: int = 1,
|
55 |
+
power: float = 1.0,
|
56 |
+
):
|
57 |
+
"""
|
58 |
+
Unified API to get any scheduler from its name.
|
59 |
+
Args:
|
60 |
+
name (`str` or `SchedulerType`):
|
61 |
+
The name of the scheduler to use.
|
62 |
+
optimizer (`torch.optim.Optimizer`):
|
63 |
+
The optimizer that will be used during training.
|
64 |
+
num_warmup_steps (`int`, *optional*):
|
65 |
+
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
66 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
67 |
+
num_training_steps (`int``, *optional*):
|
68 |
+
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
69 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
70 |
+
num_cycles (`int`, *optional*):
|
71 |
+
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
72 |
+
power (`float`, *optional*, defaults to 1.0):
|
73 |
+
Power factor. See `POLYNOMIAL` scheduler
|
74 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
75 |
+
The index of the last epoch when resuming training.
|
76 |
+
"""
|
77 |
+
name = SchedulerType(name)
|
78 |
+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
79 |
+
if name == SchedulerType.CONSTANT:
|
80 |
+
return schedule_func(optimizer)
|
81 |
+
|
82 |
+
# All other schedulers require `num_warmup_steps`
|
83 |
+
if num_warmup_steps is None:
|
84 |
+
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
85 |
+
|
86 |
+
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
87 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
88 |
+
|
89 |
+
# All other schedulers require `num_training_steps`
|
90 |
+
if num_training_steps is None:
|
91 |
+
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
92 |
+
|
93 |
+
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
94 |
+
return schedule_func(
|
95 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
96 |
+
)
|
97 |
+
|
98 |
+
if name == SchedulerType.POLYNOMIAL:
|
99 |
+
return schedule_func(
|
100 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
101 |
+
)
|
102 |
+
|
103 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
104 |
+
|
105 |
+
|
106 |
def train(args):
|
107 |
session_id = random.randint(0, 2**32)
|
108 |
training_started_at = time.time()
|
|
|
111 |
|
112 |
cache_latents = args.cache_latents
|
113 |
use_dreambooth_method = args.in_json is None
|
|
|
114 |
|
115 |
if args.seed is not None:
|
116 |
set_seed(args.seed)
|
|
|
118 |
tokenizer = train_util.load_tokenizer(args)
|
119 |
|
120 |
# データセットを準備する
|
121 |
+
if use_dreambooth_method:
|
122 |
+
print("Use DreamBooth method.")
|
123 |
+
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
|
124 |
+
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
|
125 |
+
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
126 |
+
args.bucket_reso_steps, args.bucket_no_upscale,
|
127 |
+
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
|
128 |
+
args.random_crop, args.debug_dataset)
|
129 |
else:
|
130 |
+
print("Train with captions.")
|
131 |
+
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
132 |
+
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
133 |
+
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
134 |
+
args.bucket_reso_steps, args.bucket_no_upscale,
|
135 |
+
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
136 |
+
args.dataset_repeats, args.debug_dataset)
|
137 |
+
|
138 |
+
# 学習データのdropout率を設定する
|
139 |
+
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
|
140 |
+
|
141 |
+
train_dataset.make_buckets()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
if args.debug_dataset:
|
144 |
+
train_util.debug_dataset(train_dataset)
|
145 |
return
|
146 |
+
if len(train_dataset) == 0:
|
147 |
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
148 |
return
|
149 |
|
|
|
|
|
|
|
|
|
150 |
# acceleratorを準備する
|
151 |
print("prepare accelerator")
|
152 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
|
|
161 |
if args.lowram:
|
162 |
text_encoder.to("cuda")
|
163 |
unet.to("cuda")
|
164 |
+
|
165 |
# モデルに xformers とか memory efficient attention を組み込む
|
166 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
167 |
|
|
|
171 |
vae.requires_grad_(False)
|
172 |
vae.eval()
|
173 |
with torch.no_grad():
|
174 |
+
train_dataset.cache_latents(vae)
|
175 |
vae.to("cpu")
|
176 |
if torch.cuda.is_available():
|
177 |
torch.cuda.empty_cache()
|
|
|
208 |
# 学習に必要なクラスを準備する
|
209 |
print("prepare optimizer, data loader etc.")
|
210 |
|
211 |
+
# 8-bit Adamを使う
|
212 |
+
if args.use_8bit_adam:
|
213 |
+
try:
|
214 |
+
import bitsandbytes as bnb
|
215 |
+
except ImportError:
|
216 |
+
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
217 |
+
print("use 8-bit Adam optimizer")
|
218 |
+
optimizer_class = bnb.optim.AdamW8bit
|
219 |
+
elif args.use_lion_optimizer:
|
220 |
+
try:
|
221 |
+
import lion_pytorch
|
222 |
+
except ImportError:
|
223 |
+
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
224 |
+
print("use Lion optimizer")
|
225 |
+
optimizer_class = lion_pytorch.Lion
|
226 |
+
else:
|
227 |
+
optimizer_class = torch.optim.AdamW
|
228 |
+
|
229 |
+
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
230 |
+
|
231 |
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
232 |
+
|
233 |
+
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
234 |
+
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
235 |
|
236 |
# dataloaderを準備する
|
237 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
238 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
239 |
train_dataloader = torch.utils.data.DataLoader(
|
240 |
+
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
241 |
|
242 |
# 学習ステップ数を計算する
|
243 |
if args.max_train_epochs is not None:
|
|
|
245 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
246 |
|
247 |
# lr schedulerを用意する
|
248 |
+
# lr_scheduler = diffusers.optimization.get_scheduler(
|
249 |
+
lr_scheduler = get_scheduler_fix(
|
250 |
+
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
251 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
252 |
+
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
253 |
|
254 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
255 |
if args.full_fp16:
|
|
|
317 |
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
318 |
|
319 |
# 学習する
|
|
|
320 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
321 |
print("running training / 学習開始")
|
322 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
|
323 |
+
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
|
324 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
325 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
326 |
+
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
327 |
+
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
328 |
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
329 |
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
330 |
|
|
|
331 |
metadata = {
|
332 |
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
333 |
"ss_training_started_at": training_started_at, # unix timestamp
|
|
|
335 |
"ss_learning_rate": args.learning_rate,
|
336 |
"ss_text_encoder_lr": args.text_encoder_lr,
|
337 |
"ss_unet_lr": args.unet_lr,
|
338 |
+
"ss_num_train_images": train_dataset.num_train_images, # includes repeating
|
339 |
+
"ss_num_reg_images": train_dataset.num_reg_images,
|
340 |
"ss_num_batches_per_epoch": len(train_dataloader),
|
341 |
"ss_num_epochs": num_train_epochs,
|
342 |
+
"ss_batch_size_per_device": args.train_batch_size,
|
343 |
+
"ss_total_batch_size": total_batch_size,
|
344 |
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
345 |
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
346 |
"ss_max_train_steps": args.max_train_steps,
|
|
|
352 |
"ss_mixed_precision": args.mixed_precision,
|
353 |
"ss_full_fp16": bool(args.full_fp16),
|
354 |
"ss_v2": bool(args.v2),
|
355 |
+
"ss_resolution": args.resolution,
|
356 |
"ss_clip_skip": args.clip_skip,
|
357 |
"ss_max_token_length": args.max_token_length,
|
358 |
+
"ss_color_aug": bool(args.color_aug),
|
359 |
+
"ss_flip_aug": bool(args.flip_aug),
|
360 |
+
"ss_random_crop": bool(args.random_crop),
|
361 |
+
"ss_shuffle_caption": bool(args.shuffle_caption),
|
362 |
"ss_cache_latents": bool(args.cache_latents),
|
363 |
+
"ss_enable_bucket": bool(train_dataset.enable_bucket),
|
364 |
+
"ss_min_bucket_reso": train_dataset.min_bucket_reso,
|
365 |
+
"ss_max_bucket_reso": train_dataset.max_bucket_reso,
|
366 |
"ss_seed": args.seed,
|
367 |
+
"ss_keep_tokens": args.keep_tokens,
|
368 |
"ss_noise_offset": args.noise_offset,
|
369 |
+
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
370 |
+
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
371 |
+
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
|
372 |
+
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
|
373 |
"ss_training_comment": args.training_comment, # will not be updated after training
|
374 |
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
375 |
+
"ss_optimizer": optimizer_name
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
}
|
377 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
378 |
# uncomment if another network is added
|
379 |
# for key, value in net_kwargs.items():
|
380 |
# metadata["ss_arg_" + key] = value
|
|
|
410 |
loss_total = 0.0
|
411 |
for epoch in range(num_train_epochs):
|
412 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
413 |
+
train_dataset.set_current_epoch(epoch + 1)
|
414 |
|
415 |
metadata["ss_epoch"] = str(epoch+1)
|
416 |
|
|
|
447 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
448 |
|
449 |
# Predict the noise residual
|
450 |
+
with autocast():
|
451 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
452 |
|
453 |
if args.v_parameterization:
|
|
|
465 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
466 |
|
467 |
accelerator.backward(loss)
|
468 |
+
if accelerator.sync_gradients:
|
469 |
params_to_clip = network.get_trainable_params()
|
470 |
+
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
471 |
|
472 |
optimizer.step()
|
473 |
lr_scheduler.step()
|
|
|
478 |
progress_bar.update(1)
|
479 |
global_step += 1
|
480 |
|
|
|
|
|
481 |
current_loss = loss.detach().item()
|
482 |
if epoch == 0:
|
483 |
loss_list.append(current_loss)
|
|
|
508 |
def save_func():
|
509 |
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
510 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
|
|
511 |
print(f"saving checkpoint: {ckpt_file}")
|
512 |
unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
513 |
|
|
|
522 |
if saving and args.save_state:
|
523 |
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
524 |
|
|
|
|
|
525 |
# end of epoch
|
526 |
|
527 |
metadata["ss_epoch"] = str(num_train_epochs)
|
|
|
528 |
|
529 |
is_main_process = accelerator.is_main_process
|
530 |
if is_main_process:
|
|
|
555 |
train_util.add_sd_models_arguments(parser)
|
556 |
train_util.add_dataset_arguments(parser, True, True, True)
|
557 |
train_util.add_training_arguments(parser, True)
|
|
|
|
|
558 |
|
559 |
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
560 |
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
|
|
562 |
|
563 |
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
564 |
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
565 |
+
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
566 |
+
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
567 |
+
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
568 |
+
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
569 |
|
570 |
parser.add_argument("--network_weights", type=str, default=None,
|
571 |
help="pretrained weights for network / 学習するネットワークの初期重み")
|
train_network_opt.py
CHANGED
@@ -1,5 +1,8 @@
|
|
|
|
|
|
1 |
from torch.cuda.amp import autocast
|
2 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
3 |
import importlib
|
4 |
import argparse
|
5 |
import gc
|
@@ -12,49 +15,138 @@ import json
|
|
12 |
from tqdm import tqdm
|
13 |
import torch
|
14 |
from accelerate.utils import set_seed
|
15 |
-
|
16 |
from diffusers import DDPMScheduler
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
##### バケット拡張のためのモジュール
|
19 |
import append_module
|
20 |
######
|
21 |
import library.train_util as train_util
|
22 |
-
from library.train_util import
|
23 |
-
DreamBoothDataset,
|
24 |
-
)
|
25 |
-
import library.config_util as config_util
|
26 |
-
from library.config_util import (
|
27 |
-
ConfigSanitizer,
|
28 |
-
BlueprintGenerator,
|
29 |
-
)
|
30 |
|
31 |
|
32 |
def collate_fn(examples):
|
33 |
return examples[0]
|
34 |
|
35 |
|
36 |
-
|
37 |
-
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, split_names=None):
|
38 |
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
else:
|
45 |
-
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
46 |
-
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
|
47 |
else:
|
48 |
last_lrs = lr_scheduler.get_last_lr()
|
49 |
-
|
50 |
-
logs[
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
return logs
|
56 |
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
def train(args):
|
59 |
session_id = random.randint(0, 2**32)
|
60 |
training_started_at = time.time()
|
@@ -63,7 +155,6 @@ def train(args):
|
|
63 |
|
64 |
cache_latents = args.cache_latents
|
65 |
use_dreambooth_method = args.in_json is None
|
66 |
-
use_user_config = args.dataset_config is not None
|
67 |
|
68 |
if args.seed is not None:
|
69 |
set_seed(args.seed)
|
@@ -71,72 +162,52 @@ def train(args):
|
|
71 |
tokenizer = train_util.load_tokenizer(args)
|
72 |
|
73 |
# データセットを準備する
|
74 |
-
if
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
88 |
-
else:
|
89 |
-
if use_dreambooth_method:
|
90 |
-
print("Use DreamBooth method.")
|
91 |
-
user_config = {
|
92 |
-
"datasets": [{
|
93 |
-
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
|
94 |
-
}]
|
95 |
-
}
|
96 |
-
else:
|
97 |
-
print("Train with captions.")
|
98 |
-
user_config = {
|
99 |
-
"datasets": [{
|
100 |
-
"subsets": [{
|
101 |
-
"image_dir": args.train_data_dir,
|
102 |
-
"metadata_file": args.in_json,
|
103 |
-
}]
|
104 |
-
}]
|
105 |
-
}
|
106 |
-
|
107 |
-
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
108 |
-
if args.min_resolution:
|
109 |
-
train_dataset_group = append_module.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
110 |
else:
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
if args.debug_dataset:
|
114 |
-
train_util.debug_dataset(
|
115 |
return
|
116 |
-
if len(
|
117 |
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
118 |
return
|
119 |
|
120 |
-
if cache_latents:
|
121 |
-
assert train_dataset_group.is_latent_cacheable(
|
122 |
-
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
123 |
-
|
124 |
# acceleratorを準備する
|
125 |
print("prepare accelerator")
|
126 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
127 |
-
is_main_process = accelerator.is_main_process
|
128 |
|
129 |
# mixed precisionに対応した型を用意しておき適宜castする
|
130 |
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
131 |
|
132 |
# モデルを読み込む
|
133 |
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
text_encoder.to("cuda")
|
138 |
-
unet.to("cuda")
|
139 |
-
|
140 |
# モデルに xformers とか memory efficient attention を組み込む
|
141 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
142 |
|
@@ -146,15 +217,13 @@ def train(args):
|
|
146 |
vae.requires_grad_(False)
|
147 |
vae.eval()
|
148 |
with torch.no_grad():
|
149 |
-
|
150 |
vae.to("cpu")
|
151 |
if torch.cuda.is_available():
|
152 |
torch.cuda.empty_cache()
|
153 |
gc.collect()
|
154 |
|
155 |
# prepare network
|
156 |
-
import sys
|
157 |
-
sys.path.append(os.path.dirname(__file__))
|
158 |
print("import network module:", args.network_module)
|
159 |
network_module = importlib.import_module(args.network_module)
|
160 |
|
@@ -184,65 +253,188 @@ def train(args):
|
|
184 |
|
185 |
# 学習に必要なクラスを準備する
|
186 |
print("prepare optimizer, data loader etc.")
|
187 |
-
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
if args.split_lora_networks:
|
191 |
-
lr_dic, block_args_dic = append_module.create_lr_blocks(args.blocks_lr_setting, args.block_optim_args)
|
192 |
lora_names = append_module.create_split_names(args.split_lora_networks, args.split_lora_level)
|
193 |
-
append_module.replace_prepare_optimizer_params(network
|
194 |
-
trainable_params,
|
195 |
else:
|
196 |
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
if args.
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
if args.lookahead_arg is not None:
|
212 |
-
for _arg in args.lookahead_arg:
|
213 |
-
k, v = _arg.split("=")
|
214 |
-
if k == "k":
|
215 |
-
lookahed_arg[k] = int(v)
|
216 |
-
else:
|
217 |
-
lookahed_arg[k] = float(v)
|
218 |
-
optimizer = torch_optimizer.Lookahead(optimizer, **lookahed_arg)
|
219 |
-
except:
|
220 |
-
print("\n============\ntorch_optimizerのimportに失敗しました Lookaheadを無効化して処理を続けます\n============\n")
|
221 |
# dataloaderを準備する
|
222 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
223 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
224 |
train_dataloader = torch.utils.data.DataLoader(
|
225 |
-
|
226 |
|
227 |
# 学習ステップ数を計算する
|
228 |
if args.max_train_epochs is not None:
|
229 |
-
args.max_train_steps = args.max_train_epochs *
|
230 |
-
|
231 |
-
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
232 |
|
233 |
# lr schedulerを用意する
|
234 |
-
|
235 |
-
|
|
|
|
|
|
|
|
|
236 |
else:
|
237 |
-
lr_scheduler =
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
241 |
#追加機能の設定をコメントに追記して残す
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
args.training_comment=f"{args.training_comment} split_lora_networks: {args.split_lora_networks} split_level: {args.split_lora_level}"
|
246 |
if args.min_resolution:
|
247 |
args.training_comment=f"{args.training_comment} min_resolution: {args.min_resolution} area_step: {args.area_step}"
|
248 |
|
@@ -312,21 +504,17 @@ def train(args):
|
|
312 |
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
313 |
|
314 |
# 学習する
|
315 |
-
# TODO: find a way to handle total batch size when there are multiple datasets
|
316 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
328 |
-
|
329 |
-
# TODO refactor metadata creation and move to util
|
330 |
metadata = {
|
331 |
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
332 |
"ss_training_started_at": training_started_at, # unix timestamp
|
@@ -334,10 +522,12 @@ def train(args):
|
|
334 |
"ss_learning_rate": args.learning_rate,
|
335 |
"ss_text_encoder_lr": args.text_encoder_lr,
|
336 |
"ss_unet_lr": args.unet_lr,
|
337 |
-
"ss_num_train_images":
|
338 |
-
"ss_num_reg_images":
|
339 |
"ss_num_batches_per_epoch": len(train_dataloader),
|
340 |
"ss_num_epochs": num_train_epochs,
|
|
|
|
|
341 |
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
342 |
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
343 |
"ss_max_train_steps": args.max_train_steps,
|
@@ -349,156 +539,32 @@ def train(args):
|
|
349 |
"ss_mixed_precision": args.mixed_precision,
|
350 |
"ss_full_fp16": bool(args.full_fp16),
|
351 |
"ss_v2": bool(args.v2),
|
|
|
352 |
"ss_clip_skip": args.clip_skip,
|
353 |
"ss_max_token_length": args.max_token_length,
|
|
|
|
|
|
|
|
|
354 |
"ss_cache_latents": bool(args.cache_latents),
|
|
|
|
|
|
|
355 |
"ss_seed": args.seed,
|
356 |
-
"
|
357 |
"ss_noise_offset": args.noise_offset,
|
|
|
|
|
|
|
|
|
358 |
"ss_training_comment": args.training_comment, # will not be updated after training
|
359 |
-
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash()
|
360 |
-
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
|
361 |
-
"ss_max_grad_norm": args.max_grad_norm,
|
362 |
-
"ss_caption_dropout_rate": args.caption_dropout_rate,
|
363 |
-
"ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
|
364 |
-
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
|
365 |
-
"ss_face_crop_aug_range": args.face_crop_aug_range,
|
366 |
-
"ss_prior_loss_weight": args.prior_loss_weight,
|
367 |
}
|
368 |
|
369 |
-
if
|
370 |
-
# save metadata of multiple datasets
|
371 |
-
# NOTE: pack "ss_datasets" value as json one time
|
372 |
-
# or should also pack nested collections as json?
|
373 |
-
datasets_metadata = []
|
374 |
-
tag_frequency = {} # merge tag frequency for metadata editor
|
375 |
-
dataset_dirs_info = {} # merge subset dirs for metadata editor
|
376 |
-
|
377 |
-
for dataset in train_dataset_group.datasets:
|
378 |
-
is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
|
379 |
-
dataset_metadata = {
|
380 |
-
"is_dreambooth": is_dreambooth_dataset,
|
381 |
-
"batch_size_per_device": dataset.batch_size,
|
382 |
-
"num_train_images": dataset.num_train_images, # includes repeating
|
383 |
-
"num_reg_images": dataset.num_reg_images,
|
384 |
-
"resolution": (dataset.width, dataset.height),
|
385 |
-
"enable_bucket": bool(dataset.enable_bucket),
|
386 |
-
"min_bucket_reso": dataset.min_bucket_reso,
|
387 |
-
"max_bucket_reso": dataset.max_bucket_reso,
|
388 |
-
"tag_frequency": dataset.tag_frequency,
|
389 |
-
"bucket_info": dataset.bucket_info,
|
390 |
-
}
|
391 |
-
|
392 |
-
subsets_metadata = []
|
393 |
-
for subset in dataset.subsets:
|
394 |
-
subset_metadata = {
|
395 |
-
"img_count": subset.img_count,
|
396 |
-
"num_repeats": subset.num_repeats,
|
397 |
-
"color_aug": bool(subset.color_aug),
|
398 |
-
"flip_aug": bool(subset.flip_aug),
|
399 |
-
"random_crop": bool(subset.random_crop),
|
400 |
-
"shuffle_caption": bool(subset.shuffle_caption),
|
401 |
-
"keep_tokens": subset.keep_tokens,
|
402 |
-
}
|
403 |
-
|
404 |
-
image_dir_or_metadata_file = None
|
405 |
-
if subset.image_dir:
|
406 |
-
image_dir = os.path.basename(subset.image_dir)
|
407 |
-
subset_metadata["image_dir"] = image_dir
|
408 |
-
image_dir_or_metadata_file = image_dir
|
409 |
-
|
410 |
-
if is_dreambooth_dataset:
|
411 |
-
subset_metadata["class_tokens"] = subset.class_tokens
|
412 |
-
subset_metadata["is_reg"] = subset.is_reg
|
413 |
-
if subset.is_reg:
|
414 |
-
image_dir_or_metadata_file = None # not merging reg dataset
|
415 |
-
else:
|
416 |
-
metadata_file = os.path.basename(subset.metadata_file)
|
417 |
-
subset_metadata["metadata_file"] = metadata_file
|
418 |
-
image_dir_or_metadata_file = metadata_file # may overwrite
|
419 |
-
|
420 |
-
subsets_metadata.append(subset_metadata)
|
421 |
-
|
422 |
-
# merge dataset dir: not reg subset only
|
423 |
-
# TODO update additional-network extension to show detailed dataset config from metadata
|
424 |
-
if image_dir_or_metadata_file is not None:
|
425 |
-
# datasets may have a certain dir multiple times
|
426 |
-
v = image_dir_or_metadata_file
|
427 |
-
i = 2
|
428 |
-
while v in dataset_dirs_info:
|
429 |
-
v = image_dir_or_metadata_file + f" ({i})"
|
430 |
-
i += 1
|
431 |
-
image_dir_or_metadata_file = v
|
432 |
-
|
433 |
-
dataset_dirs_info[image_dir_or_metadata_file] = {
|
434 |
-
"n_repeats": subset.num_repeats,
|
435 |
-
"img_count": subset.img_count
|
436 |
-
}
|
437 |
-
|
438 |
-
dataset_metadata["subsets"] = subsets_metadata
|
439 |
-
datasets_metadata.append(dataset_metadata)
|
440 |
-
|
441 |
-
# merge tag frequency:
|
442 |
-
for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
|
443 |
-
# あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
|
444 |
-
# もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
|
445 |
-
# なので、ここで複数datasetの回数を合算してもあまり意味はない
|
446 |
-
if ds_dir_name in tag_frequency:
|
447 |
-
continue
|
448 |
-
tag_frequency[ds_dir_name] = ds_freq_for_dir
|
449 |
-
|
450 |
-
metadata["ss_datasets"] = json.dumps(datasets_metadata)
|
451 |
-
metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
|
452 |
-
metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
|
453 |
-
else:
|
454 |
-
# conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
|
455 |
-
assert len(
|
456 |
-
train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
|
457 |
-
|
458 |
-
dataset = train_dataset_group.datasets[0]
|
459 |
-
|
460 |
-
dataset_dirs_info = {}
|
461 |
-
reg_dataset_dirs_info = {}
|
462 |
-
if use_dreambooth_method:
|
463 |
-
for subset in dataset.subsets:
|
464 |
-
info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
|
465 |
-
info[os.path.basename(subset.image_dir)] = {
|
466 |
-
"n_repeats": subset.num_repeats,
|
467 |
-
"img_count": subset.img_count
|
468 |
-
}
|
469 |
-
else:
|
470 |
-
for subset in dataset.subsets:
|
471 |
-
dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
|
472 |
-
"n_repeats": subset.num_repeats,
|
473 |
-
"img_count": subset.img_count
|
474 |
-
}
|
475 |
-
|
476 |
-
metadata.update({
|
477 |
-
"ss_batch_size_per_device": args.train_batch_size,
|
478 |
-
"ss_total_batch_size": total_batch_size,
|
479 |
-
"ss_resolution": args.resolution,
|
480 |
-
"ss_color_aug": bool(args.color_aug),
|
481 |
-
"ss_flip_aug": bool(args.flip_aug),
|
482 |
-
"ss_random_crop": bool(args.random_crop),
|
483 |
-
"ss_shuffle_caption": bool(args.shuffle_caption),
|
484 |
-
"ss_enable_bucket": bool(dataset.enable_bucket),
|
485 |
-
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
|
486 |
-
"ss_min_bucket_reso": dataset.min_bucket_reso,
|
487 |
-
"ss_max_bucket_reso": dataset.max_bucket_reso,
|
488 |
-
"ss_keep_tokens": args.keep_tokens,
|
489 |
-
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
|
490 |
-
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
|
491 |
-
"ss_tag_frequency": json.dumps(dataset.tag_frequency),
|
492 |
-
"ss_bucket_info": json.dumps(dataset.bucket_info),
|
493 |
-
})
|
494 |
-
|
495 |
-
# add extra args
|
496 |
-
if args.network_args:
|
497 |
-
metadata["ss_network_args"] = json.dumps(net_kwargs)
|
498 |
# for key, value in net_kwargs.items():
|
499 |
# metadata["ss_arg_" + key] = value
|
500 |
|
501 |
-
# model name and hash
|
502 |
if args.pretrained_model_name_or_path is not None:
|
503 |
sd_model_name = args.pretrained_model_name_or_path
|
504 |
if os.path.exists(sd_model_name):
|
@@ -517,13 +583,6 @@ def train(args):
|
|
517 |
|
518 |
metadata = {k: str(v) for k, v in metadata.items()}
|
519 |
|
520 |
-
# make minimum metadata for filtering
|
521 |
-
minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"]
|
522 |
-
minimum_metadata = {}
|
523 |
-
for key in minimum_keys:
|
524 |
-
if key in metadata:
|
525 |
-
minimum_metadata[key] = metadata[key]
|
526 |
-
|
527 |
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
528 |
global_step = 0
|
529 |
|
@@ -536,9 +595,8 @@ def train(args):
|
|
536 |
loss_list = []
|
537 |
loss_total = 0.0
|
538 |
for epoch in range(num_train_epochs):
|
539 |
-
|
540 |
-
|
541 |
-
train_dataset_group.set_current_epoch(epoch + 1)
|
542 |
|
543 |
metadata["ss_epoch"] = str(epoch+1)
|
544 |
|
@@ -575,7 +633,7 @@ def train(args):
|
|
575 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
576 |
|
577 |
# Predict the noise residual
|
578 |
-
with
|
579 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
580 |
|
581 |
if args.v_parameterization:
|
@@ -593,13 +651,12 @@ def train(args):
|
|
593 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
594 |
|
595 |
accelerator.backward(loss)
|
596 |
-
if accelerator.sync_gradients
|
597 |
params_to_clip = network.get_trainable_params()
|
598 |
-
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
599 |
|
600 |
optimizer.step()
|
601 |
-
|
602 |
-
lr_scheduler.step()
|
603 |
optimizer.zero_grad(set_to_none=True)
|
604 |
|
605 |
# Checks if the accelerator has performed an optimization step behind the scenes
|
@@ -607,8 +664,6 @@ def train(args):
|
|
607 |
progress_bar.update(1)
|
608 |
global_step += 1
|
609 |
|
610 |
-
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
611 |
-
|
612 |
current_loss = loss.detach().item()
|
613 |
if epoch == 0:
|
614 |
loss_list.append(current_loss)
|
@@ -621,7 +676,7 @@ def train(args):
|
|
621 |
progress_bar.set_postfix(**logs)
|
622 |
|
623 |
if args.logging_dir is not None:
|
624 |
-
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler
|
625 |
accelerator.log(logs, step=global_step)
|
626 |
|
627 |
if global_step >= args.max_train_steps:
|
@@ -639,9 +694,8 @@ def train(args):
|
|
639 |
def save_func():
|
640 |
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
641 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
642 |
-
metadata["ss_training_finished_at"] = str(time.time())
|
643 |
print(f"saving checkpoint: {ckpt_file}")
|
644 |
-
unwrap_model(network).save_weights(ckpt_file, save_dtype,
|
645 |
|
646 |
def remove_old_func(old_epoch_no):
|
647 |
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
@@ -650,18 +704,15 @@ def train(args):
|
|
650 |
print(f"removing old checkpoint: {old_ckpt_file}")
|
651 |
os.remove(old_ckpt_file)
|
652 |
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
657 |
-
|
658 |
-
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
659 |
|
660 |
# end of epoch
|
661 |
|
662 |
metadata["ss_epoch"] = str(num_train_epochs)
|
663 |
-
metadata["ss_training_finished_at"] = str(time.time())
|
664 |
|
|
|
665 |
if is_main_process:
|
666 |
network = unwrap_model(network)
|
667 |
|
@@ -680,7 +731,7 @@ def train(args):
|
|
680 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
681 |
|
682 |
print(f"save trained model to {ckpt_file}")
|
683 |
-
network.save_weights(ckpt_file, save_dtype,
|
684 |
print("model saved.")
|
685 |
|
686 |
|
@@ -690,8 +741,6 @@ if __name__ == '__main__':
|
|
690 |
train_util.add_sd_models_arguments(parser)
|
691 |
train_util.add_dataset_arguments(parser, True, True, True)
|
692 |
train_util.add_training_arguments(parser, True)
|
693 |
-
train_util.add_optimizer_arguments(parser)
|
694 |
-
config_util.add_config_arguments(parser)
|
695 |
|
696 |
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
697 |
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
@@ -699,6 +748,10 @@ if __name__ == '__main__':
|
|
699 |
|
700 |
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
701 |
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
|
|
|
|
|
|
|
|
702 |
|
703 |
parser.add_argument("--network_weights", type=str, default=None,
|
704 |
help="pretrained weights for network / 学習するネットワークの初期重み")
|
@@ -718,30 +771,27 @@ if __name__ == '__main__':
|
|
718 |
#Optimizer変更関連のオプション追加
|
719 |
append_module.add_append_arguments(parser)
|
720 |
args = append_module.get_config(parser)
|
721 |
-
if not args.not_output_config:
|
722 |
-
#argsを保存する
|
723 |
-
import yaml
|
724 |
-
import datetime
|
725 |
-
_t = datetime.datetime.today().strftime('%Y%m%d_%H%M')
|
726 |
-
if args.output_name==None:
|
727 |
-
config_name = f"train_network_config_{_t}.yaml"
|
728 |
-
else:
|
729 |
-
config_name = f"train_network_config_{os.path.basename(args.output_name)}_{_t}.yaml"
|
730 |
-
print(f"{config_name} に設定を書き出し中...")
|
731 |
-
with open(config_name, mode="w") as f:
|
732 |
-
yaml.dump(args.__dict__, f, indent=4)
|
733 |
|
734 |
if args.resolution==args.min_resolution:
|
735 |
args.min_resolution=None
|
736 |
|
737 |
train(args)
|
738 |
-
print("done!")
|
739 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
740 |
|
741 |
'''
|
742 |
optimizer設定メモ
|
743 |
-
torch_optimizer.AdaBelief
|
744 |
-
adastand.Adastand
|
745 |
(optimizer_argから設定できるように変更するためのメモ)
|
746 |
|
747 |
AdamWのweight_decay初期値は1e-2
|
@@ -771,7 +821,6 @@ Adafactor
|
|
771 |
transformerベースのT5学習において最強とかいう噂のoptimizer
|
772 |
huggingfaceのサンプルパラ
|
773 |
eps=1e-30,1e-3 clip_threshold=1.0 decay_rate=-0.8 relative_step=False scale_parameter=False warmup_init=False
|
774 |
-
epsの二つ目の値1e-3が学習率に影響大きい
|
775 |
|
776 |
AggMo
|
777 |
|
|
|
1 |
+
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
2 |
+
from torch.optim import Optimizer
|
3 |
from torch.cuda.amp import autocast
|
4 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
5 |
+
from typing import Optional, Union
|
6 |
import importlib
|
7 |
import argparse
|
8 |
import gc
|
|
|
15 |
from tqdm import tqdm
|
16 |
import torch
|
17 |
from accelerate.utils import set_seed
|
18 |
+
import diffusers
|
19 |
from diffusers import DDPMScheduler
|
20 |
+
print("**********************************")
|
21 |
+
#先に
|
22 |
+
#pip install torch_optimizer
|
23 |
+
#が必要
|
24 |
+
try:
|
25 |
+
import torch_optimizer as optim
|
26 |
+
except:
|
27 |
+
print("torch_optimizerがインストールされていないためAdafactorとAdastand以外の追加optimzierは使えません。\noptimizerの変更をしたい場合先にpip install torch_optimizerでライブラリを追加してください")
|
28 |
+
try:
|
29 |
+
import adastand
|
30 |
+
except:
|
31 |
+
print("※Adastandが使えません")
|
32 |
+
|
33 |
+
from transformers.optimization import Adafactor, AdafactorSchedule
|
34 |
+
print("**********************************")
|
35 |
##### バケット拡張のためのモジュール
|
36 |
import append_module
|
37 |
######
|
38 |
import library.train_util as train_util
|
39 |
+
from library.train_util import DreamBoothDataset, FineTuningDataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
|
42 |
def collate_fn(examples):
|
43 |
return examples[0]
|
44 |
|
45 |
|
46 |
+
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
|
|
47 |
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
48 |
+
|
49 |
+
if args.network_train_unet_only:
|
50 |
+
logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
|
51 |
+
elif args.network_train_text_encoder_only:
|
52 |
+
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
|
|
|
|
|
|
53 |
else:
|
54 |
last_lrs = lr_scheduler.get_last_lr()
|
55 |
+
if len(last_lrs) == 2:
|
56 |
+
logs["lr/textencoder"] = float(last_lrs[0])
|
57 |
+
logs["lr/unet"] = float(last_lrs[-1]) # may be same to textencoder
|
58 |
+
else:
|
59 |
+
if len(last_lrs) == 4:
|
60 |
+
logs_names = ["textencoder", "lora_unet_mid_block", "unet_down_blocks", "unet_up_blocks"]
|
61 |
+
elif len(last_lrs) == 8:
|
62 |
+
logs_names = ["textencoder", "unet_midblock"]
|
63 |
+
for i in range(3):
|
64 |
+
logs_names.append(f"unet_down_blocks_{i}")
|
65 |
+
logs_names.append(f"unet_up_blocks_{i+1}")
|
66 |
+
else:
|
67 |
+
logs_names = []
|
68 |
+
for i in range(12):
|
69 |
+
logs_names.append(f"text_model_encoder_layers_{i}_")
|
70 |
+
logs_names.append("unet_midblock")
|
71 |
+
for i in range(3):
|
72 |
+
logs_names.append(f"unet_down_blocks_{i}")
|
73 |
+
logs_names.append(f"unet_up_blocks_{i+1}")
|
74 |
+
|
75 |
+
for last_lr, logs_name in zip(last_lrs, logs_names):
|
76 |
+
logs[f"lr/{logs_name}"] = float(last_lr)
|
77 |
|
78 |
return logs
|
79 |
|
80 |
|
81 |
+
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
82 |
+
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
83 |
+
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
84 |
+
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
85 |
+
|
86 |
+
|
87 |
+
def get_scheduler_fix(
|
88 |
+
name: Union[str, SchedulerType],
|
89 |
+
optimizer: Optimizer,
|
90 |
+
num_warmup_steps: Optional[int] = None,
|
91 |
+
num_training_steps: Optional[int] = None,
|
92 |
+
num_cycles: float = 1.,
|
93 |
+
power: float = 1.0,
|
94 |
+
):
|
95 |
+
"""
|
96 |
+
Unified API to get any scheduler from its name.
|
97 |
+
Args:
|
98 |
+
name (`str` or `SchedulerType`):
|
99 |
+
The name of the scheduler to use.
|
100 |
+
optimizer (`torch.optim.Optimizer`):
|
101 |
+
The optimizer that will be used during training.
|
102 |
+
num_warmup_steps (`int`, *optional*):
|
103 |
+
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
104 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
105 |
+
num_training_steps (`int``, *optional*):
|
106 |
+
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
107 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
108 |
+
num_cycles (`int`, *optional*):
|
109 |
+
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
110 |
+
power (`float`, *optional*, defaults to 1.0):
|
111 |
+
Power factor. See `POLYNOMIAL` scheduler
|
112 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
113 |
+
The index of the last epoch when resuming training.
|
114 |
+
"""
|
115 |
+
name = SchedulerType(name)
|
116 |
+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
117 |
+
if name == SchedulerType.CONSTANT:
|
118 |
+
return schedule_func(optimizer)
|
119 |
+
|
120 |
+
# All other schedulers require `num_warmup_steps`
|
121 |
+
if num_warmup_steps is None:
|
122 |
+
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
123 |
+
|
124 |
+
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
125 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
126 |
+
|
127 |
+
# All other schedulers require `num_training_steps`
|
128 |
+
if num_training_steps is None:
|
129 |
+
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
130 |
+
|
131 |
+
if name == SchedulerType.COSINE:
|
132 |
+
print(f"{name} num_cycles: {num_cycles}")
|
133 |
+
return schedule_func(
|
134 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
135 |
+
)
|
136 |
+
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
137 |
+
print(f"{name} num_cycles: {int(num_cycles)}")
|
138 |
+
return schedule_func(
|
139 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=int(num_cycles)
|
140 |
+
)
|
141 |
+
|
142 |
+
if name == SchedulerType.POLYNOMIAL:
|
143 |
+
return schedule_func(
|
144 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
145 |
+
)
|
146 |
+
|
147 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
148 |
+
|
149 |
+
|
150 |
def train(args):
|
151 |
session_id = random.randint(0, 2**32)
|
152 |
training_started_at = time.time()
|
|
|
155 |
|
156 |
cache_latents = args.cache_latents
|
157 |
use_dreambooth_method = args.in_json is None
|
|
|
158 |
|
159 |
if args.seed is not None:
|
160 |
set_seed(args.seed)
|
|
|
162 |
tokenizer = train_util.load_tokenizer(args)
|
163 |
|
164 |
# データセットを準備する
|
165 |
+
if use_dreambooth_method:
|
166 |
+
if args.min_resolution:
|
167 |
+
args.min_resolution = tuple([int(r) for r in args.min_resolution.split(',')])
|
168 |
+
if len(args.min_resolution) == 1:
|
169 |
+
args.min_resolution = (args.min_resolution[0], args.min_resolution[0])
|
170 |
+
|
171 |
+
print("Use DreamBooth method.")
|
172 |
+
train_dataset = append_module.DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
|
173 |
+
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
|
174 |
+
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
175 |
+
args.bucket_reso_steps, args.bucket_no_upscale,
|
176 |
+
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
|
177 |
+
args.random_crop, args.debug_dataset, args.min_resolution, args.area_step)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
else:
|
179 |
+
print("Train with captions.")
|
180 |
+
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
181 |
+
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
182 |
+
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
183 |
+
args.bucket_reso_steps, args.bucket_no_upscale,
|
184 |
+
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
185 |
+
args.dataset_repeats, args.debug_dataset)
|
186 |
+
|
187 |
+
# 学習データのdropout率を設定する
|
188 |
+
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
|
189 |
+
|
190 |
+
train_dataset.make_buckets()
|
191 |
|
192 |
if args.debug_dataset:
|
193 |
+
train_util.debug_dataset(train_dataset)
|
194 |
return
|
195 |
+
if len(train_dataset) == 0:
|
196 |
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
197 |
return
|
198 |
|
|
|
|
|
|
|
|
|
199 |
# acceleratorを準備する
|
200 |
print("prepare accelerator")
|
201 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
|
|
202 |
|
203 |
# mixed precisionに対応した型を用意しておき適宜castする
|
204 |
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
205 |
|
206 |
# モデルを読み込む
|
207 |
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
208 |
+
# unnecessary, but work on low-ram device
|
209 |
+
text_encoder.to("cuda")
|
210 |
+
unet.to("cuda")
|
|
|
|
|
|
|
211 |
# モデルに xformers とか memory efficient attention を組み込む
|
212 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
213 |
|
|
|
217 |
vae.requires_grad_(False)
|
218 |
vae.eval()
|
219 |
with torch.no_grad():
|
220 |
+
train_dataset.cache_latents(vae)
|
221 |
vae.to("cpu")
|
222 |
if torch.cuda.is_available():
|
223 |
torch.cuda.empty_cache()
|
224 |
gc.collect()
|
225 |
|
226 |
# prepare network
|
|
|
|
|
227 |
print("import network module:", args.network_module)
|
228 |
network_module = importlib.import_module(args.network_module)
|
229 |
|
|
|
253 |
|
254 |
# 学習に必要なクラスを準備する
|
255 |
print("prepare optimizer, data loader etc.")
|
256 |
+
try:
|
257 |
+
print(f"torch_optimzier version is {optim.__version__}")
|
258 |
+
not_torch_optimizer_flag = False
|
259 |
+
except:
|
260 |
+
not_torch_optimizer_flag = True
|
261 |
+
try:
|
262 |
+
print(f"adastand version is {adastand.__version__()}")
|
263 |
+
not_adasatand_optimzier_flag = False
|
264 |
+
except:
|
265 |
+
not_adasatand_optimzier_flag = True
|
266 |
+
|
267 |
+
# 8-bit Adamを使う
|
268 |
+
if args.optimizer=="Adafactor" or args.optimizer=="Adastand" or args.optimizer=="Adastand_belief":
|
269 |
+
not_torch_optimizer_flag = False
|
270 |
+
if args.optimizer=="Adafactor":
|
271 |
+
not_adasatand_optimzier_flag = False
|
272 |
+
if not_torch_optimizer_flag or not_adasatand_optimzier_flag:
|
273 |
+
print(f"==========================\n必要なライブラリがないため {args.optimizer} の使用ができません。optimizerを AdamW に変更して実行します\n==========================")
|
274 |
+
args.optimizer="AdamW"
|
275 |
+
if args.use_8bit_adam:
|
276 |
+
if not args.optimizer=="AdamW" and not args.optimizer=="Lamb":
|
277 |
+
print(f"\n==========================\n{args.optimizer} は8bitAdamに実装されていないので8bitAdamをオフにします\n==========================\n")
|
278 |
+
args.use_8bit_adam=False
|
279 |
+
if args.use_8bit_adam:
|
280 |
+
try:
|
281 |
+
import bitsandbytes as bnb
|
282 |
+
except ImportError:
|
283 |
+
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
284 |
+
print("use 8-bit Adam optimizer")
|
285 |
+
args.training_comment=f"{args.training_comment} use_8bit_adam=True"
|
286 |
+
if args.optimizer=="Lamb":
|
287 |
+
optimizer_class = bnb.optim.LAMB8bit
|
288 |
+
else:
|
289 |
+
args.optimizer="AdamW"
|
290 |
+
optimizer_class = bnb.optim.AdamW8bit
|
291 |
+
else:
|
292 |
+
print(f"use {args.optimizer}")
|
293 |
+
if args.optimizer=="RAdam":
|
294 |
+
optimizer_class = torch.optim.RAdam
|
295 |
+
elif args.optimizer=="AdaBound":
|
296 |
+
optimizer_class = optim.AdaBound
|
297 |
+
elif args.optimizer=="AdaBelief":
|
298 |
+
optimizer_class = optim.AdaBelief
|
299 |
+
elif args.optimizer=="AdamP":
|
300 |
+
optimizer_class = optim.AdamP
|
301 |
+
elif args.optimizer=="Adafactor":
|
302 |
+
optimizer_class = Adafactor
|
303 |
+
elif args.optimizer=="Adastand":
|
304 |
+
optimizer_class = adastand.Adastand
|
305 |
+
elif args.optimizer=="Adastand_belief":
|
306 |
+
optimizer_class = adastand.Adastand_b
|
307 |
+
elif args.optimizer=="AggMo":
|
308 |
+
optimizer_class = optim.AggMo
|
309 |
+
elif args.optimizer=="Apollo":
|
310 |
+
optimizer_class = optim.Apollo
|
311 |
+
elif args.optimizer=="Lamb":
|
312 |
+
optimizer_class = optim.Lamb
|
313 |
+
elif args.optimizer=="Ranger":
|
314 |
+
optimizer_class = optim.Ranger
|
315 |
+
elif args.optimizer=="RangerVA":
|
316 |
+
optimizer_class = optim.RangerVA
|
317 |
+
elif args.optimizer=="Yogi":
|
318 |
+
optimizer_class = optim.Yogi
|
319 |
+
elif args.optimizer=="Shampoo":
|
320 |
+
optimizer_class = optim.Shampoo
|
321 |
+
elif args.optimizer=="NovoGrad":
|
322 |
+
optimizer_class = optim.NovoGrad
|
323 |
+
elif args.optimizer=="QHAdam":
|
324 |
+
optimizer_class = optim.QHAdam
|
325 |
+
elif args.optimizer=="DiffGrad" or args.optimizer=="Lookahead_DiffGrad":
|
326 |
+
optimizer_class = optim.DiffGrad
|
327 |
+
elif args.optimizer=="MADGRAD":
|
328 |
+
optimizer_class = optim.MADGRAD
|
329 |
+
else:
|
330 |
+
optimizer_class = torch.optim.AdamW
|
331 |
+
|
332 |
+
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
333 |
+
#optimizerデフォ設定
|
334 |
+
if args.optimizer_arg==None:
|
335 |
+
if args.optimizer=="AdaBelief":
|
336 |
+
args.optimizer_arg = ["eps=1e-16","betas=0.9,0.999","weight_decouple=True","rectify=False","fixed_decay=False"]
|
337 |
+
elif args.optimizer=="DiffGrad":
|
338 |
+
args.optimizer_arg = ["eps=1e-16"]
|
339 |
+
optimizer_arg = {}
|
340 |
+
lookahed_arg = {"k": 5, "alpha": 0.5}
|
341 |
+
adafactor_scheduler_arg = {"initial_lr": 0.}
|
342 |
+
int_args = ["k","n_sma_threshold","warmup"]
|
343 |
+
str_args = ["transformer","grad_transformer"]
|
344 |
+
if not args.optimizer_arg==None and len(args.optimizer_arg)>0:
|
345 |
+
for _opt_arg in args.optimizer_arg:
|
346 |
+
key, value = _opt_arg.split("=")
|
347 |
+
if value=="True" or value=="False":
|
348 |
+
optimizer_arg[key]=bool((value=="True"))
|
349 |
+
elif key=="betas" or key=="nus" or key=="eps2" or (key=="eps" and "," in value):
|
350 |
+
_value = value.split(",")
|
351 |
+
optimizer_arg[key] = (float(_value[0]),float(_value[1]))
|
352 |
+
del _value
|
353 |
+
elif key in int_args:
|
354 |
+
if "Lookahead" in args.optimizer:
|
355 |
+
lookahed_arg[key] = int(value)
|
356 |
+
else:
|
357 |
+
optimizer_arg[key] = int(value)
|
358 |
+
elif key in str_args:
|
359 |
+
optimizer_arg[key] = value
|
360 |
+
else:
|
361 |
+
if key=="alpha" and "Lookahead" in args.optimizer:
|
362 |
+
lookahed_arg[key] = int(value)
|
363 |
+
elif key=="initial_lr" and args.optimizer == "Adafactor":
|
364 |
+
adafactor_scheduler_arg[key] = float(value)
|
365 |
+
else:
|
366 |
+
optimizer_arg[key] = float(value)
|
367 |
+
del _opt_arg
|
368 |
+
AdafactorScheduler_Flag = False
|
369 |
+
list_of_init_lr = []
|
370 |
+
if args.optimizer=="Adafactor":
|
371 |
+
if not "relative_step" in optimizer_arg:
|
372 |
+
optimizer_arg["relative_step"] = True
|
373 |
+
if "warmup_init" in optimizer_arg:
|
374 |
+
if optimizer_arg["warmup_init"]==True and optimizer_arg["relative_step"]==False:
|
375 |
+
print("**************\nwarmup_initはrelative_stepがオンである必要があるためrelative_stepをオンにします\n**************")
|
376 |
+
optimizer_arg["relative_step"] = True
|
377 |
+
if optimizer_arg["relative_step"] == True:
|
378 |
+
AdafactorScheduler_Flag = True
|
379 |
+
list_of_init_lr = [0.,0.]
|
380 |
+
if args.text_encoder_lr is not None: list_of_init_lr[0] = float(args.text_encoder_lr)
|
381 |
+
if args.unet_lr is not None: list_of_init_lr[1] = float(args.unet_lr)
|
382 |
+
#if not "initial_lr" in adafactor_scheduler_arg:
|
383 |
+
# adafactor_scheduler_arg = args.learning_rate
|
384 |
+
args.learning_rate = None
|
385 |
+
args.text_encoder_lr = None
|
386 |
+
args.unet_lr = None
|
387 |
+
print(f"optimizer arg: {optimizer_arg}")
|
388 |
+
print("=-----------------------------------=")
|
389 |
+
if not AdafactorScheduler_Flag: args.split_lora_networks = False
|
390 |
if args.split_lora_networks:
|
|
|
391 |
lora_names = append_module.create_split_names(args.split_lora_networks, args.split_lora_level)
|
392 |
+
append_module.replace_prepare_optimizer_params(network)
|
393 |
+
trainable_params, _list_of_init_lr = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, list_of_init_lr, lora_names)
|
394 |
else:
|
395 |
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
396 |
+
_list_of_init_lr = []
|
397 |
+
print(f"trainable_params_len: {len(trainable_params)}")
|
398 |
+
if len(_list_of_init_lr)>0:
|
399 |
+
list_of_init_lr = _list_of_init_lr
|
400 |
+
print(f"split loras network is {len(list_of_init_lr)}")
|
401 |
+
if len(list_of_init_lr) > 0:
|
402 |
+
adafactor_scheduler_arg["initial_lr"] = list_of_init_lr
|
403 |
+
|
404 |
+
optimizer = optimizer_class(trainable_params, lr=args.learning_rate, **optimizer_arg)
|
405 |
+
|
406 |
+
if args.optimizer=="Lookahead_DiffGrad" or args.optimizer=="Lookahedad_Adam":
|
407 |
+
optimizer = optim.Lookahead(optimizer, **lookahed_arg)
|
408 |
+
print(f"lookahed_arg: {lookahed_arg}")
|
409 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
# dataloaderを準備する
|
411 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
412 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
413 |
train_dataloader = torch.utils.data.DataLoader(
|
414 |
+
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
415 |
|
416 |
# 学習ステップ数を計算する
|
417 |
if args.max_train_epochs is not None:
|
418 |
+
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
419 |
+
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
|
|
420 |
|
421 |
# lr schedulerを用意する
|
422 |
+
# lr_scheduler = diffusers.optimization.get_scheduler(
|
423 |
+
if AdafactorScheduler_Flag:
|
424 |
+
print("===================================\nAdafactorはデフォルトでrelative_stepがオンになっているので lrは自動算出されるためLrScheculerの指定も無効になります\nもし任意のLrやLr_Schedulerを使いたい場合は --optimizer_arg relative_ste=False を指定してください\nまた任意のLrを使う場合は scale_parameter=False も併せて指定するのが推奨です\n===================================")
|
425 |
+
lr_scheduler = append_module.AdafactorSchedule_append(optimizer, **adafactor_scheduler_arg)
|
426 |
+
print(f"AdafactorSchedule initial lrs: {lr_scheduler.get_lr()}")
|
427 |
+
del list_of_init_lr
|
428 |
else:
|
429 |
+
lr_scheduler = get_scheduler_fix(
|
430 |
+
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
431 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
432 |
+
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
433 |
+
|
434 |
#追加機能の設定をコメントに追記して残す
|
435 |
+
args.training_comment=f"{args.training_comment} optimizer: {args.optimizer} / optimizer_arg: {args.optimizer_arg}"
|
436 |
+
if AdafactorScheduler_Flag:
|
437 |
+
args.training_comment=f"{args.training_comment} split_lora_networks: {args.split_lora_networks}"
|
|
|
438 |
if args.min_resolution:
|
439 |
args.training_comment=f"{args.training_comment} min_resolution: {args.min_resolution} area_step: {args.area_step}"
|
440 |
|
|
|
504 |
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
505 |
|
506 |
# 学習する
|
|
|
507 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
508 |
+
print("running training / 学習開始")
|
509 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
|
510 |
+
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
|
511 |
+
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
512 |
+
print(f" num epochs / epoch数: {num_train_epochs}")
|
513 |
+
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
514 |
+
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
515 |
+
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
516 |
+
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
517 |
+
|
|
|
|
|
|
|
518 |
metadata = {
|
519 |
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
520 |
"ss_training_started_at": training_started_at, # unix timestamp
|
|
|
522 |
"ss_learning_rate": args.learning_rate,
|
523 |
"ss_text_encoder_lr": args.text_encoder_lr,
|
524 |
"ss_unet_lr": args.unet_lr,
|
525 |
+
"ss_num_train_images": train_dataset.num_train_images, # includes repeating
|
526 |
+
"ss_num_reg_images": train_dataset.num_reg_images,
|
527 |
"ss_num_batches_per_epoch": len(train_dataloader),
|
528 |
"ss_num_epochs": num_train_epochs,
|
529 |
+
"ss_batch_size_per_device": args.train_batch_size,
|
530 |
+
"ss_total_batch_size": total_batch_size,
|
531 |
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
532 |
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
533 |
"ss_max_train_steps": args.max_train_steps,
|
|
|
539 |
"ss_mixed_precision": args.mixed_precision,
|
540 |
"ss_full_fp16": bool(args.full_fp16),
|
541 |
"ss_v2": bool(args.v2),
|
542 |
+
"ss_resolution": args.resolution,
|
543 |
"ss_clip_skip": args.clip_skip,
|
544 |
"ss_max_token_length": args.max_token_length,
|
545 |
+
"ss_color_aug": bool(args.color_aug),
|
546 |
+
"ss_flip_aug": bool(args.flip_aug),
|
547 |
+
"ss_random_crop": bool(args.random_crop),
|
548 |
+
"ss_shuffle_caption": bool(args.shuffle_caption),
|
549 |
"ss_cache_latents": bool(args.cache_latents),
|
550 |
+
"ss_enable_bucket": bool(train_dataset.enable_bucket),
|
551 |
+
"ss_min_bucket_reso": train_dataset.min_bucket_reso,
|
552 |
+
"ss_max_bucket_reso": train_dataset.max_bucket_reso,
|
553 |
"ss_seed": args.seed,
|
554 |
+
"ss_keep_tokens": args.keep_tokens,
|
555 |
"ss_noise_offset": args.noise_offset,
|
556 |
+
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
557 |
+
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
558 |
+
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
|
559 |
+
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
|
560 |
"ss_training_comment": args.training_comment, # will not be updated after training
|
561 |
+
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
562 |
}
|
563 |
|
564 |
+
# uncomment if another network is added
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
565 |
# for key, value in net_kwargs.items():
|
566 |
# metadata["ss_arg_" + key] = value
|
567 |
|
|
|
568 |
if args.pretrained_model_name_or_path is not None:
|
569 |
sd_model_name = args.pretrained_model_name_or_path
|
570 |
if os.path.exists(sd_model_name):
|
|
|
583 |
|
584 |
metadata = {k: str(v) for k, v in metadata.items()}
|
585 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
586 |
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
587 |
global_step = 0
|
588 |
|
|
|
595 |
loss_list = []
|
596 |
loss_total = 0.0
|
597 |
for epoch in range(num_train_epochs):
|
598 |
+
print(f"epoch {epoch+1}/{num_train_epochs}")
|
599 |
+
train_dataset.set_current_epoch(epoch + 1)
|
|
|
600 |
|
601 |
metadata["ss_epoch"] = str(epoch+1)
|
602 |
|
|
|
633 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
634 |
|
635 |
# Predict the noise residual
|
636 |
+
with autocast():
|
637 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
638 |
|
639 |
if args.v_parameterization:
|
|
|
651 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
652 |
|
653 |
accelerator.backward(loss)
|
654 |
+
if accelerator.sync_gradients:
|
655 |
params_to_clip = network.get_trainable_params()
|
656 |
+
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
657 |
|
658 |
optimizer.step()
|
659 |
+
lr_scheduler.step()
|
|
|
660 |
optimizer.zero_grad(set_to_none=True)
|
661 |
|
662 |
# Checks if the accelerator has performed an optimization step behind the scenes
|
|
|
664 |
progress_bar.update(1)
|
665 |
global_step += 1
|
666 |
|
|
|
|
|
667 |
current_loss = loss.detach().item()
|
668 |
if epoch == 0:
|
669 |
loss_list.append(current_loss)
|
|
|
676 |
progress_bar.set_postfix(**logs)
|
677 |
|
678 |
if args.logging_dir is not None:
|
679 |
+
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
680 |
accelerator.log(logs, step=global_step)
|
681 |
|
682 |
if global_step >= args.max_train_steps:
|
|
|
694 |
def save_func():
|
695 |
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
696 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
|
|
697 |
print(f"saving checkpoint: {ckpt_file}")
|
698 |
+
unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
699 |
|
700 |
def remove_old_func(old_epoch_no):
|
701 |
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
|
|
704 |
print(f"removing old checkpoint: {old_ckpt_file}")
|
705 |
os.remove(old_ckpt_file)
|
706 |
|
707 |
+
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
708 |
+
if saving and args.save_state:
|
709 |
+
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
|
|
|
|
|
|
710 |
|
711 |
# end of epoch
|
712 |
|
713 |
metadata["ss_epoch"] = str(num_train_epochs)
|
|
|
714 |
|
715 |
+
is_main_process = accelerator.is_main_process
|
716 |
if is_main_process:
|
717 |
network = unwrap_model(network)
|
718 |
|
|
|
731 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
732 |
|
733 |
print(f"save trained model to {ckpt_file}")
|
734 |
+
network.save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
735 |
print("model saved.")
|
736 |
|
737 |
|
|
|
741 |
train_util.add_sd_models_arguments(parser)
|
742 |
train_util.add_dataset_arguments(parser, True, True, True)
|
743 |
train_util.add_training_arguments(parser, True)
|
|
|
|
|
744 |
|
745 |
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
746 |
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
|
|
748 |
|
749 |
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
750 |
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
751 |
+
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
752 |
+
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
753 |
+
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
754 |
+
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
755 |
|
756 |
parser.add_argument("--network_weights", type=str, default=None,
|
757 |
help="pretrained weights for network / 学習するネットワークの初期重み")
|
|
|
771 |
#Optimizer変更関連のオプション追加
|
772 |
append_module.add_append_arguments(parser)
|
773 |
args = append_module.get_config(parser)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
774 |
|
775 |
if args.resolution==args.min_resolution:
|
776 |
args.min_resolution=None
|
777 |
|
778 |
train(args)
|
|
|
779 |
|
780 |
+
#学習が終わったら現在のargsを保存する
|
781 |
+
# import yaml
|
782 |
+
# import datetime
|
783 |
+
# _t = datetime.datetime.today().strftime('%Y%m%d_%H%M')
|
784 |
+
# if args.output_name==None:
|
785 |
+
# config_name = f"train_network_config_{_t}.yaml"
|
786 |
+
# else:
|
787 |
+
# config_name = f"train_network_config_{os.path.basename(args.output_name)}_{_t}.yaml"
|
788 |
+
# print(f"{config_name} に設定を書き出し中...")
|
789 |
+
# with open(config_name, mode="w") as f:
|
790 |
+
# yaml.dump(args.__dict__, f, indent=4)
|
791 |
+
# print("done!")
|
792 |
|
793 |
'''
|
794 |
optimizer設定メモ
|
|
|
|
|
795 |
(optimizer_argから設定できるように変更するためのメモ)
|
796 |
|
797 |
AdamWのweight_decay初期値は1e-2
|
|
|
821 |
transformerベースのT5学習において最強とかいう噂のoptimizer
|
822 |
huggingfaceのサンプルパラ
|
823 |
eps=1e-30,1e-3 clip_threshold=1.0 decay_rate=-0.8 relative_step=False scale_parameter=False warmup_init=False
|
|
|
824 |
|
825 |
AggMo
|
826 |
|
train_textual_inversion.py
CHANGED
@@ -11,11 +11,7 @@ import diffusers
|
|
11 |
from diffusers import DDPMScheduler
|
12 |
|
13 |
import library.train_util as train_util
|
14 |
-
|
15 |
-
from library.config_util import (
|
16 |
-
ConfigSanitizer,
|
17 |
-
BlueprintGenerator,
|
18 |
-
)
|
19 |
|
20 |
imagenet_templates_small = [
|
21 |
"a photo of a {}",
|
@@ -83,6 +79,7 @@ def train(args):
|
|
83 |
train_util.prepare_dataset_args(args, True)
|
84 |
|
85 |
cache_latents = args.cache_latents
|
|
|
86 |
|
87 |
if args.seed is not None:
|
88 |
set_seed(args.seed)
|
@@ -142,35 +139,21 @@ def train(args):
|
|
142 |
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
143 |
|
144 |
# データセットを準備する
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
else:
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
}
|
161 |
-
else:
|
162 |
-
print("Train with captions.")
|
163 |
-
user_config = {
|
164 |
-
"datasets": [{
|
165 |
-
"subsets": [{
|
166 |
-
"image_dir": args.train_data_dir,
|
167 |
-
"metadata_file": args.in_json,
|
168 |
-
}]
|
169 |
-
}]
|
170 |
-
}
|
171 |
-
|
172 |
-
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
173 |
-
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
174 |
|
175 |
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
176 |
if use_template:
|
@@ -180,25 +163,20 @@ def train(args):
|
|
180 |
captions = []
|
181 |
for tmpl in templates:
|
182 |
captions.append(tmpl.format(replace_to))
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
else:
|
190 |
-
prompt_replacement = None
|
191 |
|
192 |
if args.debug_dataset:
|
193 |
-
train_util.debug_dataset(
|
194 |
return
|
195 |
-
if len(
|
196 |
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
197 |
return
|
198 |
|
199 |
-
if cache_latents:
|
200 |
-
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
201 |
-
|
202 |
# モデルに xformers とか memory efficient attention を組み込む
|
203 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
204 |
|
@@ -208,7 +186,7 @@ def train(args):
|
|
208 |
vae.requires_grad_(False)
|
209 |
vae.eval()
|
210 |
with torch.no_grad():
|
211 |
-
|
212 |
vae.to("cpu")
|
213 |
if torch.cuda.is_available():
|
214 |
torch.cuda.empty_cache()
|
@@ -220,14 +198,35 @@ def train(args):
|
|
220 |
|
221 |
# 学習に必要なクラスを準備する
|
222 |
print("prepare optimizer, data loader etc.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
trainable_params = text_encoder.get_input_embeddings().parameters()
|
224 |
-
|
|
|
|
|
225 |
|
226 |
# dataloaderを準備する
|
227 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
228 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
229 |
train_dataloader = torch.utils.data.DataLoader(
|
230 |
-
|
231 |
|
232 |
# 学習ステップ数を計算する
|
233 |
if args.max_train_epochs is not None:
|
@@ -235,9 +234,8 @@ def train(args):
|
|
235 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
236 |
|
237 |
# lr schedulerを用意する
|
238 |
-
lr_scheduler =
|
239 |
-
|
240 |
-
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
241 |
|
242 |
# acceleratorがなんかよろしくやってくれるらしい
|
243 |
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
@@ -285,8 +283,8 @@ def train(args):
|
|
285 |
# 学習する
|
286 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
287 |
print("running training / 学習開始")
|
288 |
-
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {
|
289 |
-
print(f" num reg images / 正則化画像の数: {
|
290 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
291 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
292 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
@@ -305,11 +303,12 @@ def train(args):
|
|
305 |
|
306 |
for epoch in range(num_train_epochs):
|
307 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
308 |
-
|
309 |
|
310 |
text_encoder.train()
|
311 |
|
312 |
loss_total = 0
|
|
|
313 |
for step, batch in enumerate(train_dataloader):
|
314 |
with accelerator.accumulate(text_encoder):
|
315 |
with torch.no_grad():
|
@@ -358,9 +357,9 @@ def train(args):
|
|
358 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
359 |
|
360 |
accelerator.backward(loss)
|
361 |
-
if accelerator.sync_gradients
|
362 |
params_to_clip = text_encoder.get_input_embeddings().parameters()
|
363 |
-
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
364 |
|
365 |
optimizer.step()
|
366 |
lr_scheduler.step()
|
@@ -375,14 +374,9 @@ def train(args):
|
|
375 |
progress_bar.update(1)
|
376 |
global_step += 1
|
377 |
|
378 |
-
train_util.sample_images(accelerator, args, None, global_step, accelerator.device,
|
379 |
-
vae, tokenizer, text_encoder, unet, prompt_replacement)
|
380 |
-
|
381 |
current_loss = loss.detach().item()
|
382 |
if args.logging_dir is not None:
|
383 |
-
logs = {"loss": current_loss, "lr":
|
384 |
-
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
385 |
-
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
386 |
accelerator.log(logs, step=global_step)
|
387 |
|
388 |
loss_total += current_loss
|
@@ -400,6 +394,8 @@ def train(args):
|
|
400 |
accelerator.wait_for_everyone()
|
401 |
|
402 |
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
|
|
|
|
403 |
|
404 |
if args.save_every_n_epochs is not None:
|
405 |
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
@@ -421,9 +417,6 @@ def train(args):
|
|
421 |
if saving and args.save_state:
|
422 |
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
423 |
|
424 |
-
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device,
|
425 |
-
vae, tokenizer, text_encoder, unet, prompt_replacement)
|
426 |
-
|
427 |
# end of epoch
|
428 |
|
429 |
is_main_process = accelerator.is_main_process
|
@@ -498,8 +491,6 @@ if __name__ == '__main__':
|
|
498 |
train_util.add_sd_models_arguments(parser)
|
499 |
train_util.add_dataset_arguments(parser, True, True, False)
|
500 |
train_util.add_training_arguments(parser, True)
|
501 |
-
train_util.add_optimizer_arguments(parser)
|
502 |
-
config_util.add_config_arguments(parser)
|
503 |
|
504 |
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
505 |
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|
|
|
11 |
from diffusers import DDPMScheduler
|
12 |
|
13 |
import library.train_util as train_util
|
14 |
+
from library.train_util import DreamBoothDataset, FineTuningDataset
|
|
|
|
|
|
|
|
|
15 |
|
16 |
imagenet_templates_small = [
|
17 |
"a photo of a {}",
|
|
|
79 |
train_util.prepare_dataset_args(args, True)
|
80 |
|
81 |
cache_latents = args.cache_latents
|
82 |
+
use_dreambooth_method = args.in_json is None
|
83 |
|
84 |
if args.seed is not None:
|
85 |
set_seed(args.seed)
|
|
|
139 |
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
140 |
|
141 |
# データセットを準備する
|
142 |
+
if use_dreambooth_method:
|
143 |
+
print("Use DreamBooth method.")
|
144 |
+
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
|
145 |
+
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
|
146 |
+
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
147 |
+
args.bucket_reso_steps, args.bucket_no_upscale,
|
148 |
+
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
|
149 |
else:
|
150 |
+
print("Train with captions.")
|
151 |
+
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
152 |
+
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
153 |
+
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
154 |
+
args.bucket_reso_steps, args.bucket_no_upscale,
|
155 |
+
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
156 |
+
args.dataset_repeats, args.debug_dataset)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
159 |
if use_template:
|
|
|
163 |
captions = []
|
164 |
for tmpl in templates:
|
165 |
captions.append(tmpl.format(replace_to))
|
166 |
+
train_dataset.add_replacement("", captions)
|
167 |
+
elif args.num_vectors_per_token > 1:
|
168 |
+
replace_to = " ".join(token_strings)
|
169 |
+
train_dataset.add_replacement(args.token_string, replace_to)
|
170 |
+
|
171 |
+
train_dataset.make_buckets()
|
|
|
|
|
172 |
|
173 |
if args.debug_dataset:
|
174 |
+
train_util.debug_dataset(train_dataset, show_input_ids=True)
|
175 |
return
|
176 |
+
if len(train_dataset) == 0:
|
177 |
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
178 |
return
|
179 |
|
|
|
|
|
|
|
180 |
# モデルに xformers とか memory efficient attention を組み込む
|
181 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
182 |
|
|
|
186 |
vae.requires_grad_(False)
|
187 |
vae.eval()
|
188 |
with torch.no_grad():
|
189 |
+
train_dataset.cache_latents(vae)
|
190 |
vae.to("cpu")
|
191 |
if torch.cuda.is_available():
|
192 |
torch.cuda.empty_cache()
|
|
|
198 |
|
199 |
# 学習に必要なクラスを準備する
|
200 |
print("prepare optimizer, data loader etc.")
|
201 |
+
|
202 |
+
# 8-bit Adamを使う
|
203 |
+
if args.use_8bit_adam:
|
204 |
+
try:
|
205 |
+
import bitsandbytes as bnb
|
206 |
+
except ImportError:
|
207 |
+
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
208 |
+
print("use 8-bit Adam optimizer")
|
209 |
+
optimizer_class = bnb.optim.AdamW8bit
|
210 |
+
elif args.use_lion_optimizer:
|
211 |
+
try:
|
212 |
+
import lion_pytorch
|
213 |
+
except ImportError:
|
214 |
+
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
215 |
+
print("use Lion optimizer")
|
216 |
+
optimizer_class = lion_pytorch.Lion
|
217 |
+
else:
|
218 |
+
optimizer_class = torch.optim.AdamW
|
219 |
+
|
220 |
trainable_params = text_encoder.get_input_embeddings().parameters()
|
221 |
+
|
222 |
+
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
223 |
+
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
224 |
|
225 |
# dataloaderを準備する
|
226 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
227 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
228 |
train_dataloader = torch.utils.data.DataLoader(
|
229 |
+
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
230 |
|
231 |
# 学習ステップ数を計算する
|
232 |
if args.max_train_epochs is not None:
|
|
|
234 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
235 |
|
236 |
# lr schedulerを用意する
|
237 |
+
lr_scheduler = diffusers.optimization.get_scheduler(
|
238 |
+
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
|
|
239 |
|
240 |
# acceleratorがなんかよろしくやってくれるらしい
|
241 |
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
|
|
283 |
# 学習する
|
284 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
285 |
print("running training / 学習開始")
|
286 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
|
287 |
+
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
|
288 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
289 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
290 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
|
|
303 |
|
304 |
for epoch in range(num_train_epochs):
|
305 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
306 |
+
train_dataset.set_current_epoch(epoch + 1)
|
307 |
|
308 |
text_encoder.train()
|
309 |
|
310 |
loss_total = 0
|
311 |
+
bef_epo_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
312 |
for step, batch in enumerate(train_dataloader):
|
313 |
with accelerator.accumulate(text_encoder):
|
314 |
with torch.no_grad():
|
|
|
357 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
358 |
|
359 |
accelerator.backward(loss)
|
360 |
+
if accelerator.sync_gradients:
|
361 |
params_to_clip = text_encoder.get_input_embeddings().parameters()
|
362 |
+
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
363 |
|
364 |
optimizer.step()
|
365 |
lr_scheduler.step()
|
|
|
374 |
progress_bar.update(1)
|
375 |
global_step += 1
|
376 |
|
|
|
|
|
|
|
377 |
current_loss = loss.detach().item()
|
378 |
if args.logging_dir is not None:
|
379 |
+
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
|
|
380 |
accelerator.log(logs, step=global_step)
|
381 |
|
382 |
loss_total += current_loss
|
|
|
394 |
accelerator.wait_for_everyone()
|
395 |
|
396 |
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
397 |
+
# d = updated_embs - bef_epo_embs
|
398 |
+
# print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
|
399 |
|
400 |
if args.save_every_n_epochs is not None:
|
401 |
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
|
|
417 |
if saving and args.save_state:
|
418 |
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
419 |
|
|
|
|
|
|
|
420 |
# end of epoch
|
421 |
|
422 |
is_main_process = accelerator.is_main_process
|
|
|
491 |
train_util.add_sd_models_arguments(parser)
|
492 |
train_util.add_dataset_arguments(parser, True, True, False)
|
493 |
train_util.add_training_arguments(parser, True)
|
|
|
|
|
494 |
|
495 |
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
496 |
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|