abc commited on
Commit
74be2a5
·
1 Parent(s): fbecb28

Upload 35 files

Browse files
append_module.py CHANGED
@@ -2,19 +2,7 @@ import argparse
2
  import json
3
  import shutil
4
  import time
5
- from typing import (
6
- Dict,
7
- List,
8
- NamedTuple,
9
- Optional,
10
- Sequence,
11
- Tuple,
12
- Union,
13
- )
14
- from dataclasses import (
15
- asdict,
16
- dataclass,
17
- )
18
  from accelerate import Accelerator
19
  from torch.autograd.function import Function
20
  import glob
@@ -40,7 +28,6 @@ import safetensors.torch
40
 
41
  import library.model_util as model_util
42
  import library.train_util as train_util
43
- import library.config_util as config_util
44
 
45
  #============================================================================================================
46
  #AdafactorScheduleに暫定的にinitial_lrを層別に適用できるようにしたもの
@@ -128,124 +115,6 @@ def make_bucket_resolutions_fix(max_reso, min_reso, min_size=256, max_size=1024,
128
  return area_size_resos_list, area_size_list
129
 
130
  #============================================================================================================
131
- #config_util 内より
132
- #============================================================================================================
133
- @dataclass
134
- class DreamBoothDatasetParams(config_util.DreamBoothDatasetParams):
135
- min_resolution: Optional[Tuple[int, int]] = None
136
- area_step : int = 2
137
-
138
- class ConfigSanitizer(config_util.ConfigSanitizer):
139
- #@config_util.curry
140
- @staticmethod
141
- def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
142
- config_util.Schema(config_util.ExactSequence([klass, klass]))(value)
143
- return tuple(value)
144
-
145
- #@config_util.curry
146
- @staticmethod
147
- def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
148
- config_util.Schema(config_util.Any(klass, config_util.ExactSequence([klass, klass])))(value)
149
- try:
150
- config_util.Schema(klass)(value)
151
- return (value, value)
152
- except:
153
- return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
154
- # datasets schema
155
- DATASET_ASCENDABLE_SCHEMA = {
156
- "batch_size": int,
157
- "bucket_no_upscale": bool,
158
- "bucket_reso_steps": int,
159
- "enable_bucket": bool,
160
- "max_bucket_reso": int,
161
- "min_bucket_reso": int,
162
- "resolution": config_util.functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
163
- "min_resolution": config_util.functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
164
- "area_step": int,
165
- }
166
- def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None:
167
- super().__init__(support_dreambooth, support_finetuning, support_dropout)
168
- def _check(self):
169
- print(self.db_dataset_schema)
170
-
171
- class BlueprintGenerator(config_util.BlueprintGenerator):
172
- def __init__(self, sanitizer: ConfigSanitizer):
173
- config_util.DreamBoothDatasetParams = DreamBoothDatasetParams
174
- super().__init__(sanitizer)
175
-
176
- def generate_dataset_group_by_blueprint(dataset_group_blueprint: config_util.DatasetGroupBlueprint):
177
- datasets: List[Union[DreamBoothDataset, train_util.FineTuningDataset]] = []
178
-
179
- for dataset_blueprint in dataset_group_blueprint.datasets:
180
- if dataset_blueprint.is_dreambooth:
181
- subset_klass = train_util.DreamBoothSubset
182
- dataset_klass = DreamBoothDataset
183
- else:
184
- subset_klass = train_util.FineTuningSubset
185
- dataset_klass = train_util.FineTuningDataset
186
-
187
- subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
188
- dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
189
- datasets.append(dataset)
190
-
191
- # print info
192
- info = ""
193
- for i, dataset in enumerate(datasets):
194
- is_dreambooth = isinstance(dataset, DreamBoothDataset)
195
- info += config_util.dedent(f"""\
196
- [Dataset {i}]
197
- batch_size: {dataset.batch_size}
198
- resolution: {(dataset.width, dataset.height)}
199
- enable_bucket: {dataset.enable_bucket}
200
- """)
201
-
202
- if dataset.enable_bucket:
203
- info += config_util.indent(config_util.dedent(f"""\
204
- min_bucket_reso: {dataset.min_bucket_reso}
205
- max_bucket_reso: {dataset.max_bucket_reso}
206
- bucket_reso_steps: {dataset.bucket_reso_steps}
207
- bucket_no_upscale: {dataset.bucket_no_upscale}
208
- \n"""), " ")
209
- else:
210
- info += "\n"
211
-
212
- for j, subset in enumerate(dataset.subsets):
213
- info += config_util.indent(config_util.dedent(f"""\
214
- [Subset {j} of Dataset {i}]
215
- image_dir: "{subset.image_dir}"
216
- image_count: {subset.img_count}
217
- num_repeats: {subset.num_repeats}
218
- shuffle_caption: {subset.shuffle_caption}
219
- keep_tokens: {subset.keep_tokens}
220
- caption_dropout_rate: {subset.caption_dropout_rate}
221
- caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
222
- caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
223
- color_aug: {subset.color_aug}
224
- flip_aug: {subset.flip_aug}
225
- face_crop_aug_range: {subset.face_crop_aug_range}
226
- random_crop: {subset.random_crop}
227
- """), " ")
228
-
229
- if is_dreambooth:
230
- info += config_util.indent(config_util.dedent(f"""\
231
- is_reg: {subset.is_reg}
232
- class_tokens: {subset.class_tokens}
233
- caption_extension: {subset.caption_extension}
234
- \n"""), " ")
235
- else:
236
- info += config_util.indent(config_util.dedent(f"""\
237
- metadata_file: {subset.metadata_file}
238
- \n"""), " ")
239
-
240
- print(info)
241
-
242
- # make buckets first because it determines the length of dataset
243
- for i, dataset in enumerate(datasets):
244
- print(f"[Dataset {i}]")
245
- dataset.make_buckets()
246
-
247
- return train_util.DatasetGroup(datasets)
248
- #============================================================================================================
249
  #train_util 内より
250
  #============================================================================================================
251
  class BucketManager_append(train_util.BucketManager):
@@ -310,7 +179,7 @@ class BucketManager_append(train_util.BucketManager):
310
  bucket_size_id_list.append(bucket_size_id + i + 1)
311
  _min_error = 1000.
312
  _min_id = bucket_size_id
313
- for now_size_id in bucket_size_id_list:
314
  self.predefined_aspect_ratios = self.predefined_aspect_ratios_list[now_size_id]
315
  ar_errors = self.predefined_aspect_ratios - aspect_ratio
316
  ar_error = np.abs(ar_errors).min()
@@ -384,13 +253,13 @@ class BucketManager_append(train_util.BucketManager):
384
  return reso, resized_size, ar_error
385
 
386
  class DreamBoothDataset(train_util.DreamBoothDataset):
387
- def __init__(self, subsets: Sequence[train_util.DreamBoothSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset, min_resolution=None, area_step=None) -> None:
388
  print("use append DreamBoothDataset")
389
  self.min_resolution = min_resolution
390
  self.area_step = area_step
391
- super().__init__(subsets, batch_size, tokenizer, max_token_length,
392
- resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale,
393
- prior_loss_weight, debug_dataset)
394
  def make_buckets(self):
395
  '''
396
  bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
@@ -483,50 +352,40 @@ class DreamBoothDataset(train_util.DreamBoothDataset):
483
  self.shuffle_buckets()
484
  self._length = len(self.buckets_indices)
485
 
486
- import transformers
487
- from torch.optim import Optimizer
488
- from diffusers.optimization import SchedulerType
489
- from typing import Union
490
- def get_scheduler_Adafactor(
491
- name: Union[str, SchedulerType],
492
- optimizer: Optimizer,
493
- scheduler_arg: Dict
494
- ):
495
- if name.startswith("adafactor"):
496
- assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
497
- print(scheduler_arg)
498
- return AdafactorSchedule_append(optimizer, **scheduler_arg)
 
 
 
 
 
 
 
 
499
  #============================================================================================================
500
  #networks.lora
501
  #============================================================================================================
502
- #from networks.lora import LoRANetwork
503
- def replace_prepare_optimizer_params(networks, network_module):
504
- def prepare_optimizer_params(self, text_encoder_lr, unet_lr, loranames=None, lr_dic=None, block_args_dic=None):
505
-
506
  def enumerate_params(loras, lora_name=None):
507
  params = []
508
  for lora in loras:
509
  if lora_name is not None:
510
- get_param_flag = False
511
- if "attentions" in lora_name or "lora_unet_up_blocks_0_resnets_2":
512
- lora_names = [lora_name]
513
- if "attentions" in lora_name:
514
- lora_names.append(lora_name.replace("attentions", "resnets"))
515
- elif "lora_unet_up_blocks_0_resnets_2" in lora_name:
516
- lora_names.append("lora_unet_up_blocks_0_upsamplers_")
517
- elif "lora_unet_up_blocks_1_attentions_2_" in lora_name:
518
- lora_names.append("lora_unet_up_blocks_1_upsamplers_")
519
- elif "lora_unet_up_blocks_2_attentions_2_" in lora_name:
520
- lora_names.append("lora_unet_up_blocks_2_upsamplers_")
521
-
522
- for _name in lora_names:
523
- if _name in lora.lora_name:
524
- get_param_flag = True
525
- break
526
- else:
527
- if lora_name in lora.lora_name:
528
- get_param_flag = True
529
- if get_param_flag: params.extend(lora.parameters())
530
  else:
531
  params.extend(lora.parameters())
532
  return params
@@ -534,7 +393,6 @@ def replace_prepare_optimizer_params(networks, network_module):
534
  self.requires_grad_(True)
535
  all_params = []
536
  ret_scheduler_lr = []
537
- used_names = []
538
 
539
  if loranames is not None:
540
  textencoder_names = [None]
@@ -547,181 +405,37 @@ def replace_prepare_optimizer_params(networks, network_module):
547
  if self.text_encoder_loras:
548
  for textencoder_name in textencoder_names:
549
  param_data = {'params': enumerate_params(self.text_encoder_loras, lora_name=textencoder_name)}
550
- used_names.append(textencoder_name)
551
  if text_encoder_lr is not None:
552
  param_data['lr'] = text_encoder_lr
553
- if lr_dic is not None:
554
- if textencoder_name in lr_dic:
555
- param_data['lr'] = lr_dic[textencoder_name]
556
- print(f"{textencoder_name} lr: {param_data['lr']}")
557
-
558
- if block_args_dic is not None:
559
- if "lora_te_" in block_args_dic:
560
- for pname, value in block_args_dic["lora_te_"].items():
561
- param_data[pname] = value
562
- if textencoder_name in block_args_dic:
563
- for pname, value in block_args_dic[textencoder_name].items():
564
- param_data[pname] = value
565
-
566
- if text_encoder_lr is not None:
567
- ret_scheduler_lr.append(text_encoder_lr)
568
- else:
569
- ret_scheduler_lr.append(0.)
570
- if lr_dic is not None:
571
- if textencoder_name in lr_dic:
572
- ret_scheduler_lr[-1] = lr_dic[textencoder_name]
573
  all_params.append(param_data)
574
 
575
  if self.unet_loras:
576
  for unet_name in unet_names:
577
  param_data = {'params': enumerate_params(self.unet_loras, lora_name=unet_name)}
578
- if len(param_data["params"])==0: continue
579
- used_names.append(unet_name)
580
  if unet_lr is not None:
581
  param_data['lr'] = unet_lr
582
- if lr_dic is not None:
583
- if unet_name in lr_dic:
584
- param_data['lr'] = lr_dic[unet_name]
585
- print(f"{unet_name} lr: {param_data['lr']}")
586
-
587
- if block_args_dic is not None:
588
- if "lora_unet_" in block_args_dic:
589
- for pname, value in block_args_dic["lora_unet_"].items():
590
- param_data[pname] = value
591
- if unet_name in block_args_dic:
592
- for pname, value in block_args_dic[unet_name].items():
593
- param_data[pname] = value
594
-
595
- if unet_lr is not None:
596
- ret_scheduler_lr.append(unet_lr)
597
- else:
598
- ret_scheduler_lr.append(0.)
599
- if lr_dic is not None:
600
- if unet_name in lr_dic:
601
- ret_scheduler_lr[-1] = lr_dic[unet_name]
602
  all_params.append(param_data)
603
 
604
- return all_params, {"initial_lr" : ret_scheduler_lr}, used_names
605
- try:
606
- network_module.LoRANetwork.prepare_optimizer_params = prepare_optimizer_params
607
- except:
608
- print("cant't replace prepare_optimizer_params")
609
 
610
  #============================================================================================================
611
  #新規追加
612
  #============================================================================================================
613
  def add_append_arguments(parser: argparse.ArgumentParser):
614
  # for train_network_opt.py
615
- #parser.add_argument("--optimizer", type=str, default="AdamW", choices=["AdamW", "RAdam", "AdaBound", "AdaBelief", "AggMo", "AdamP", "Adastand", "Adastand_belief", "Apollo", "Lamb", "Ranger", "RangerVA", "Lookahead_Adam", "Lookahead_DiffGrad", "Yogi", "NovoGrad", "QHAdam", "DiffGrad", "MADGRAD", "Adafactor"], help="使用するoptimizerを指定する")
616
- #parser.add_argument("--optimizer_arg", type=str, default=None, nargs='*')
617
- parser.add_argument("--use_lookahead", action="store_true")
618
- parser.add_argument("--lookahead_arg", type=str, nargs="*", default=None)
619
  parser.add_argument("--split_lora_networks", action="store_true")
620
  parser.add_argument("--split_lora_level", type=int, default=0, help="どれくらい細分化するかの設定 0がunetのみを層別に 1がunetを大枠で分割 2がtextencoder含めて層別")
621
- parser.add_argument("--blocks_lr_setting", type=str, default=None)
622
- parser.add_argument("--block_optim_args", type=str, nargs="*", default=None)
623
  parser.add_argument("--min_resolution", type=str, default=None)
624
  parser.add_argument("--area_step", type=int, default=1)
625
  parser.add_argument("--config", type=str, default=None)
626
- parser.add_argument("--not_output_config", action="store_true")
627
-
628
- class MyNetwork_Names:
629
- ex_block_weight_dic = {
630
- "BASE": ["te"],
631
- "IN01": ["down_0_at_0","donw_0_res_0"], "IN02": ["down_0_at_1","down_0_res_1"], "IN03": ["down_0_down"],
632
- "IN04": ["down_1_at_0","donw_1_res_0"], "IN05": ["down_1_at_1","donw_1_res_1"], "IN06": ["down_1_down"],
633
- "IN07": ["down_2_at_0","donw_2_res_0"], "IN08": ["down_2_at_1","donw_2_res_1"], "IN09": ["down_2_down"],
634
- "IN10": ["down_3_res_0"], "IN11": ["down_3_res_1"],
635
- "MID": ["mid"],
636
- "OUT00": ["up_0_res_0"], "OUT01": ["up_0_res_1"], "OUT02": ["up_0_res_2", "up_0_up"],
637
- "OUT03": ["up_1_at_0", "up_1_res_0"], "OUT04": ["up_1_at_1", "up_1_res_1"], "OUT05": ["up_1_at_2", "up_1_res_2", "up_1_up"],
638
- "OUT06": ["up_2_at_0", "up_2_res_0"], "OUT07": ["up_2_at_1", "up_2_res_1"], "OUT08": ["up_2_at_2", "up_2_res_2", "up_2_up"],
639
- "OUT09": ["up_3_at_0", "up_3_res_0"], "OUT10": ["up_3_at_1", "up_3_res_1"], "OUT11": ["up_3_at_2", "up_3_res_2"],
640
- }
641
-
642
- blocks_name_dic = { "te": "lora_te_",
643
- "unet": "lora_unet_",
644
- "mid": "lora_unet_mid_block_",
645
- "down": "lora_unet_down_blocks_",
646
- "up": "lora_unet_up_blocks_"}
647
- for i in range(12):
648
- blocks_name_dic[f"te_{i}"] = f"lora_te_text_model_encoder_layers_{i}_"
649
- for i in range(3):
650
- blocks_name_dic[f"down_{i}"] = f"lora_unet_down_blocks_{i}"
651
- blocks_name_dic[f"up_{i+1}"] = f"lora_unet_up_blocks_{i+1}"
652
- for i in range(4):
653
- for j in range(2):
654
- if i<=2: blocks_name_dic[f"down_{i}_at_{j}"] = f"lora_unet_down_blocks_{i}_attentions_{j}_"
655
- blocks_name_dic[f"down_{i}_res_{j}"] = f"lora_unet_down_blocks_{i}_resnets_{j}"
656
- for j in range(3):
657
- if i>=1: blocks_name_dic[f"up_{i}_at_{j}"] = f"lora_unet_up_blocks_{i}_attentions_{j}_"
658
- blocks_name_dic[f"up_{i}_res_{j}"] = f"lora_unet_up_blocks_{i}_resnets_{j}"
659
- if i<=2:
660
- blocks_name_dic[f"down_{i}_down"] = f"lora_unet_down_blocks_{i}_downsamplers_"
661
- blocks_name_dic[f"up_{i}_up"] = f"lora_unet_up_blocks_{i}_upsamplers_"
662
-
663
- def create_lr_blocks(lr_setting_str=None, block_optim_args=None):
664
- ex_block_weight_dic = MyNetwork_Names.ex_block_weight_dic
665
- blocks_name_dic = MyNetwork_Names.blocks_name_dic
666
-
667
- lr_dic = {}
668
- if lr_setting_str==None or lr_setting_str=="":
669
- pass
670
- else:
671
- lr_settings = lr_setting_str.replace(" ", "").split(",")
672
- for lr_setting in lr_settings:
673
- key, value = lr_setting.split("=")
674
- if key in ex_block_weight_dic:
675
- keys = ex_block_weight_dic[key]
676
- else:
677
- keys = [key]
678
- for key in keys:
679
- if key in blocks_name_dic:
680
- new_key = blocks_name_dic[key]
681
- lr_dic[new_key] = float(value)
682
- if len(lr_dic)==0:
683
- lr_dic = None
684
-
685
- args_dic = {}
686
- if (block_optim_args is None):
687
- block_optim_args = []
688
- if (len(block_optim_args)>0):
689
- for my_arg in block_optim_args:
690
- my_arg = my_arg.replace(" ", "")
691
- splits = my_arg.split(":")
692
- b_name = splits[0]
693
-
694
- key, _value = splits[1].split("=")
695
- value_type = float
696
- if len(splits)==3:
697
- if _value=="str":
698
- value_type = str
699
- elif _value=="int":
700
- value_type = int
701
- _value = splits[2]
702
- if _value=="true" or _value=="false":
703
- value_type = bool
704
- if "," in _value:
705
- _value = _value.split(",")
706
- for i in range(len(_value)):
707
- _value[i] = value_type(_value[i])
708
- value=tuple(_value)
709
- else:
710
- value = value_type(_value)
711
-
712
- if b_name in ex_block_weight_dic:
713
- b_names = ex_block_weight_dic[b_name]
714
- else:
715
- b_names = [b_name]
716
- for b_name in b_names:
717
- new_b_name = blocks_name_dic[b_name]
718
- if not new_b_name in args_dic:
719
- args_dic[new_b_name] = {}
720
- args_dic[new_b_name][key] = value
721
-
722
- if len(args_dic)==0:
723
- args_dic = None
724
- return lr_dic, args_dic
725
 
726
  def create_split_names(split_flag, split_level):
727
  split_names = None
@@ -732,28 +446,14 @@ def create_split_names(split_flag, split_level):
732
  if split_level==1:
733
  unet_names.append(f"lora_unet_down_blocks_")
734
  unet_names.append(f"lora_unet_up_blocks_")
735
- elif split_level==2 or split_level==0 or split_level==4:
736
- if split_level>=2:
737
  text_encoder_names = []
738
  for i in range(12):
739
  text_encoder_names.append(f"lora_te_text_model_encoder_layers_{i}_")
740
-
741
- if split_level<=2:
742
- for i in range(3):
743
- unet_names.append(f"lora_unet_down_blocks_{i}")
744
- unet_names.append(f"lora_unet_up_blocks_{i+1}")
745
-
746
- if split_level>=3:
747
- for i in range(4):
748
- for j in range(2):
749
- if i<=2: unet_names.append(f"lora_unet_down_blocks_{i}_attentions_{j}_")
750
- if i== 3: unet_names.append(f"lora_unet_down_blocks_{i}_resnets_{j}")
751
- for j in range(3):
752
- if i>=1: unet_names.append(f"lora_unet_up_blocks_{i}_attentions_{j}_")
753
- if i==0: unet_names.append(f"lora_unet_up_blocks_{i}_resnets_{j}")
754
- if i<=2:
755
- unet_names.append(f"lora_unet_down_blocks_{i}_downsamplers_")
756
-
757
  split_names["text_encoder"] = text_encoder_names
758
  split_names["unet"] = unet_names
759
  return split_names
@@ -765,7 +465,7 @@ def get_config(parser):
765
  import datetime
766
  if os.path.splitext(args.config)[-1] == ".yaml":
767
  args.config = os.path.splitext(args.config)[0]
768
- config_path = f"{args.config}.yaml"
769
  if os.path.exists(config_path):
770
  print(f"{config_path} から設定を読み込み中...")
771
  margs, rest = parser.parse_known_args()
@@ -786,41 +486,19 @@ def get_config(parser):
786
  args_type_dic[key] = act.type
787
  #データタイプの確認とargsにkeyの内容を代入していく
788
  for key, v in configs.items():
789
- if v is not None:
790
- if key in args_dic:
791
- if args_dic[key] is not None:
792
- new_type = type(args_dic[key])
793
- if (not type(v) == new_type) and (not new_type==list):
794
- v = new_type(v)
795
- else:
796
  if not type(v) == args_type_dic[key]:
797
  v = args_type_dic[key](v)
798
- args_dic[key] = v
799
  #最後にデフォから指定が変わってるものを変更する
800
  for key, v in change_def_dic.items():
801
  args_dic[key] = v
802
  else:
803
  print(f"{config_path} が見つかりませんでした")
804
  return args
805
-
806
- '''
807
- class GradientReversalFunction(torch.autograd.Function):
808
- @staticmethod
809
- def forward(ctx, input_forward: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
810
- ctx.save_for_backward(scale)
811
- return input_forward
812
- @staticmethod
813
- def backward(ctx, grad_backward: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
814
- scale, = ctx.saved_tensors
815
- return scale * -grad_backward, None
816
-
817
- class GradientReversal(torch.nn.Module):
818
- def __init__(self, scale: float):
819
- super(GradientReversal, self).__init__()
820
- self.scale = torch.tensor(scale)
821
- def forward(self, x: torch.Tensor, flag: bool = False) -> torch.Tensor:
822
- if flag:
823
- return x
824
- else:
825
- return GradientReversalFunction.apply(x, self.scale)
826
- '''
 
2
  import json
3
  import shutil
4
  import time
5
+ from typing import Dict, List, NamedTuple, Tuple
 
 
 
 
 
 
 
 
 
 
 
 
6
  from accelerate import Accelerator
7
  from torch.autograd.function import Function
8
  import glob
 
28
 
29
  import library.model_util as model_util
30
  import library.train_util as train_util
 
31
 
32
  #============================================================================================================
33
  #AdafactorScheduleに暫定的にinitial_lrを層別に適用できるようにしたもの
 
115
  return area_size_resos_list, area_size_list
116
 
117
  #============================================================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  #train_util 内より
119
  #============================================================================================================
120
  class BucketManager_append(train_util.BucketManager):
 
179
  bucket_size_id_list.append(bucket_size_id + i + 1)
180
  _min_error = 1000.
181
  _min_id = bucket_size_id
182
+ for now_size_id in bucket_size_id:
183
  self.predefined_aspect_ratios = self.predefined_aspect_ratios_list[now_size_id]
184
  ar_errors = self.predefined_aspect_ratios - aspect_ratio
185
  ar_error = np.abs(ar_errors).min()
 
253
  return reso, resized_size, ar_error
254
 
255
  class DreamBoothDataset(train_util.DreamBoothDataset):
256
+ def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset, min_resolution=None, area_step=None) -> None:
257
  print("use append DreamBoothDataset")
258
  self.min_resolution = min_resolution
259
  self.area_step = area_step
260
+ super().__init__(batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens,
261
+ resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight,
262
+ flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
263
  def make_buckets(self):
264
  '''
265
  bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
 
352
  self.shuffle_buckets()
353
  self._length = len(self.buckets_indices)
354
 
355
+ class FineTuningDataset(train_util.FineTuningDataset):
356
+ def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
357
+ train_util.glob_images = glob_images
358
+ super().__init__( json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
359
+ resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range,
360
+ random_crop, dataset_repeats, debug_dataset)
361
+
362
+ def glob_images(directory, base="*", npz_flag=True):
363
+ img_paths = []
364
+ dots = []
365
+ for ext in train_util.IMAGE_EXTENSIONS:
366
+ dots.append(ext)
367
+ if npz_flag:
368
+ dots.append(".npz")
369
+ for ext in dots:
370
+ if base == '*':
371
+ img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
372
+ else:
373
+ img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
374
+ return img_paths
375
+
376
  #============================================================================================================
377
  #networks.lora
378
  #============================================================================================================
379
+ from networks.lora import LoRANetwork
380
+ def replace_prepare_optimizer_params(networks):
381
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, scheduler_lr=None, loranames=None):
382
+
383
  def enumerate_params(loras, lora_name=None):
384
  params = []
385
  for lora in loras:
386
  if lora_name is not None:
387
+ if lora_name in lora.lora_name:
388
+ params.extend(lora.parameters())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  else:
390
  params.extend(lora.parameters())
391
  return params
 
393
  self.requires_grad_(True)
394
  all_params = []
395
  ret_scheduler_lr = []
 
396
 
397
  if loranames is not None:
398
  textencoder_names = [None]
 
405
  if self.text_encoder_loras:
406
  for textencoder_name in textencoder_names:
407
  param_data = {'params': enumerate_params(self.text_encoder_loras, lora_name=textencoder_name)}
 
408
  if text_encoder_lr is not None:
409
  param_data['lr'] = text_encoder_lr
410
+ if scheduler_lr is not None:
411
+ ret_scheduler_lr.append(scheduler_lr[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  all_params.append(param_data)
413
 
414
  if self.unet_loras:
415
  for unet_name in unet_names:
416
  param_data = {'params': enumerate_params(self.unet_loras, lora_name=unet_name)}
 
 
417
  if unet_lr is not None:
418
  param_data['lr'] = unet_lr
419
+ if scheduler_lr is not None:
420
+ ret_scheduler_lr.append(scheduler_lr[1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  all_params.append(param_data)
422
 
423
+ return all_params, ret_scheduler_lr
424
+
425
+ LoRANetwork.prepare_optimizer_params = prepare_optimizer_params
 
 
426
 
427
  #============================================================================================================
428
  #新規追加
429
  #============================================================================================================
430
  def add_append_arguments(parser: argparse.ArgumentParser):
431
  # for train_network_opt.py
432
+ parser.add_argument("--optimizer", type=str, default="AdamW", choices=["AdamW", "RAdam", "AdaBound", "AdaBelief", "AggMo", "AdamP", "Adastand", "Adastand_belief", "Apollo", "Lamb", "Ranger", "RangerVA", "Lookahead_Adam", "Lookahead_DiffGrad", "Yogi", "NovoGrad", "QHAdam", "DiffGrad", "MADGRAD", "Adafactor"], help="使用するoptimizerを指定する")
433
+ parser.add_argument("--optimizer_arg", type=str, default=None, nargs='*')
 
 
434
  parser.add_argument("--split_lora_networks", action="store_true")
435
  parser.add_argument("--split_lora_level", type=int, default=0, help="どれくらい細分化するかの設定 0がunetのみを層別に 1がunetを大枠で分割 2がtextencoder含めて層別")
 
 
436
  parser.add_argument("--min_resolution", type=str, default=None)
437
  parser.add_argument("--area_step", type=int, default=1)
438
  parser.add_argument("--config", type=str, default=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
 
440
  def create_split_names(split_flag, split_level):
441
  split_names = None
 
446
  if split_level==1:
447
  unet_names.append(f"lora_unet_down_blocks_")
448
  unet_names.append(f"lora_unet_up_blocks_")
449
+ elif split_level==2 or split_level==0:
450
+ if split_level==2:
451
  text_encoder_names = []
452
  for i in range(12):
453
  text_encoder_names.append(f"lora_te_text_model_encoder_layers_{i}_")
454
+ for i in range(3):
455
+ unet_names.append(f"lora_unet_down_blocks_{i}")
456
+ unet_names.append(f"lora_unet_up_blocks_{i+1}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  split_names["text_encoder"] = text_encoder_names
458
  split_names["unet"] = unet_names
459
  return split_names
 
465
  import datetime
466
  if os.path.splitext(args.config)[-1] == ".yaml":
467
  args.config = os.path.splitext(args.config)[0]
468
+ config_path = f"./{args.config}.yaml"
469
  if os.path.exists(config_path):
470
  print(f"{config_path} から設定を読み込み中...")
471
  margs, rest = parser.parse_known_args()
 
486
  args_type_dic[key] = act.type
487
  #データタイプの確認とargsにkeyの内容を代入していく
488
  for key, v in configs.items():
489
+ if key in args_dic:
490
+ if args_dic[key] is not None:
491
+ new_type = type(args_dic[key])
492
+ if (not type(v) == new_type) and (not new_type==list):
493
+ v = new_type(v)
494
+ else:
495
+ if v is not None:
496
  if not type(v) == args_type_dic[key]:
497
  v = args_type_dic[key](v)
498
+ args_dic[key] = v
499
  #最後にデフォから指定が変わってるものを変更する
500
  for key, v in change_def_dic.items():
501
  args_dic[key] = v
502
  else:
503
  print(f"{config_path} が見つかりませんでした")
504
  return args
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fine_tune.py CHANGED
@@ -13,11 +13,7 @@ import diffusers
13
  from diffusers import DDPMScheduler
14
 
15
  import library.train_util as train_util
16
- import library.config_util as config_util
17
- from library.config_util import (
18
- ConfigSanitizer,
19
- BlueprintGenerator,
20
- )
21
 
22
  def collate_fn(examples):
23
  return examples[0]
@@ -34,36 +30,25 @@ def train(args):
34
 
35
  tokenizer = train_util.load_tokenizer(args)
36
 
37
- blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
38
- if args.dataset_config is not None:
39
- print(f"Load dataset config from {args.dataset_config}")
40
- user_config = config_util.load_user_config(args.dataset_config)
41
- ignored = ["train_data_dir", "in_json"]
42
- if any(getattr(args, attr) is not None for attr in ignored):
43
- print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
44
- else:
45
- user_config = {
46
- "datasets": [{
47
- "subsets": [{
48
- "image_dir": args.train_data_dir,
49
- "metadata_file": args.in_json,
50
- }]
51
- }]
52
- }
53
-
54
- blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
55
- train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
56
 
57
  if args.debug_dataset:
58
- train_util.debug_dataset(train_dataset_group)
59
  return
60
- if len(train_dataset_group) == 0:
61
  print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
62
  return
63
 
64
- if cache_latents:
65
- assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
66
-
67
  # acceleratorを準備する
68
  print("prepare accelerator")
69
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
@@ -124,7 +109,7 @@ def train(args):
124
  vae.requires_grad_(False)
125
  vae.eval()
126
  with torch.no_grad():
127
- train_dataset_group.cache_latents(vae)
128
  vae.to("cpu")
129
  if torch.cuda.is_available():
130
  torch.cuda.empty_cache()
@@ -164,13 +149,33 @@ def train(args):
164
 
165
  # 学習に必要なクラスを準備する
166
  print("prepare optimizer, data loader etc.")
167
- _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  # dataloaderを準備する
170
  # DataLoaderのプロセス数:0はメインプロセスになる
171
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
172
  train_dataloader = torch.utils.data.DataLoader(
173
- train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
174
 
175
  # 学習ステップ数を計算する
176
  if args.max_train_epochs is not None:
@@ -178,9 +183,8 @@ def train(args):
178
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
179
 
180
  # lr schedulerを用意する
181
- lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
182
- num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
183
- num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
184
 
185
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
186
  if args.full_fp16:
@@ -214,7 +218,7 @@ def train(args):
214
  # 学習する
215
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
216
  print("running training / 学習開始")
217
- print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
218
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
219
  print(f" num epochs / epoch数: {num_train_epochs}")
220
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
@@ -233,7 +237,7 @@ def train(args):
233
 
234
  for epoch in range(num_train_epochs):
235
  print(f"epoch {epoch+1}/{num_train_epochs}")
236
- train_dataset_group.set_current_epoch(epoch + 1)
237
 
238
  for m in training_models:
239
  m.train()
@@ -282,11 +286,11 @@ def train(args):
282
  loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
283
 
284
  accelerator.backward(loss)
285
- if accelerator.sync_gradients and args.max_grad_norm != 0.0:
286
  params_to_clip = []
287
  for m in training_models:
288
  params_to_clip.extend(m.parameters())
289
- accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
290
 
291
  optimizer.step()
292
  lr_scheduler.step()
@@ -297,16 +301,11 @@ def train(args):
297
  progress_bar.update(1)
298
  global_step += 1
299
 
300
- train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
301
-
302
  current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
303
  if args.logging_dir is not None:
304
- logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
305
- if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
306
- logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
307
  accelerator.log(logs, step=global_step)
308
 
309
- # TODO moving averageにする
310
  loss_total += current_loss
311
  avr_loss = loss_total / (step+1)
312
  logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
@@ -316,7 +315,7 @@ def train(args):
316
  break
317
 
318
  if args.logging_dir is not None:
319
- logs = {"loss/epoch": loss_total / len(train_dataloader)}
320
  accelerator.log(logs, step=epoch+1)
321
 
322
  accelerator.wait_for_everyone()
@@ -326,8 +325,6 @@ def train(args):
326
  train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
327
  save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
328
 
329
- train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
330
-
331
  is_main_process = accelerator.is_main_process
332
  if is_main_process:
333
  unet = unwrap_model(unet)
@@ -354,8 +351,6 @@ if __name__ == '__main__':
354
  train_util.add_dataset_arguments(parser, False, True, True)
355
  train_util.add_training_arguments(parser, False)
356
  train_util.add_sd_saving_arguments(parser)
357
- train_util.add_optimizer_arguments(parser)
358
- config_util.add_config_arguments(parser)
359
 
360
  parser.add_argument("--diffusers_xformers", action='store_true',
361
  help='use xformers by diffusers / Diffusersでxformersを使用する')
 
13
  from diffusers import DDPMScheduler
14
 
15
  import library.train_util as train_util
16
+
 
 
 
 
17
 
18
  def collate_fn(examples):
19
  return examples[0]
 
30
 
31
  tokenizer = train_util.load_tokenizer(args)
32
 
33
+ train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
34
+ tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
35
+ args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
36
+ args.bucket_reso_steps, args.bucket_no_upscale,
37
+ args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
38
+ args.dataset_repeats, args.debug_dataset)
39
+
40
+ # 学習データのdropout率を設定する
41
+ train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
42
+
43
+ train_dataset.make_buckets()
 
 
 
 
 
 
 
 
44
 
45
  if args.debug_dataset:
46
+ train_util.debug_dataset(train_dataset)
47
  return
48
+ if len(train_dataset) == 0:
49
  print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
50
  return
51
 
 
 
 
52
  # acceleratorを準備する
53
  print("prepare accelerator")
54
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
 
109
  vae.requires_grad_(False)
110
  vae.eval()
111
  with torch.no_grad():
112
+ train_dataset.cache_latents(vae)
113
  vae.to("cpu")
114
  if torch.cuda.is_available():
115
  torch.cuda.empty_cache()
 
149
 
150
  # 学習に必要なクラスを準備する
151
  print("prepare optimizer, data loader etc.")
152
+
153
+ # 8-bit Adamを使う
154
+ if args.use_8bit_adam:
155
+ try:
156
+ import bitsandbytes as bnb
157
+ except ImportError:
158
+ raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
159
+ print("use 8-bit Adam optimizer")
160
+ optimizer_class = bnb.optim.AdamW8bit
161
+ elif args.use_lion_optimizer:
162
+ try:
163
+ import lion_pytorch
164
+ except ImportError:
165
+ raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
166
+ print("use Lion optimizer")
167
+ optimizer_class = lion_pytorch.Lion
168
+ else:
169
+ optimizer_class = torch.optim.AdamW
170
+
171
+ # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
172
+ optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
173
 
174
  # dataloaderを準備する
175
  # DataLoaderのプロセス数:0はメインプロセスになる
176
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
177
  train_dataloader = torch.utils.data.DataLoader(
178
+ train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
179
 
180
  # 学習ステップ数を計算する
181
  if args.max_train_epochs is not None:
 
183
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
184
 
185
  # lr schedulerを用意する
186
+ lr_scheduler = diffusers.optimization.get_scheduler(
187
+ args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
 
188
 
189
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
190
  if args.full_fp16:
 
218
  # 学習する
219
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
220
  print("running training / 学習開始")
221
+ print(f" num examples / サンプル数: {train_dataset.num_train_images}")
222
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
223
  print(f" num epochs / epoch数: {num_train_epochs}")
224
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
 
237
 
238
  for epoch in range(num_train_epochs):
239
  print(f"epoch {epoch+1}/{num_train_epochs}")
240
+ train_dataset.set_current_epoch(epoch + 1)
241
 
242
  for m in training_models:
243
  m.train()
 
286
  loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
287
 
288
  accelerator.backward(loss)
289
+ if accelerator.sync_gradients:
290
  params_to_clip = []
291
  for m in training_models:
292
  params_to_clip.extend(m.parameters())
293
+ accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
294
 
295
  optimizer.step()
296
  lr_scheduler.step()
 
301
  progress_bar.update(1)
302
  global_step += 1
303
 
 
 
304
  current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
305
  if args.logging_dir is not None:
306
+ logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
 
 
307
  accelerator.log(logs, step=global_step)
308
 
 
309
  loss_total += current_loss
310
  avr_loss = loss_total / (step+1)
311
  logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
 
315
  break
316
 
317
  if args.logging_dir is not None:
318
+ logs = {"epoch_loss": loss_total / len(train_dataloader)}
319
  accelerator.log(logs, step=epoch+1)
320
 
321
  accelerator.wait_for_everyone()
 
325
  train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
326
  save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
327
 
 
 
328
  is_main_process = accelerator.is_main_process
329
  if is_main_process:
330
  unet = unwrap_model(unet)
 
351
  train_util.add_dataset_arguments(parser, False, True, True)
352
  train_util.add_training_arguments(parser, False)
353
  train_util.add_sd_saving_arguments(parser)
 
 
354
 
355
  parser.add_argument("--diffusers_xformers", action='store_true',
356
  help='use xformers by diffusers / Diffusersでxformersを使用する')
gen_img_diffusers.py CHANGED
@@ -47,7 +47,7 @@ VGG(
47
  """
48
 
49
  import json
50
- from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
51
  import glob
52
  import importlib
53
  import inspect
@@ -60,6 +60,7 @@ import math
60
  import os
61
  import random
62
  import re
 
63
 
64
  import diffusers
65
  import numpy as np
@@ -80,9 +81,6 @@ from PIL import Image
80
  from PIL.PngImagePlugin import PngInfo
81
 
82
  import library.model_util as model_util
83
- import library.train_util as train_util
84
- import tools.original_control_net as original_control_net
85
- from tools.original_control_net import ControlNetInfo
86
 
87
  # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
88
  TOKENIZER_PATH = "openai/clip-vit-large-patch14"
@@ -489,9 +487,6 @@ class PipelineLike():
489
  self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
490
  self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
491
 
492
- # ControlNet
493
- self.control_nets: List[ControlNetInfo] = []
494
-
495
  # Textual Inversion
496
  def add_token_replacement(self, target_token_id, rep_token_ids):
497
  self.token_replacements[target_token_id] = rep_token_ids
@@ -505,11 +500,7 @@ class PipelineLike():
505
  new_tokens.append(token)
506
  return new_tokens
507
 
508
- def set_control_nets(self, ctrl_nets):
509
- self.control_nets = ctrl_nets
510
-
511
  # region xformersとか使う部分:独自に書き換えるので関係なし
512
-
513
  def enable_xformers_memory_efficient_attention(self):
514
  r"""
515
  Enable memory efficient attention as implemented in xformers.
@@ -590,8 +581,6 @@ class PipelineLike():
590
  latents: Optional[torch.FloatTensor] = None,
591
  max_embeddings_multiples: Optional[int] = 3,
592
  output_type: Optional[str] = "pil",
593
- vae_batch_size: float = None,
594
- return_latents: bool = False,
595
  # return_dict: bool = True,
596
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
597
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
@@ -683,9 +672,6 @@ class PipelineLike():
683
  else:
684
  raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
685
 
686
- vae_batch_size = batch_size if vae_batch_size is None else (
687
- int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size)))
688
-
689
  if strength < 0 or strength > 1:
690
  raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
691
 
@@ -766,7 +752,7 @@ class PipelineLike():
766
  text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
767
  text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
768
 
769
- if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None or self.control_nets:
770
  if isinstance(clip_guide_images, PIL.Image.Image):
771
  clip_guide_images = [clip_guide_images]
772
 
@@ -779,7 +765,7 @@ class PipelineLike():
779
  image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
780
  if len(image_embeddings_clip) == 1:
781
  image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
782
- elif self.vgg16_guidance_scale > 0:
783
  size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
784
  clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
785
  clip_guide_images = torch.cat(clip_guide_images, dim=0)
@@ -788,10 +774,6 @@ class PipelineLike():
788
  image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
789
  if len(image_embeddings_vgg16) == 1:
790
  image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
791
- else:
792
- # ControlNetのhintにguide imageを流用する
793
- # 前処理はControlNet側で行う
794
- pass
795
 
796
  # set timesteps
797
  self.scheduler.set_timesteps(num_inference_steps, self.device)
@@ -799,6 +781,7 @@ class PipelineLike():
799
  latents_dtype = text_embeddings.dtype
800
  init_latents_orig = None
801
  mask = None
 
802
 
803
  if init_image is None:
804
  # get the initial random noise unless the user supplied it
@@ -830,8 +813,6 @@ class PipelineLike():
830
  if isinstance(init_image[0], PIL.Image.Image):
831
  init_image = [preprocess_image(im) for im in init_image]
832
  init_image = torch.cat(init_image)
833
- if isinstance(init_image, list):
834
- init_image = torch.stack(init_image)
835
 
836
  # mask image to tensor
837
  if mask_image is not None:
@@ -842,24 +823,9 @@ class PipelineLike():
842
 
843
  # encode the init image into latents and scale the latents
844
  init_image = init_image.to(device=self.device, dtype=latents_dtype)
845
- if init_image.size()[2:] == (height // 8, width // 8):
846
- init_latents = init_image
847
- else:
848
- if vae_batch_size >= batch_size:
849
- init_latent_dist = self.vae.encode(init_image).latent_dist
850
- init_latents = init_latent_dist.sample(generator=generator)
851
- else:
852
- if torch.cuda.is_available():
853
- torch.cuda.empty_cache()
854
- init_latents = []
855
- for i in tqdm(range(0, batch_size, vae_batch_size)):
856
- init_latent_dist = self.vae.encode(init_image[i:i + vae_batch_size]
857
- if vae_batch_size > 1 else init_image[i].unsqueeze(0)).latent_dist
858
- init_latents.append(init_latent_dist.sample(generator=generator))
859
- init_latents = torch.cat(init_latents)
860
-
861
- init_latents = 0.18215 * init_latents
862
-
863
  if len(init_latents) == 1:
864
  init_latents = init_latents.repeat((batch_size, 1, 1, 1))
865
  init_latents_orig = init_latents
@@ -898,21 +864,12 @@ class PipelineLike():
898
  extra_step_kwargs["eta"] = eta
899
 
900
  num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
901
-
902
- if self.control_nets:
903
- guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
904
-
905
  for i, t in enumerate(tqdm(timesteps)):
906
  # expand the latents if we are doing classifier free guidance
907
  latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
908
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
909
-
910
  # predict the noise residual
911
- if self.control_nets:
912
- noise_pred = original_control_net.call_unet_and_control_net(
913
- i, num_latent_input, self.unet, self.control_nets, guided_hints, i / len(timesteps), latent_model_input, t, text_embeddings).sample
914
- else:
915
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
916
 
917
  # perform guidance
918
  if do_classifier_free_guidance:
@@ -954,19 +911,8 @@ class PipelineLike():
954
  if is_cancelled_callback is not None and is_cancelled_callback():
955
  return None
956
 
957
- if return_latents:
958
- return (latents, False)
959
-
960
  latents = 1 / 0.18215 * latents
961
- if vae_batch_size >= batch_size:
962
- image = self.vae.decode(latents).sample
963
- else:
964
- if torch.cuda.is_available():
965
- torch.cuda.empty_cache()
966
- images = []
967
- for i in tqdm(range(0, batch_size, vae_batch_size)):
968
- images.append(self.vae.decode(latents[i:i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample)
969
- image = torch.cat(images)
970
 
971
  image = (image / 2 + 0.5).clamp(0, 1)
972
 
@@ -1853,7 +1799,7 @@ def preprocess_mask(mask):
1853
  mask = mask.convert("L")
1854
  w, h = mask.size
1855
  w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
1856
- mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS)
1857
  mask = np.array(mask).astype(np.float32) / 255.0
1858
  mask = np.tile(mask, (4, 1, 1))
1859
  mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
@@ -1871,35 +1817,6 @@ def preprocess_mask(mask):
1871
  # return text_encoder
1872
 
1873
 
1874
- class BatchDataBase(NamedTuple):
1875
- # バッチ分割が必要ないデータ
1876
- step: int
1877
- prompt: str
1878
- negative_prompt: str
1879
- seed: int
1880
- init_image: Any
1881
- mask_image: Any
1882
- clip_prompt: str
1883
- guide_image: Any
1884
-
1885
-
1886
- class BatchDataExt(NamedTuple):
1887
- # バッチ分割が必要なデータ
1888
- width: int
1889
- height: int
1890
- steps: int
1891
- scale: float
1892
- negative_scale: float
1893
- strength: float
1894
- network_muls: Tuple[float]
1895
-
1896
-
1897
- class BatchData(NamedTuple):
1898
- return_latents: bool
1899
- base: BatchDataBase
1900
- ext: BatchDataExt
1901
-
1902
-
1903
  def main(args):
1904
  if args.fp16:
1905
  dtype = torch.float16
@@ -1964,7 +1881,10 @@ def main(args):
1964
  # tokenizerを読み込む
1965
  print("loading tokenizer")
1966
  if use_stable_diffusion_format:
1967
- tokenizer = train_util.load_tokenizer(args)
 
 
 
1968
 
1969
  # schedulerを用意する
1970
  sched_init_args = {}
@@ -2075,13 +1995,11 @@ def main(args):
2075
  # networkを組み込む
2076
  if args.network_module:
2077
  networks = []
2078
- network_default_muls = []
2079
  for i, network_module in enumerate(args.network_module):
2080
  print("import network module:", network_module)
2081
  imported_module = importlib.import_module(network_module)
2082
 
2083
  network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
2084
- network_default_muls.append(network_mul)
2085
 
2086
  net_kwargs = {}
2087
  if args.network_args and i < len(args.network_args):
@@ -2096,7 +2014,7 @@ def main(args):
2096
  network_weight = args.network_weights[i]
2097
  print("load network weights from:", network_weight)
2098
 
2099
- if model_util.is_safetensors(network_weight) and args.network_show_meta:
2100
  from safetensors.torch import safe_open
2101
  with safe_open(network_weight, framework="pt") as f:
2102
  metadata = f.metadata()
@@ -2119,18 +2037,6 @@ def main(args):
2119
  else:
2120
  networks = []
2121
 
2122
- # ControlNetの処理
2123
- control_nets: List[ControlNetInfo] = []
2124
- if args.control_net_models:
2125
- for i, model in enumerate(args.control_net_models):
2126
- prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
2127
- weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
2128
- ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
2129
-
2130
- ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
2131
- prep = original_control_net.load_preprocess(prep_type)
2132
- control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
2133
-
2134
  if args.opt_channels_last:
2135
  print(f"set optimizing: channels last")
2136
  text_encoder.to(memory_format=torch.channels_last)
@@ -2144,14 +2050,9 @@ def main(args):
2144
  if vgg16_model is not None:
2145
  vgg16_model.to(memory_format=torch.channels_last)
2146
 
2147
- for cn in control_nets:
2148
- cn.unet.to(memory_format=torch.channels_last)
2149
- cn.net.to(memory_format=torch.channels_last)
2150
-
2151
  pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
2152
  clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
2153
  vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
2154
- pipe.set_control_nets(control_nets)
2155
  print("pipeline is ready.")
2156
 
2157
  if args.diffusers_xformers:
@@ -2285,12 +2186,9 @@ def main(args):
2285
 
2286
  prev_image = None # for VGG16 guided
2287
  if args.guide_image_path is not None:
2288
- print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
2289
- guide_images = []
2290
- for p in args.guide_image_path:
2291
- guide_images.extend(load_images(p))
2292
-
2293
- print(f"loaded {len(guide_images)} guide images for guidance")
2294
  if len(guide_images) == 0:
2295
  print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
2296
  guide_images = None
@@ -2321,46 +2219,33 @@ def main(args):
2321
  iter_seed = random.randint(0, 0x7fffffff)
2322
 
2323
  # バッチ処理の関数
2324
- def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
2325
  batch_size = len(batch)
2326
 
2327
  # highres_fixの処理
2328
  if highres_fix and not highres_1st:
2329
- # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
2330
- print("process 1st stage")
2331
  batch_1st = []
2332
- for _, base, ext in batch:
2333
- width_1st = int(ext.width * args.highres_fix_scale + .5)
2334
- height_1st = int(ext.height * args.highres_fix_scale + .5)
2335
  width_1st = width_1st - width_1st % 32
2336
  height_1st = height_1st - height_1st % 32
2337
-
2338
- ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale,
2339
- ext.negative_scale, ext.strength, ext.network_muls)
2340
- batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
2341
  images_1st = process_batch(batch_1st, True, True)
2342
 
2343
  # 2nd stageのバッチを作成して以下処理する
2344
- print("process 2nd stage")
2345
- if args.highres_fix_latents_upscaling:
2346
- org_dtype = images_1st.dtype
2347
- if images_1st.dtype == torch.bfloat16:
2348
- images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
2349
- images_1st = torch.nn.functional.interpolate(
2350
- images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode='bilinear') # , antialias=True)
2351
- images_1st = images_1st.to(org_dtype)
2352
-
2353
  batch_2nd = []
2354
- for i, (bd, image) in enumerate(zip(batch, images_1st)):
2355
- if not args.highres_fix_latents_upscaling:
2356
- image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
2357
- bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext)
2358
- batch_2nd.append(bd_2nd)
2359
  batch = batch_2nd
2360
 
2361
- # このバッチの情報を取り出す
2362
- return_latents, (step_first, _, _, _, init_image, mask_image, _, guide_image), \
2363
- (width, height, steps, scale, negative_scale, strength, network_muls) = batch[0]
2364
  noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
2365
 
2366
  prompts = []
@@ -2393,7 +2278,7 @@ def main(args):
2393
  all_images_are_same = True
2394
  all_masks_are_same = True
2395
  all_guide_images_are_same = True
2396
- for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
2397
  prompts.append(prompt)
2398
  negative_prompts.append(negative_prompt)
2399
  seeds.append(seed)
@@ -2410,13 +2295,9 @@ def main(args):
2410
  all_masks_are_same = mask_images[-2] is mask_image
2411
 
2412
  if guide_image is not None:
2413
- if type(guide_image) is list:
2414
- guide_images.extend(guide_image)
2415
- all_guide_images_are_same = False
2416
- else:
2417
- guide_images.append(guide_image)
2418
- if i > 0 and all_guide_images_are_same:
2419
- all_guide_images_are_same = guide_images[-2] is guide_image
2420
 
2421
  # make start code
2422
  torch.manual_seed(seed)
@@ -2439,24 +2320,10 @@ def main(args):
2439
  if guide_images is not None and all_guide_images_are_same:
2440
  guide_images = guide_images[0]
2441
 
2442
- # ControlNet使用時はguide imageをリサイズする
2443
- if control_nets:
2444
- # TODO resampleのメソッド
2445
- guide_images = guide_images if type(guide_images) == list else [guide_images]
2446
- guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images]
2447
- if len(guide_images) == 1:
2448
- guide_images = guide_images[0]
2449
-
2450
  # generate
2451
- if networks:
2452
- for n, m in zip(networks, network_muls if network_muls else network_default_muls):
2453
- n.set_multiplier(m)
2454
-
2455
  images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
2456
- output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises,
2457
- vae_batch_size=args.vae_batch_size, return_latents=return_latents,
2458
- clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
2459
- if highres_1st and not args.highres_fix_save_1st: # return images or latents
2460
  return images
2461
 
2462
  # save image
@@ -2531,7 +2398,6 @@ def main(args):
2531
  strength = 0.8 if args.strength is None else args.strength
2532
  negative_prompt = ""
2533
  clip_prompt = None
2534
- network_muls = None
2535
 
2536
  prompt_args = prompt.strip().split(' --')
2537
  prompt = prompt_args[0]
@@ -2595,15 +2461,6 @@ def main(args):
2595
  clip_prompt = m.group(1)
2596
  print(f"clip prompt: {clip_prompt}")
2597
  continue
2598
-
2599
- m = re.match(r'am ([\d\.\-,]+)', parg, re.IGNORECASE)
2600
- if m: # network multiplies
2601
- network_muls = [float(v) for v in m.group(1).split(",")]
2602
- while len(network_muls) < len(networks):
2603
- network_muls.append(network_muls[-1])
2604
- print(f"network mul: {network_muls}")
2605
- continue
2606
-
2607
  except ValueError as ex:
2608
  print(f"Exception in parsing / 解析エラー: {parg}")
2609
  print(ex)
@@ -2641,12 +2498,7 @@ def main(args):
2641
  mask_image = mask_images[global_step % len(mask_images)]
2642
 
2643
  if guide_images is not None:
2644
- if control_nets: # 複数件の場合あり
2645
- c = len(control_nets)
2646
- p = global_step % (len(guide_images) // c)
2647
- guide_image = guide_images[p * c:p * c + c]
2648
- else:
2649
- guide_image = guide_images[global_step % len(guide_images)]
2650
  elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
2651
  if prev_image is None:
2652
  print("Generate 1st image without guide image.")
@@ -2654,9 +2506,10 @@ def main(args):
2654
  print("Use previous image as guide image.")
2655
  guide_image = prev_image
2656
 
2657
- b1 = BatchData(False, BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
2658
- BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None))
2659
- if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
 
2660
  process_batch(batch_data, highres_fix)
2661
  batch_data.clear()
2662
 
@@ -2700,8 +2553,6 @@ if __name__ == '__main__':
2700
  parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
2701
  parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
2702
  parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
2703
- parser.add_argument("--vae_batch_size", type=float, default=None,
2704
- help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率")
2705
  parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
2706
  parser.add_argument('--sampler', type=str, default='ddim',
2707
  choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
@@ -2713,8 +2564,6 @@ if __name__ == '__main__':
2713
  parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
2714
  parser.add_argument("--vae", type=str, default=None,
2715
  help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
2716
- parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
2717
- help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
2718
  # parser.add_argument("--replace_clip_l14_336", action='store_true',
2719
  # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
2720
  parser.add_argument("--seed", type=int, default=None,
@@ -2729,15 +2578,12 @@ if __name__ == '__main__':
2729
  parser.add_argument("--opt_channels_last", action='store_true',
2730
  help='set channels last option to model / モデルにchannels lastを指定し最適化する')
2731
  parser.add_argument("--network_module", type=str, default=None, nargs='*',
2732
- help='additional network module to use / 追加ネットワークを使う時そのモジュール名')
2733
  parser.add_argument("--network_weights", type=str, default=None, nargs='*',
2734
- help='additional network weights to load / 追加ネットワークの重み')
2735
- parser.add_argument("--network_mul", type=float, default=None, nargs='*',
2736
- help='additional network multiplier / 追加ネットワークの効果の倍率')
2737
  parser.add_argument("--network_args", type=str, default=None, nargs='*',
2738
  help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
2739
- parser.add_argument("--network_show_meta", action='store_true',
2740
- help='show metadata of network model / ネットワークモデルのメタデータを表示する')
2741
  parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
2742
  help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
2743
  parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
@@ -2751,26 +2597,15 @@ if __name__ == '__main__':
2751
  help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
2752
  parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
2753
  help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
2754
- parser.add_argument("--guide_image_path", type=str, default=None, nargs="*",
2755
- help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
2756
  parser.add_argument("--highres_fix_scale", type=float, default=None,
2757
  help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
2758
  parser.add_argument("--highres_fix_steps", type=int, default=28,
2759
  help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
2760
  parser.add_argument("--highres_fix_save_1st", action='store_true',
2761
  help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
2762
- parser.add_argument("--highres_fix_latents_upscaling", action='store_true',
2763
- help="use latents upscaling for highres fix / highres fixでlatentで拡大する")
2764
  parser.add_argument("--negative_scale", type=float, default=None,
2765
  help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
2766
 
2767
- parser.add_argument("--control_net_models", type=str, default=None, nargs='*',
2768
- help='ControlNet models to use / 使用するControlNetのモデル名')
2769
- parser.add_argument("--control_net_preps", type=str, default=None, nargs='*',
2770
- help='ControlNet preprocess to use / 使用するControlNetのプリプロセス名')
2771
- parser.add_argument("--control_net_weights", type=float, default=None, nargs='*', help='ControlNet weights / ControlNetの重み')
2772
- parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*',
2773
- help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率')
2774
-
2775
  args = parser.parse_args()
2776
  main(args)
 
47
  """
48
 
49
  import json
50
+ from typing import List, Optional, Union
51
  import glob
52
  import importlib
53
  import inspect
 
60
  import os
61
  import random
62
  import re
63
+ from typing import Any, Callable, List, Optional, Union
64
 
65
  import diffusers
66
  import numpy as np
 
81
  from PIL.PngImagePlugin import PngInfo
82
 
83
  import library.model_util as model_util
 
 
 
84
 
85
  # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
86
  TOKENIZER_PATH = "openai/clip-vit-large-patch14"
 
487
  self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
488
  self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
489
 
 
 
 
490
  # Textual Inversion
491
  def add_token_replacement(self, target_token_id, rep_token_ids):
492
  self.token_replacements[target_token_id] = rep_token_ids
 
500
  new_tokens.append(token)
501
  return new_tokens
502
 
 
 
 
503
  # region xformersとか使う部分:独自に書き換えるので関係なし
 
504
  def enable_xformers_memory_efficient_attention(self):
505
  r"""
506
  Enable memory efficient attention as implemented in xformers.
 
581
  latents: Optional[torch.FloatTensor] = None,
582
  max_embeddings_multiples: Optional[int] = 3,
583
  output_type: Optional[str] = "pil",
 
 
584
  # return_dict: bool = True,
585
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
586
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
 
672
  else:
673
  raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
674
 
 
 
 
675
  if strength < 0 or strength > 1:
676
  raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
677
 
 
752
  text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
753
  text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
754
 
755
+ if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None:
756
  if isinstance(clip_guide_images, PIL.Image.Image):
757
  clip_guide_images = [clip_guide_images]
758
 
 
765
  image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
766
  if len(image_embeddings_clip) == 1:
767
  image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
768
+ else:
769
  size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
770
  clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
771
  clip_guide_images = torch.cat(clip_guide_images, dim=0)
 
774
  image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
775
  if len(image_embeddings_vgg16) == 1:
776
  image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
 
 
 
 
777
 
778
  # set timesteps
779
  self.scheduler.set_timesteps(num_inference_steps, self.device)
 
781
  latents_dtype = text_embeddings.dtype
782
  init_latents_orig = None
783
  mask = None
784
+ noise = None
785
 
786
  if init_image is None:
787
  # get the initial random noise unless the user supplied it
 
813
  if isinstance(init_image[0], PIL.Image.Image):
814
  init_image = [preprocess_image(im) for im in init_image]
815
  init_image = torch.cat(init_image)
 
 
816
 
817
  # mask image to tensor
818
  if mask_image is not None:
 
823
 
824
  # encode the init image into latents and scale the latents
825
  init_image = init_image.to(device=self.device, dtype=latents_dtype)
826
+ init_latent_dist = self.vae.encode(init_image).latent_dist
827
+ init_latents = init_latent_dist.sample(generator=generator)
828
+ init_latents = 0.18215 * init_latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
829
  if len(init_latents) == 1:
830
  init_latents = init_latents.repeat((batch_size, 1, 1, 1))
831
  init_latents_orig = init_latents
 
864
  extra_step_kwargs["eta"] = eta
865
 
866
  num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
 
 
 
 
867
  for i, t in enumerate(tqdm(timesteps)):
868
  # expand the latents if we are doing classifier free guidance
869
  latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
870
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
871
  # predict the noise residual
872
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
 
 
 
 
873
 
874
  # perform guidance
875
  if do_classifier_free_guidance:
 
911
  if is_cancelled_callback is not None and is_cancelled_callback():
912
  return None
913
 
 
 
 
914
  latents = 1 / 0.18215 * latents
915
+ image = self.vae.decode(latents).sample
 
 
 
 
 
 
 
 
916
 
917
  image = (image / 2 + 0.5).clamp(0, 1)
918
 
 
1799
  mask = mask.convert("L")
1800
  w, h = mask.size
1801
  w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
1802
+ mask = mask.resize((w // 8, h // 8), resample=PIL.Image.LANCZOS)
1803
  mask = np.array(mask).astype(np.float32) / 255.0
1804
  mask = np.tile(mask, (4, 1, 1))
1805
  mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
 
1817
  # return text_encoder
1818
 
1819
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1820
  def main(args):
1821
  if args.fp16:
1822
  dtype = torch.float16
 
1881
  # tokenizerを読み込む
1882
  print("loading tokenizer")
1883
  if use_stable_diffusion_format:
1884
+ if args.v2:
1885
+ tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
1886
+ else:
1887
+ tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
1888
 
1889
  # schedulerを用意する
1890
  sched_init_args = {}
 
1995
  # networkを組み込む
1996
  if args.network_module:
1997
  networks = []
 
1998
  for i, network_module in enumerate(args.network_module):
1999
  print("import network module:", network_module)
2000
  imported_module = importlib.import_module(network_module)
2001
 
2002
  network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
 
2003
 
2004
  net_kwargs = {}
2005
  if args.network_args and i < len(args.network_args):
 
2014
  network_weight = args.network_weights[i]
2015
  print("load network weights from:", network_weight)
2016
 
2017
+ if model_util.is_safetensors(network_weight):
2018
  from safetensors.torch import safe_open
2019
  with safe_open(network_weight, framework="pt") as f:
2020
  metadata = f.metadata()
 
2037
  else:
2038
  networks = []
2039
 
 
 
 
 
 
 
 
 
 
 
 
 
2040
  if args.opt_channels_last:
2041
  print(f"set optimizing: channels last")
2042
  text_encoder.to(memory_format=torch.channels_last)
 
2050
  if vgg16_model is not None:
2051
  vgg16_model.to(memory_format=torch.channels_last)
2052
 
 
 
 
 
2053
  pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
2054
  clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
2055
  vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
 
2056
  print("pipeline is ready.")
2057
 
2058
  if args.diffusers_xformers:
 
2186
 
2187
  prev_image = None # for VGG16 guided
2188
  if args.guide_image_path is not None:
2189
+ print(f"load image for CLIP/VGG16 guidance: {args.guide_image_path}")
2190
+ guide_images = load_images(args.guide_image_path)
2191
+ print(f"loaded {len(guide_images)} guide images for CLIP/VGG16 guidance")
 
 
 
2192
  if len(guide_images) == 0:
2193
  print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
2194
  guide_images = None
 
2219
  iter_seed = random.randint(0, 0x7fffffff)
2220
 
2221
  # バッチ処理の関数
2222
+ def process_batch(batch, highres_fix, highres_1st=False):
2223
  batch_size = len(batch)
2224
 
2225
  # highres_fixの処理
2226
  if highres_fix and not highres_1st:
2227
+ # 1st stageのバッチを作成して呼び出す
2228
+ print("process 1st stage1")
2229
  batch_1st = []
2230
+ for params1, (width, height, steps, scale, negative_scale, strength) in batch:
2231
+ width_1st = int(width * args.highres_fix_scale + .5)
2232
+ height_1st = int(height * args.highres_fix_scale + .5)
2233
  width_1st = width_1st - width_1st % 32
2234
  height_1st = height_1st - height_1st % 32
2235
+ batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, negative_scale, strength)))
 
 
 
2236
  images_1st = process_batch(batch_1st, True, True)
2237
 
2238
  # 2nd stageのバッチを作成して以下処理する
2239
+ print("process 2nd stage1")
 
 
 
 
 
 
 
 
2240
  batch_2nd = []
2241
+ for i, (b1, image) in enumerate(zip(batch, images_1st)):
2242
+ image = image.resize((width, height), resample=PIL.Image.LANCZOS)
2243
+ (step, prompt, negative_prompt, seed, _, _, clip_prompt, guide_image), params2 = b1
2244
+ batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2))
 
2245
  batch = batch_2nd
2246
 
2247
+ (step_first, _, _, _, init_image, mask_image, _, guide_image), (width,
2248
+ height, steps, scale, negative_scale, strength) = batch[0]
 
2249
  noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
2250
 
2251
  prompts = []
 
2278
  all_images_are_same = True
2279
  all_masks_are_same = True
2280
  all_guide_images_are_same = True
2281
+ for i, ((_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
2282
  prompts.append(prompt)
2283
  negative_prompts.append(negative_prompt)
2284
  seeds.append(seed)
 
2295
  all_masks_are_same = mask_images[-2] is mask_image
2296
 
2297
  if guide_image is not None:
2298
+ guide_images.append(guide_image)
2299
+ if i > 0 and all_guide_images_are_same:
2300
+ all_guide_images_are_same = guide_images[-2] is guide_image
 
 
 
 
2301
 
2302
  # make start code
2303
  torch.manual_seed(seed)
 
2320
  if guide_images is not None and all_guide_images_are_same:
2321
  guide_images = guide_images[0]
2322
 
 
 
 
 
 
 
 
 
2323
  # generate
 
 
 
 
2324
  images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
2325
+ output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
2326
+ if highres_1st and not args.highres_fix_save_1st:
 
 
2327
  return images
2328
 
2329
  # save image
 
2398
  strength = 0.8 if args.strength is None else args.strength
2399
  negative_prompt = ""
2400
  clip_prompt = None
 
2401
 
2402
  prompt_args = prompt.strip().split(' --')
2403
  prompt = prompt_args[0]
 
2461
  clip_prompt = m.group(1)
2462
  print(f"clip prompt: {clip_prompt}")
2463
  continue
 
 
 
 
 
 
 
 
 
2464
  except ValueError as ex:
2465
  print(f"Exception in parsing / 解析エラー: {parg}")
2466
  print(ex)
 
2498
  mask_image = mask_images[global_step % len(mask_images)]
2499
 
2500
  if guide_images is not None:
2501
+ guide_image = guide_images[global_step % len(guide_images)]
 
 
 
 
 
2502
  elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
2503
  if prev_image is None:
2504
  print("Generate 1st image without guide image.")
 
2506
  print("Use previous image as guide image.")
2507
  guide_image = prev_image
2508
 
2509
+ # TODO named tupleか何かにする
2510
+ b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
2511
+ (width, height, steps, scale, negative_scale, strength))
2512
+ if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
2513
  process_batch(batch_data, highres_fix)
2514
  batch_data.clear()
2515
 
 
2553
  parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
2554
  parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
2555
  parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
 
 
2556
  parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
2557
  parser.add_argument('--sampler', type=str, default='ddim',
2558
  choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
 
2564
  parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
2565
  parser.add_argument("--vae", type=str, default=None,
2566
  help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
 
 
2567
  # parser.add_argument("--replace_clip_l14_336", action='store_true',
2568
  # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
2569
  parser.add_argument("--seed", type=int, default=None,
 
2578
  parser.add_argument("--opt_channels_last", action='store_true',
2579
  help='set channels last option to model / モデルにchannels lastを指定し最適化する')
2580
  parser.add_argument("--network_module", type=str, default=None, nargs='*',
2581
+ help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
2582
  parser.add_argument("--network_weights", type=str, default=None, nargs='*',
2583
+ help='Hypernetwork weights to load / Hypernetworkの重み')
2584
+ parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
 
2585
  parser.add_argument("--network_args", type=str, default=None, nargs='*',
2586
  help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
 
 
2587
  parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
2588
  help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
2589
  parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
 
2597
  help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
2598
  parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
2599
  help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
2600
+ parser.add_argument("--guide_image_path", type=str, default=None, help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
 
2601
  parser.add_argument("--highres_fix_scale", type=float, default=None,
2602
  help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
2603
  parser.add_argument("--highres_fix_steps", type=int, default=28,
2604
  help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
2605
  parser.add_argument("--highres_fix_save_1st", action='store_true',
2606
  help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
 
 
2607
  parser.add_argument("--negative_scale", type=float, default=None,
2608
  help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
2609
 
 
 
 
 
 
 
 
 
2610
  args = parser.parse_args()
2611
  main(args)
library/train_util.py CHANGED
@@ -1,21 +1,12 @@
1
  # common functions for training
2
 
3
  import argparse
4
- import importlib
5
  import json
6
- import re
7
  import shutil
8
  import time
9
- from typing import (
10
- Dict,
11
- List,
12
- NamedTuple,
13
- Optional,
14
- Sequence,
15
- Tuple,
16
- Union,
17
- )
18
  from accelerate import Accelerator
 
19
  import glob
20
  import math
21
  import os
@@ -26,16 +17,10 @@ from io import BytesIO
26
 
27
  from tqdm import tqdm
28
  import torch
29
- from torch.optim import Optimizer
30
  from torchvision import transforms
31
  from transformers import CLIPTokenizer
32
- import transformers
33
  import diffusers
34
- from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
35
- from diffusers import (StableDiffusionPipeline, DDPMScheduler,
36
- EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler,
37
- LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler,
38
- KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler)
39
  import albumentations as albu
40
  import numpy as np
41
  from PIL import Image
@@ -210,93 +195,23 @@ class BucketBatchIndex(NamedTuple):
210
  batch_index: int
211
 
212
 
213
- class AugHelper:
214
- def __init__(self):
215
- # prepare all possible augmentators
216
- color_aug_method = albu.OneOf([
217
- albu.HueSaturationValue(8, 0, 0, p=.5),
218
- albu.RandomGamma((95, 105), p=.5),
219
- ], p=.33)
220
- flip_aug_method = albu.HorizontalFlip(p=0.5)
221
-
222
- # key: (use_color_aug, use_flip_aug)
223
- self.augmentors = {
224
- (True, True): albu.Compose([
225
- color_aug_method,
226
- flip_aug_method,
227
- ], p=1.),
228
- (True, False): albu.Compose([
229
- color_aug_method,
230
- ], p=1.),
231
- (False, True): albu.Compose([
232
- flip_aug_method,
233
- ], p=1.),
234
- (False, False): None
235
- }
236
-
237
- def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]:
238
- return self.augmentors[(use_color_aug, use_flip_aug)]
239
-
240
-
241
- class BaseSubset:
242
- def __init__(self, image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, keep_tokens: int, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], random_crop: bool, caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float) -> None:
243
- self.image_dir = image_dir
244
- self.num_repeats = num_repeats
245
- self.shuffle_caption = shuffle_caption
246
- self.keep_tokens = keep_tokens
247
- self.color_aug = color_aug
248
- self.flip_aug = flip_aug
249
- self.face_crop_aug_range = face_crop_aug_range
250
- self.random_crop = random_crop
251
- self.caption_dropout_rate = caption_dropout_rate
252
- self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
253
- self.caption_tag_dropout_rate = caption_tag_dropout_rate
254
-
255
- self.img_count = 0
256
-
257
-
258
- class DreamBoothSubset(BaseSubset):
259
- def __init__(self, image_dir: str, is_reg: bool, class_tokens: Optional[str], caption_extension: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
260
- assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
261
-
262
- super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
263
- face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
264
-
265
- self.is_reg = is_reg
266
- self.class_tokens = class_tokens
267
- self.caption_extension = caption_extension
268
-
269
- def __eq__(self, other) -> bool:
270
- if not isinstance(other, DreamBoothSubset):
271
- return NotImplemented
272
- return self.image_dir == other.image_dir
273
-
274
- class FineTuningSubset(BaseSubset):
275
- def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
276
- assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
277
-
278
- super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
279
- face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
280
-
281
- self.metadata_file = metadata_file
282
-
283
- def __eq__(self, other) -> bool:
284
- if not isinstance(other, FineTuningSubset):
285
- return NotImplemented
286
- return self.metadata_file == other.metadata_file
287
-
288
  class BaseDataset(torch.utils.data.Dataset):
289
- def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None:
290
  super().__init__()
291
- self.tokenizer = tokenizer
292
  self.max_token_length = max_token_length
 
 
293
  # width/height is used when enable_bucket==False
294
  self.width, self.height = (None, None) if resolution is None else resolution
 
 
 
295
  self.debug_dataset = debug_dataset
296
-
297
- self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
298
-
299
  self.token_padding_disabled = False
 
 
300
  self.tag_frequency = {}
301
 
302
  self.enable_bucket = False
@@ -310,28 +225,49 @@ class BaseDataset(torch.utils.data.Dataset):
310
  self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
311
 
312
  self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
 
 
 
313
 
314
  # augmentation
315
- self.aug_helper = AugHelper()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
  self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
318
 
319
  self.image_data: Dict[str, ImageInfo] = {}
320
- self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
321
 
322
  self.replacements = {}
323
 
324
  def set_current_epoch(self, epoch):
325
  self.current_epoch = epoch
326
- self.shuffle_buckets()
 
 
 
 
 
327
 
328
  def set_tag_frequency(self, dir_name, captions):
329
  frequency_for_dir = self.tag_frequency.get(dir_name, {})
330
  self.tag_frequency[dir_name] = frequency_for_dir
331
  for caption in captions:
332
  for tag in caption.split(","):
333
- tag = tag.strip()
334
- if tag:
335
  tag = tag.lower()
336
  frequency = frequency_for_dir.get(tag, 0)
337
  frequency_for_dir[tag] = frequency + 1
@@ -342,36 +278,42 @@ class BaseDataset(torch.utils.data.Dataset):
342
  def add_replacement(self, str_from, str_to):
343
  self.replacements[str_from] = str_to
344
 
345
- def process_caption(self, subset: BaseSubset, caption):
346
  # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
347
- is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate
348
- is_drop_out = is_drop_out or subset.caption_dropout_every_n_epochs > 0 and self.current_epoch % subset.caption_dropout_every_n_epochs == 0
349
 
350
  if is_drop_out:
351
  caption = ""
352
  else:
353
- if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0:
354
  def dropout_tags(tokens):
355
- if subset.caption_tag_dropout_rate <= 0:
356
  return tokens
357
  l = []
358
  for token in tokens:
359
- if random.random() >= subset.caption_tag_dropout_rate:
360
  l.append(token)
361
  return l
362
 
363
- fixed_tokens = []
364
- flex_tokens = [t.strip() for t in caption.strip().split(",")]
365
- if subset.keep_tokens > 0:
366
- fixed_tokens = flex_tokens[:subset.keep_tokens]
367
- flex_tokens = flex_tokens[subset.keep_tokens:]
 
 
 
 
 
368
 
369
- if subset.shuffle_caption:
370
- random.shuffle(flex_tokens)
371
 
372
- flex_tokens = dropout_tags(flex_tokens)
373
 
374
- caption = ", ".join(fixed_tokens + flex_tokens)
 
375
 
376
  # textual inversion対応
377
  for str_from, str_to in self.replacements.items():
@@ -425,9 +367,8 @@ class BaseDataset(torch.utils.data.Dataset):
425
  input_ids = torch.stack(iids_list) # 3,77
426
  return input_ids
427
 
428
- def register_image(self, info: ImageInfo, subset: BaseSubset):
429
  self.image_data[info.image_key] = info
430
- self.image_to_subset[info.image_key] = subset
431
 
432
  def make_buckets(self):
433
  '''
@@ -526,7 +467,7 @@ class BaseDataset(torch.utils.data.Dataset):
526
  img = np.array(image, np.uint8)
527
  return img
528
 
529
- def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size):
530
  image_height, image_width = image.shape[0:2]
531
 
532
  if image_width != resized_size[0] or image_height != resized_size[1]:
@@ -536,27 +477,22 @@ class BaseDataset(torch.utils.data.Dataset):
536
  image_height, image_width = image.shape[0:2]
537
  if image_width > reso[0]:
538
  trim_size = image_width - reso[0]
539
- p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
540
  # print("w", trim_size, p)
541
  image = image[:, p:p + reso[0]]
542
  if image_height > reso[1]:
543
  trim_size = image_height - reso[1]
544
- p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
545
  # print("h", trim_size, p)
546
  image = image[p:p + reso[1]]
547
 
548
  assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
549
  return image
550
 
551
- def is_latent_cacheable(self):
552
- return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
553
-
554
  def cache_latents(self, vae):
555
  # TODO ここを高速化したい
556
  print("caching latents.")
557
  for info in tqdm(self.image_data.values()):
558
- subset = self.image_to_subset[info.image_key]
559
-
560
  if info.latents_npz is not None:
561
  info.latents = self.load_latents_from_npz(info, False)
562
  info.latents = torch.FloatTensor(info.latents)
@@ -566,13 +502,13 @@ class BaseDataset(torch.utils.data.Dataset):
566
  continue
567
 
568
  image = self.load_image(info.absolute_path)
569
- image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
570
 
571
  img_tensor = self.image_transforms(image)
572
  img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
573
  info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
574
 
575
- if subset.flip_aug:
576
  image = image[:, ::-1].copy() # cannot convert to Tensor without copy
577
  img_tensor = self.image_transforms(image)
578
  img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
@@ -582,11 +518,11 @@ class BaseDataset(torch.utils.data.Dataset):
582
  image = Image.open(image_path)
583
  return image.size
584
 
585
- def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
586
  img = self.load_image(image_path)
587
 
588
  face_cx = face_cy = face_w = face_h = 0
589
- if subset.face_crop_aug_range is not None:
590
  tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
591
  if len(tokens) >= 5:
592
  face_cx = int(tokens[-4])
@@ -597,7 +533,7 @@ class BaseDataset(torch.utils.data.Dataset):
597
  return img, face_cx, face_cy, face_w, face_h
598
 
599
  # いい感じに切り出す
600
- def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h):
601
  height, width = image.shape[0:2]
602
  if height == self.height and width == self.width:
603
  return image
@@ -605,8 +541,8 @@ class BaseDataset(torch.utils.data.Dataset):
605
  # 画像サイズはsizeより大きいのでリサイズする
606
  face_size = max(face_w, face_h)
607
  min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
608
- min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
609
- max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
610
  if min_scale >= max_scale: # range指定がmin==max
611
  scale = min_scale
612
  else:
@@ -624,13 +560,13 @@ class BaseDataset(torch.utils.data.Dataset):
624
  for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
625
  p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
626
 
627
- if subset.random_crop:
628
  # 背景も含めるために顔を中心に置く確率を高めつつずらす
629
  range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
630
  p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
631
  else:
632
  # range指定があるときのみ、すこしだけランダムに(わりと適当)
633
- if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]:
634
  if face_size > self.size // 10 and face_size >= 40:
635
  p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
636
 
@@ -653,6 +589,9 @@ class BaseDataset(torch.utils.data.Dataset):
653
  return self._length
654
 
655
  def __getitem__(self, index):
 
 
 
656
  bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
657
  bucket_batch_size = self.buckets_indices[index].bucket_batch_size
658
  image_index = self.buckets_indices[index].batch_index * bucket_batch_size
@@ -665,29 +604,28 @@ class BaseDataset(torch.utils.data.Dataset):
665
 
666
  for image_key in bucket[image_index:image_index + bucket_batch_size]:
667
  image_info = self.image_data[image_key]
668
- subset = self.image_to_subset[image_key]
669
  loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
670
 
671
  # image/latentsを処理する
672
  if image_info.latents is not None:
673
- latents = image_info.latents if not subset.flip_aug or random.random() < .5 else image_info.latents_flipped
674
  image = None
675
  elif image_info.latents_npz is not None:
676
- latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= .5)
677
  latents = torch.FloatTensor(latents)
678
  image = None
679
  else:
680
  # 画像を読み込み、必要ならcropする
681
- img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path)
682
  im_h, im_w = img.shape[0:2]
683
 
684
  if self.enable_bucket:
685
- img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size)
686
  else:
687
  if face_cx > 0: # 顔位置情報あり
688
- img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
689
  elif im_h > self.height or im_w > self.width:
690
- assert subset.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
691
  if im_h > self.height:
692
  p = random.randint(0, im_h - self.height)
693
  img = img[p:p + self.height]
@@ -699,9 +637,8 @@ class BaseDataset(torch.utils.data.Dataset):
699
  assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
700
 
701
  # augmentation
702
- aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
703
- if aug is not None:
704
- img = aug(image=img)['image']
705
 
706
  latents = None
707
  image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
@@ -709,7 +646,7 @@ class BaseDataset(torch.utils.data.Dataset):
709
  images.append(image)
710
  latents_list.append(latents)
711
 
712
- caption = self.process_caption(subset, image_info.caption)
713
  captions.append(caption)
714
  if not self.token_padding_disabled: # this option might be omitted in future
715
  input_ids_list.append(self.get_input_ids(caption))
@@ -740,8 +677,9 @@ class BaseDataset(torch.utils.data.Dataset):
740
 
741
 
742
  class DreamBoothDataset(BaseDataset):
743
- def __init__(self, subsets: Sequence[DreamBoothSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset) -> None:
744
- super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
 
745
 
746
  assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
747
 
@@ -764,7 +702,7 @@ class DreamBoothDataset(BaseDataset):
764
  self.bucket_reso_steps = None # この情報は使われない
765
  self.bucket_no_upscale = False
766
 
767
- def read_caption(img_path, caption_extension):
768
  # captionの候補ファイル名を作る
769
  base_name = os.path.splitext(img_path)[0]
770
  base_name_face_det = base_name
@@ -787,171 +725,153 @@ class DreamBoothDataset(BaseDataset):
787
  break
788
  return caption
789
 
790
- def load_dreambooth_dir(subset: DreamBoothSubset):
791
- if not os.path.isdir(subset.image_dir):
792
- print(f"not directory: {subset.image_dir}")
793
- return [], []
 
 
 
 
 
 
 
794
 
795
- img_paths = glob_images(subset.image_dir, "*")
796
- print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
 
797
 
798
  # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
799
  captions = []
800
  for img_path in img_paths:
801
- cap_for_img = read_caption(img_path, subset.caption_extension)
802
- if cap_for_img is None and subset.class_tokens is None:
803
- print(f"neither caption file nor class tokens are found. use empty caption for {img_path}")
804
- captions.append("")
805
- else:
806
- captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
807
-
808
- self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
809
 
810
- return img_paths, captions
811
-
812
- print("prepare images.")
813
- num_train_images = 0
814
- num_reg_images = 0
815
- reg_infos: List[ImageInfo] = []
816
- for subset in subsets:
817
- if subset.num_repeats < 1:
818
- print(f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
819
- continue
820
-
821
- if subset in self.subsets:
822
- print(f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
823
- continue
824
 
825
- img_paths, captions = load_dreambooth_dir(subset)
826
- if len(img_paths) < 1:
827
- print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
828
- continue
829
 
830
- if subset.is_reg:
831
- num_reg_images += subset.num_repeats * len(img_paths)
832
- else:
833
- num_train_images += subset.num_repeats * len(img_paths)
 
 
834
 
835
  for img_path, caption in zip(img_paths, captions):
836
- info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
837
- if subset.is_reg:
838
- reg_infos.append(info)
839
- else:
840
- self.register_image(info, subset)
841
 
842
- subset.img_count = len(img_paths)
843
- self.subsets.append(subset)
844
 
845
  print(f"{num_train_images} train images with repeating.")
846
  self.num_train_images = num_train_images
847
 
848
- print(f"{num_reg_images} reg images.")
849
- if num_train_images < num_reg_images:
850
- print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
851
-
852
- if num_reg_images == 0:
853
- print("no regularization images / 正則化画像が見つかりませんでした")
854
- else:
855
- # num_repeatsを計算する:どうせ大した数ではないのでループで処理する
856
- n = 0
857
- first_loop = True
858
- while n < num_train_images:
859
- for info in reg_infos:
860
- if first_loop:
861
- self.register_image(info, subset)
862
- n += info.num_repeats
863
- else:
864
- info.num_repeats += 1
865
- n += 1
866
- if n >= num_train_images:
867
- break
868
- first_loop = False
869
-
870
- self.num_reg_images = num_reg_images
871
-
872
-
873
- class FineTuningDataset(BaseDataset):
874
- def __init__(self, subsets: Sequence[FineTuningSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset) -> None:
875
- super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
876
 
877
- self.batch_size = batch_size
 
 
 
878
 
879
- self.num_train_images = 0
880
- self.num_reg_images = 0
 
881
 
882
- for subset in subsets:
883
- if subset.num_repeats < 1:
884
- print(f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
885
- continue
886
 
887
- if subset in self.subsets:
888
- print(f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
889
- continue
890
 
891
- # メタデータを読み込む
892
- if os.path.exists(subset.metadata_file):
893
- print(f"loading existing metadata: {subset.metadata_file}")
894
- with open(subset.metadata_file, "rt", encoding='utf-8') as f:
895
- metadata = json.load(f)
896
  else:
897
- raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}")
 
 
 
 
 
 
 
 
 
 
 
 
 
898
 
899
- if len(metadata) < 1:
900
- print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します")
901
- continue
902
-
903
- tags_list = []
904
- for image_key, img_md in metadata.items():
905
- # path情報を作る
906
- if os.path.exists(image_key):
907
- abs_path = image_key
908
- else:
909
- # わりといい加減だがいい方法が思いつかん
910
- abs_path = glob_images(subset.image_dir, image_key)
911
- assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
912
- abs_path = abs_path[0]
913
-
914
- caption = img_md.get('caption')
915
- tags = img_md.get('tags')
916
- if caption is None:
917
- caption = tags
918
- elif tags is not None and len(tags) > 0:
919
- caption = caption + ', ' + tags
920
- tags_list.append(tags)
921
- assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
922
 
923
- image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
924
- image_info.image_size = img_md.get('train_resolution')
925
 
926
- if not subset.color_aug and not subset.random_crop:
927
- # if npz exists, use them
928
- image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
 
 
 
 
 
 
 
 
 
929
 
930
- self.register_image(image_info, subset)
 
 
931
 
932
- self.num_train_images += len(metadata) * subset.num_repeats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
933
 
934
- # TODO do not record tag freq when no tag
935
- self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list)
936
- subset.img_count = len(metadata)
937
- self.subsets.append(subset)
938
 
939
  # check existence of all npz files
940
- use_npz_latents = all([not(subset.color_aug or subset.random_crop) for subset in self.subsets])
941
  if use_npz_latents:
942
- flip_aug_in_subset = False
943
  npz_any = False
944
  npz_all = True
945
-
946
  for image_info in self.image_data.values():
947
- subset = self.image_to_subset[image_info.image_key]
948
-
949
  has_npz = image_info.latents_npz is not None
950
  npz_any = npz_any or has_npz
951
 
952
- if subset.flip_aug:
953
  has_npz = has_npz and image_info.latents_npz_flipped is not None
954
- flip_aug_in_subset = True
955
  npz_all = npz_all and has_npz
956
 
957
  if npz_any and not npz_all:
@@ -963,7 +883,7 @@ class FineTuningDataset(BaseDataset):
963
  elif not npz_all:
964
  use_npz_latents = False
965
  print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
966
- if flip_aug_in_subset:
967
  print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
968
  # else:
969
  # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
@@ -1009,7 +929,7 @@ class FineTuningDataset(BaseDataset):
1009
  for image_info in self.image_data.values():
1010
  image_info.latents_npz = image_info.latents_npz_flipped = None
1011
 
1012
- def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
1013
  base_name = os.path.splitext(image_key)[0]
1014
  npz_file_norm = base_name + '.npz'
1015
 
@@ -1021,8 +941,8 @@ class FineTuningDataset(BaseDataset):
1021
  return npz_file_norm, npz_file_flip
1022
 
1023
  # image_key is relative path
1024
- npz_file_norm = os.path.join(subset.image_dir, image_key + '.npz')
1025
- npz_file_flip = os.path.join(subset.image_dir, image_key + '_flip.npz')
1026
 
1027
  if not os.path.exists(npz_file_norm):
1028
  npz_file_norm = None
@@ -1033,60 +953,13 @@ class FineTuningDataset(BaseDataset):
1033
  return npz_file_norm, npz_file_flip
1034
 
1035
 
1036
- # behave as Dataset mock
1037
- class DatasetGroup(torch.utils.data.ConcatDataset):
1038
- def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
1039
- self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]]
1040
-
1041
- super().__init__(datasets)
1042
-
1043
- self.image_data = {}
1044
- self.num_train_images = 0
1045
- self.num_reg_images = 0
1046
-
1047
- # simply concat together
1048
- # TODO: handling image_data key duplication among dataset
1049
- # In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset.
1050
- for dataset in datasets:
1051
- self.image_data.update(dataset.image_data)
1052
- self.num_train_images += dataset.num_train_images
1053
- self.num_reg_images += dataset.num_reg_images
1054
-
1055
- def add_replacement(self, str_from, str_to):
1056
- for dataset in self.datasets:
1057
- dataset.add_replacement(str_from, str_to)
1058
-
1059
- # def make_buckets(self):
1060
- # for dataset in self.datasets:
1061
- # dataset.make_buckets()
1062
-
1063
- def cache_latents(self, vae):
1064
- for i, dataset in enumerate(self.datasets):
1065
- print(f"[Dataset {i}]")
1066
- dataset.cache_latents(vae)
1067
-
1068
- def is_latent_cacheable(self) -> bool:
1069
- return all([dataset.is_latent_cacheable() for dataset in self.datasets])
1070
-
1071
- def set_current_epoch(self, epoch):
1072
- for dataset in self.datasets:
1073
- dataset.set_current_epoch(epoch)
1074
-
1075
- def disable_token_padding(self):
1076
- for dataset in self.datasets:
1077
- dataset.disable_token_padding()
1078
-
1079
-
1080
  def debug_dataset(train_dataset, show_input_ids=False):
1081
  print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
1082
  print("Escape for exit. / Escキーで中断、終了します")
1083
 
1084
  train_dataset.set_current_epoch(1)
1085
  k = 0
1086
- indices = list(range(len(train_dataset)))
1087
- random.shuffle(indices)
1088
- for i, idx in enumerate(indices):
1089
- example = train_dataset[idx]
1090
  if example['latents'] is not None:
1091
  print(f"sample has latents from npz file: {example['latents'].size()}")
1092
  for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
@@ -1491,35 +1364,6 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
1491
  help='enable v-parameterization training / v-parameterization学習を有効にする')
1492
  parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
1493
  help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
1494
- parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
1495
- help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
1496
-
1497
-
1498
- def add_optimizer_arguments(parser: argparse.ArgumentParser):
1499
- parser.add_argument("--optimizer_type", type=str, default="",
1500
- help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor")
1501
-
1502
- # backward compatibility
1503
- parser.add_argument("--use_8bit_adam", action="store_true",
1504
- help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
1505
- parser.add_argument("--use_lion_optimizer", action="store_true",
1506
- help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
1507
-
1508
- parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
1509
- parser.add_argument("--max_grad_norm", default=1.0, type=float,
1510
- help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない")
1511
-
1512
- parser.add_argument("--optimizer_args", type=str, default=None, nargs='*',
1513
- help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")")
1514
-
1515
- parser.add_argument("--lr_scheduler", type=str, default="constant",
1516
- help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor")
1517
- parser.add_argument("--lr_warmup_steps", type=int, default=0,
1518
- help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
1519
- parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
1520
- help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
1521
- parser.add_argument("--lr_scheduler_power", type=float, default=1,
1522
- help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
1523
 
1524
 
1525
  def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
@@ -1543,6 +1387,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
1543
  parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
1544
  parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
1545
  help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
 
 
 
 
1546
  parser.add_argument("--mem_eff_attn", action="store_true",
1547
  help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
1548
  parser.add_argument("--xformers", action="store_true",
@@ -1550,6 +1398,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
1550
  parser.add_argument("--vae", type=str, default=None,
1551
  help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
1552
 
 
1553
  parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
1554
  parser.add_argument("--max_train_epochs", type=int, default=None,
1555
  help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
@@ -1570,23 +1419,15 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
1570
  parser.add_argument("--logging_dir", type=str, default=None,
1571
  help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
1572
  parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
 
 
 
 
1573
  parser.add_argument("--noise_offset", type=float, default=None,
1574
  help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
1575
  parser.add_argument("--lowram", action="store_true",
1576
  help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
1577
 
1578
- parser.add_argument("--sample_every_n_steps", type=int, default=None,
1579
- help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する")
1580
- parser.add_argument("--sample_every_n_epochs", type=int, default=None,
1581
- help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)")
1582
- parser.add_argument("--sample_prompts", type=str, default=None,
1583
- help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル")
1584
- parser.add_argument('--sample_sampler', type=str, default='ddim',
1585
- choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
1586
- 'dpmsolver++', 'dpmsingle',
1587
- 'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'],
1588
- help=f'sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類')
1589
-
1590
  if support_dreambooth:
1591
  # DreamBooth training
1592
  parser.add_argument("--prior_loss_weight", type=float, default=1.0,
@@ -1608,8 +1449,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
1608
  parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
1609
  parser.add_argument("--caption_extention", type=str, default=None,
1610
  help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
1611
- parser.add_argument("--keep_tokens", type=int, default=0,
1612
- help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)")
1613
  parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
1614
  parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
1615
  parser.add_argument("--face_crop_aug_range", type=str, default=None,
@@ -1634,11 +1475,11 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
1634
  if support_caption_dropout:
1635
  # Textual Inversion はcaptionのdropoutをsupportしない
1636
  # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
1637
- parser.add_argument("--caption_dropout_rate", type=float, default=0.0,
1638
  help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
1639
- parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=0,
1640
  help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
1641
- parser.add_argument("--caption_tag_dropout_rate", type=float, default=0.0,
1642
  help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
1643
 
1644
  if support_dreambooth:
@@ -1663,249 +1504,16 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
1663
  # region utils
1664
 
1665
 
1666
- def get_optimizer(args, trainable_params):
1667
- # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
1668
-
1669
- optimizer_type = args.optimizer_type
1670
- if args.use_8bit_adam:
1671
- assert not args.use_lion_optimizer, "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています"
1672
- assert optimizer_type is None or optimizer_type == "", "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています"
1673
- optimizer_type = "AdamW8bit"
1674
-
1675
- elif args.use_lion_optimizer:
1676
- assert optimizer_type is None or optimizer_type == "", "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています"
1677
- optimizer_type = "Lion"
1678
-
1679
- if optimizer_type is None or optimizer_type == "":
1680
- optimizer_type = "AdamW"
1681
- optimizer_type = optimizer_type.lower()
1682
-
1683
- # 引数を分解する:boolとfloat、tupleのみ対応
1684
- optimizer_kwargs = {}
1685
- if args.optimizer_args is not None and len(args.optimizer_args) > 0:
1686
- for arg in args.optimizer_args:
1687
- key, value = arg.split('=')
1688
-
1689
- value = value.split(",")
1690
- for i in range(len(value)):
1691
- if value[i].lower() == "true" or value[i].lower() == "false":
1692
- value[i] = (value[i].lower() == "true")
1693
- else:
1694
- value[i] = float(value[i])
1695
- if len(value) == 1:
1696
- value = value[0]
1697
- else:
1698
- value = tuple(value)
1699
-
1700
- optimizer_kwargs[key] = value
1701
- # print("optkwargs:", optimizer_kwargs)
1702
-
1703
- lr = args.learning_rate
1704
-
1705
- if optimizer_type == "AdamW8bit".lower():
1706
- try:
1707
- import bitsandbytes as bnb
1708
- except ImportError:
1709
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
1710
- print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
1711
- optimizer_class = bnb.optim.AdamW8bit
1712
- optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1713
-
1714
- elif optimizer_type == "SGDNesterov8bit".lower():
1715
- try:
1716
- import bitsandbytes as bnb
1717
- except ImportError:
1718
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
1719
- print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
1720
- if "momentum" not in optimizer_kwargs:
1721
- print(f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
1722
- optimizer_kwargs["momentum"] = 0.9
1723
-
1724
- optimizer_class = bnb.optim.SGD8bit
1725
- optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
1726
-
1727
- elif optimizer_type == "Lion".lower():
1728
- try:
1729
- import lion_pytorch
1730
- except ImportError:
1731
- raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
1732
- print(f"use Lion optimizer | {optimizer_kwargs}")
1733
- optimizer_class = lion_pytorch.Lion
1734
- optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1735
-
1736
- elif optimizer_type == "SGDNesterov".lower():
1737
- print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
1738
- if "momentum" not in optimizer_kwargs:
1739
- print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
1740
- optimizer_kwargs["momentum"] = 0.9
1741
-
1742
- optimizer_class = torch.optim.SGD
1743
- optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
1744
-
1745
- elif optimizer_type == "DAdaptation".lower():
1746
- try:
1747
- import dadaptation
1748
- except ImportError:
1749
- raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
1750
- print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
1751
-
1752
- min_lr = lr
1753
- if type(trainable_params) == list and type(trainable_params[0]) == dict:
1754
- for group in trainable_params:
1755
- min_lr = min(min_lr, group.get("lr", lr))
1756
-
1757
- if min_lr <= 0.1:
1758
- print(
1759
- f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {min_lr}')
1760
- print('recommend option: lr=1.0 / 推奨は1.0です')
1761
-
1762
- optimizer_class = dadaptation.DAdaptAdam
1763
- optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1764
-
1765
- elif optimizer_type == "Adafactor".lower():
1766
- # 引数を確認して適宜補正する
1767
- if "relative_step" not in optimizer_kwargs:
1768
- optimizer_kwargs["relative_step"] = True # default
1769
- if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
1770
- print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします")
1771
- optimizer_kwargs["relative_step"] = True
1772
- print(f"use Adafactor optimizer | {optimizer_kwargs}")
1773
-
1774
- if optimizer_kwargs["relative_step"]:
1775
- print(f"relative_step is true / relative_stepがtrueです")
1776
- if lr != 0.0:
1777
- print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
1778
- args.learning_rate = None
1779
-
1780
- # trainable_paramsがgroupだった時の処理:lrを削除する
1781
- if type(trainable_params) == list and type(trainable_params[0]) == dict:
1782
- has_group_lr = False
1783
- for group in trainable_params:
1784
- p = group.pop("lr", None)
1785
- has_group_lr = has_group_lr or (p is not None)
1786
-
1787
- if has_group_lr:
1788
- # 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない
1789
- print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます")
1790
- args.unet_lr = None
1791
- args.text_encoder_lr = None
1792
-
1793
- if args.lr_scheduler != "adafactor":
1794
- print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
1795
- args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
1796
-
1797
- lr = None
1798
- else:
1799
- if args.max_grad_norm != 0.0:
1800
- print(f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません")
1801
- if args.lr_scheduler != "constant_with_warmup":
1802
- print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
1803
- if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
1804
- print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
1805
-
1806
- optimizer_class = transformers.optimization.Adafactor
1807
- optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1808
-
1809
- elif optimizer_type == "AdamW".lower():
1810
- print(f"use AdamW optimizer | {optimizer_kwargs}")
1811
- optimizer_class = torch.optim.AdamW
1812
- optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1813
-
1814
- else:
1815
- # 任意のoptimizerを使う
1816
- optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
1817
- print(f"use {optimizer_type} | {optimizer_kwargs}")
1818
- if "." not in optimizer_type:
1819
- optimizer_module = torch.optim
1820
- else:
1821
- values = optimizer_type.split(".")
1822
- optimizer_module = importlib.import_module(".".join(values[:-1]))
1823
- optimizer_type = values[-1]
1824
-
1825
- optimizer_class = getattr(optimizer_module, optimizer_type)
1826
- optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1827
-
1828
- optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
1829
- optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
1830
-
1831
- return optimizer_name, optimizer_args, optimizer
1832
-
1833
-
1834
- # Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
1835
- # code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
1836
- # Which is a newer release of diffusers than currently packaged with sd-scripts
1837
- # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
1838
-
1839
-
1840
- def get_scheduler_fix(
1841
- name: Union[str, SchedulerType],
1842
- optimizer: Optimizer,
1843
- num_warmup_steps: Optional[int] = None,
1844
- num_training_steps: Optional[int] = None,
1845
- num_cycles: int = 1,
1846
- power: float = 1.0,
1847
- ):
1848
- """
1849
- Unified API to get any scheduler from its name.
1850
- Args:
1851
- name (`str` or `SchedulerType`):
1852
- The name of the scheduler to use.
1853
- optimizer (`torch.optim.Optimizer`):
1854
- The optimizer that will be used during training.
1855
- num_warmup_steps (`int`, *optional*):
1856
- The number of warmup steps to do. This is not required by all schedulers (hence the argument being
1857
- optional), the function will raise an error if it's unset and the scheduler type requires it.
1858
- num_training_steps (`int``, *optional*):
1859
- The number of training steps to do. This is not required by all schedulers (hence the argument being
1860
- optional), the function will raise an error if it's unset and the scheduler type requires it.
1861
- num_cycles (`int`, *optional*):
1862
- The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
1863
- power (`float`, *optional*, defaults to 1.0):
1864
- Power factor. See `POLYNOMIAL` scheduler
1865
- last_epoch (`int`, *optional*, defaults to -1):
1866
- The index of the last epoch when resuming training.
1867
- """
1868
- if name.startswith("adafactor"):
1869
- assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
1870
- initial_lr = float(name.split(':')[1])
1871
- # print("adafactor scheduler init lr", initial_lr)
1872
- return transformers.optimization.AdafactorSchedule(optimizer, initial_lr)
1873
-
1874
- name = SchedulerType(name)
1875
- schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
1876
- if name == SchedulerType.CONSTANT:
1877
- return schedule_func(optimizer)
1878
-
1879
- # All other schedulers require `num_warmup_steps`
1880
- if num_warmup_steps is None:
1881
- raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
1882
-
1883
- if name == SchedulerType.CONSTANT_WITH_WARMUP:
1884
- return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
1885
-
1886
- # All other schedulers require `num_training_steps`
1887
- if num_training_steps is None:
1888
- raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
1889
-
1890
- if name == SchedulerType.COSINE_WITH_RESTARTS:
1891
- return schedule_func(
1892
- optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
1893
- )
1894
-
1895
- if name == SchedulerType.POLYNOMIAL:
1896
- return schedule_func(
1897
- optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
1898
- )
1899
-
1900
- return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
1901
-
1902
-
1903
  def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
1904
  # backward compatibility
1905
  if args.caption_extention is not None:
1906
  args.caption_extension = args.caption_extention
1907
  args.caption_extention = None
1908
 
 
 
 
 
1909
  # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
1910
  if args.resolution is not None:
1911
  args.resolution = tuple([int(r) for r in args.resolution.split(',')])
@@ -1928,28 +1536,12 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
1928
 
1929
  def load_tokenizer(args: argparse.Namespace):
1930
  print("prepare tokenizer")
1931
- original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH
1932
-
1933
- tokenizer: CLIPTokenizer = None
1934
- if args.tokenizer_cache_dir:
1935
- local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace('/', '_'))
1936
- if os.path.exists(local_tokenizer_path):
1937
- print(f"load tokenizer from cache: {local_tokenizer_path}")
1938
- tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2
1939
-
1940
- if tokenizer is None:
1941
- if args.v2:
1942
- tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer")
1943
- else:
1944
- tokenizer = CLIPTokenizer.from_pretrained(original_path)
1945
-
1946
- if hasattr(args, "max_token_length") and args.max_token_length is not None:
1947
  print(f"update token length: {args.max_token_length}")
1948
-
1949
- if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
1950
- print(f"save Tokenizer to cache: {local_tokenizer_path}")
1951
- tokenizer.save_pretrained(local_tokenizer_path)
1952
-
1953
  return tokenizer
1954
 
1955
 
@@ -2000,19 +1592,13 @@ def prepare_dtype(args: argparse.Namespace):
2000
 
2001
 
2002
  def load_target_model(args: argparse.Namespace, weight_dtype):
2003
- name_or_path = args.pretrained_model_name_or_path
2004
- name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
2005
- load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
2006
  if load_stable_diffusion_format:
2007
  print("load StableDiffusion checkpoint")
2008
- text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path)
2009
  else:
2010
  print("load Diffusers pretrained models")
2011
- try:
2012
- pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
2013
- except EnvironmentError as ex:
2014
- print(
2015
- f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}")
2016
  text_encoder = pipe.text_encoder
2017
  vae = pipe.vae
2018
  unet = pipe.unet
@@ -2181,185 +1767,6 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
2181
  model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
2182
  accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
2183
 
2184
-
2185
- # scheduler:
2186
- SCHEDULER_LINEAR_START = 0.00085
2187
- SCHEDULER_LINEAR_END = 0.0120
2188
- SCHEDULER_TIMESTEPS = 1000
2189
- SCHEDLER_SCHEDULE = 'scaled_linear'
2190
-
2191
-
2192
- def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None):
2193
- """
2194
- 生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない
2195
- clip skipは対応した
2196
- """
2197
- if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
2198
- return
2199
- if args.sample_every_n_epochs is not None:
2200
- # sample_every_n_steps は無視する
2201
- if epoch is None or epoch % args.sample_every_n_epochs != 0:
2202
- return
2203
- else:
2204
- if steps % args.sample_every_n_steps != 0:
2205
- return
2206
-
2207
- print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
2208
- if not os.path.isfile(args.sample_prompts):
2209
- print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
2210
- return
2211
-
2212
- # ここでCUDAのキャッシュクリアとかしたほうがいいのか……
2213
-
2214
- org_vae_device = vae.device # CPUにいるはず
2215
- vae.to(device)
2216
-
2217
- # clip skip 対応のための wrapper を作る
2218
- if args.clip_skip is None:
2219
- text_encoder_or_wrapper = text_encoder
2220
- else:
2221
- class Wrapper():
2222
- def __init__(self, tenc) -> None:
2223
- self.tenc = tenc
2224
- self.config = {}
2225
- super().__init__()
2226
-
2227
- def __call__(self, input_ids, attention_mask):
2228
- enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True)
2229
- encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
2230
- encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states)
2231
- pooled_output = enc_out['pooler_output']
2232
- return encoder_hidden_states, pooled_output # 1st output is only used
2233
-
2234
- text_encoder_or_wrapper = Wrapper(text_encoder)
2235
-
2236
- # read prompts
2237
- with open(args.sample_prompts, 'rt', encoding='utf-8') as f:
2238
- prompts = f.readlines()
2239
-
2240
- # schedulerを用意する
2241
- sched_init_args = {}
2242
- if args.sample_sampler == "ddim":
2243
- scheduler_cls = DDIMScheduler
2244
- elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
2245
- scheduler_cls = DDPMScheduler
2246
- elif args.sample_sampler == "pndm":
2247
- scheduler_cls = PNDMScheduler
2248
- elif args.sample_sampler == 'lms' or args.sample_sampler == 'k_lms':
2249
- scheduler_cls = LMSDiscreteScheduler
2250
- elif args.sample_sampler == 'euler' or args.sample_sampler == 'k_euler':
2251
- scheduler_cls = EulerDiscreteScheduler
2252
- elif args.sample_sampler == 'euler_a' or args.sample_sampler == 'k_euler_a':
2253
- scheduler_cls = EulerAncestralDiscreteScheduler
2254
- elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
2255
- scheduler_cls = DPMSolverMultistepScheduler
2256
- sched_init_args['algorithm_type'] = args.sample_sampler
2257
- elif args.sample_sampler == "dpmsingle":
2258
- scheduler_cls = DPMSolverSinglestepScheduler
2259
- elif args.sample_sampler == "heun":
2260
- scheduler_cls = HeunDiscreteScheduler
2261
- elif args.sample_sampler == 'dpm_2' or args.sample_sampler == 'k_dpm_2':
2262
- scheduler_cls = KDPM2DiscreteScheduler
2263
- elif args.sample_sampler == 'dpm_2_a' or args.sample_sampler == 'k_dpm_2_a':
2264
- scheduler_cls = KDPM2AncestralDiscreteScheduler
2265
- else:
2266
- scheduler_cls = DDIMScheduler
2267
-
2268
- if args.v_parameterization:
2269
- sched_init_args['prediction_type'] = 'v_prediction'
2270
-
2271
- scheduler = scheduler_cls(num_train_timesteps=SCHEDULER_TIMESTEPS,
2272
- beta_start=SCHEDULER_LINEAR_START, beta_end=SCHEDULER_LINEAR_END,
2273
- beta_schedule=SCHEDLER_SCHEDULE, **sched_init_args)
2274
-
2275
- # clip_sample=Trueにする
2276
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
2277
- # print("set clip_sample to True")
2278
- scheduler.config.clip_sample = True
2279
-
2280
- pipeline = StableDiffusionPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer,
2281
- scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False)
2282
- pipeline.to(device)
2283
-
2284
- save_dir = args.output_dir + "/sample"
2285
- os.makedirs(save_dir, exist_ok=True)
2286
-
2287
- rng_state = torch.get_rng_state()
2288
- cuda_rng_state = torch.cuda.get_rng_state()
2289
-
2290
- with torch.no_grad():
2291
- with accelerator.autocast():
2292
- for i, prompt in enumerate(prompts):
2293
- prompt = prompt.strip()
2294
- if len(prompt) == 0 or prompt[0] == '#':
2295
- continue
2296
-
2297
- # subset of gen_img_diffusers
2298
- prompt_args = prompt.split(' --')
2299
- prompt = prompt_args[0]
2300
- negative_prompt = None
2301
- sample_steps = 30
2302
- width = height = 512
2303
- scale = 7.5
2304
- seed = None
2305
- for parg in prompt_args:
2306
- try:
2307
- m = re.match(r'w (\d+)', parg, re.IGNORECASE)
2308
- if m:
2309
- width = int(m.group(1))
2310
- continue
2311
-
2312
- m = re.match(r'h (\d+)', parg, re.IGNORECASE)
2313
- if m:
2314
- height = int(m.group(1))
2315
- continue
2316
-
2317
- m = re.match(r'd (\d+)', parg, re.IGNORECASE)
2318
- if m:
2319
- seed = int(m.group(1))
2320
- continue
2321
-
2322
- m = re.match(r's (\d+)', parg, re.IGNORECASE)
2323
- if m: # steps
2324
- sample_steps = max(1, min(1000, int(m.group(1))))
2325
- continue
2326
-
2327
- m = re.match(r'l ([\d\.]+)', parg, re.IGNORECASE)
2328
- if m: # scale
2329
- scale = float(m.group(1))
2330
- continue
2331
-
2332
- m = re.match(r'n (.+)', parg, re.IGNORECASE)
2333
- if m: # negative prompt
2334
- negative_prompt = m.group(1)
2335
- continue
2336
-
2337
- except ValueError as ex:
2338
- print(f"Exception in parsing / 解析エラー: {parg}")
2339
- print(ex)
2340
-
2341
- if seed is not None:
2342
- torch.manual_seed(seed)
2343
- torch.cuda.manual_seed(seed)
2344
-
2345
- if prompt_replacement is not None:
2346
- prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
2347
- if negative_prompt is not None:
2348
- negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
2349
-
2350
- image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
2351
-
2352
- ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
2353
- num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
2354
- seed_suffix = "" if seed is None else f"_{seed}"
2355
- img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
2356
-
2357
- image.save(os.path.join(save_dir, img_filename))
2358
-
2359
- torch.set_rng_state(rng_state)
2360
- torch.cuda.set_rng_state(cuda_rng_state)
2361
- vae.to(org_vae_device)
2362
-
2363
  # endregion
2364
 
2365
  # region 前処理用
 
1
  # common functions for training
2
 
3
  import argparse
 
4
  import json
 
5
  import shutil
6
  import time
7
+ from typing import Dict, List, NamedTuple, Tuple
 
 
 
 
 
 
 
 
8
  from accelerate import Accelerator
9
+ from torch.autograd.function import Function
10
  import glob
11
  import math
12
  import os
 
17
 
18
  from tqdm import tqdm
19
  import torch
 
20
  from torchvision import transforms
21
  from transformers import CLIPTokenizer
 
22
  import diffusers
23
+ from diffusers import DDPMScheduler, StableDiffusionPipeline
 
 
 
 
24
  import albumentations as albu
25
  import numpy as np
26
  from PIL import Image
 
195
  batch_index: int
196
 
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  class BaseDataset(torch.utils.data.Dataset):
199
+ def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
200
  super().__init__()
201
+ self.tokenizer: CLIPTokenizer = tokenizer
202
  self.max_token_length = max_token_length
203
+ self.shuffle_caption = shuffle_caption
204
+ self.shuffle_keep_tokens = shuffle_keep_tokens
205
  # width/height is used when enable_bucket==False
206
  self.width, self.height = (None, None) if resolution is None else resolution
207
+ self.face_crop_aug_range = face_crop_aug_range
208
+ self.flip_aug = flip_aug
209
+ self.color_aug = color_aug
210
  self.debug_dataset = debug_dataset
211
+ self.random_crop = random_crop
 
 
212
  self.token_padding_disabled = False
213
+ self.dataset_dirs_info = {}
214
+ self.reg_dataset_dirs_info = {}
215
  self.tag_frequency = {}
216
 
217
  self.enable_bucket = False
 
225
  self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
226
 
227
  self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
228
+ self.dropout_rate: float = 0
229
+ self.dropout_every_n_epochs: int = None
230
+ self.tag_dropout_rate: float = 0
231
 
232
  # augmentation
233
+ flip_p = 0.5 if flip_aug else 0.0
234
+ if color_aug:
235
+ # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
236
+ self.aug = albu.Compose([
237
+ albu.OneOf([
238
+ albu.HueSaturationValue(8, 0, 0, p=.5),
239
+ albu.RandomGamma((95, 105), p=.5),
240
+ ], p=.33),
241
+ albu.HorizontalFlip(p=flip_p)
242
+ ], p=1.)
243
+ elif flip_aug:
244
+ self.aug = albu.Compose([
245
+ albu.HorizontalFlip(p=flip_p)
246
+ ], p=1.)
247
+ else:
248
+ self.aug = None
249
 
250
  self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
251
 
252
  self.image_data: Dict[str, ImageInfo] = {}
 
253
 
254
  self.replacements = {}
255
 
256
  def set_current_epoch(self, epoch):
257
  self.current_epoch = epoch
258
+
259
+ def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
260
+ # コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
261
+ self.dropout_rate = dropout_rate
262
+ self.dropout_every_n_epochs = dropout_every_n_epochs
263
+ self.tag_dropout_rate = tag_dropout_rate
264
 
265
  def set_tag_frequency(self, dir_name, captions):
266
  frequency_for_dir = self.tag_frequency.get(dir_name, {})
267
  self.tag_frequency[dir_name] = frequency_for_dir
268
  for caption in captions:
269
  for tag in caption.split(","):
270
+ if tag and not tag.isspace():
 
271
  tag = tag.lower()
272
  frequency = frequency_for_dir.get(tag, 0)
273
  frequency_for_dir[tag] = frequency + 1
 
278
  def add_replacement(self, str_from, str_to):
279
  self.replacements[str_from] = str_to
280
 
281
+ def process_caption(self, caption):
282
  # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
283
+ is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
284
+ is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
285
 
286
  if is_drop_out:
287
  caption = ""
288
  else:
289
+ if self.shuffle_caption or self.tag_dropout_rate > 0:
290
  def dropout_tags(tokens):
291
+ if self.tag_dropout_rate <= 0:
292
  return tokens
293
  l = []
294
  for token in tokens:
295
+ if random.random() >= self.tag_dropout_rate:
296
  l.append(token)
297
  return l
298
 
299
+ tokens = [t.strip() for t in caption.strip().split(",")]
300
+ if self.shuffle_keep_tokens is None:
301
+ if self.shuffle_caption:
302
+ random.shuffle(tokens)
303
+
304
+ tokens = dropout_tags(tokens)
305
+ else:
306
+ if len(tokens) > self.shuffle_keep_tokens:
307
+ keep_tokens = tokens[:self.shuffle_keep_tokens]
308
+ tokens = tokens[self.shuffle_keep_tokens:]
309
 
310
+ if self.shuffle_caption:
311
+ random.shuffle(tokens)
312
 
313
+ tokens = dropout_tags(tokens)
314
 
315
+ tokens = keep_tokens + tokens
316
+ caption = ", ".join(tokens)
317
 
318
  # textual inversion対応
319
  for str_from, str_to in self.replacements.items():
 
367
  input_ids = torch.stack(iids_list) # 3,77
368
  return input_ids
369
 
370
+ def register_image(self, info: ImageInfo):
371
  self.image_data[info.image_key] = info
 
372
 
373
  def make_buckets(self):
374
  '''
 
467
  img = np.array(image, np.uint8)
468
  return img
469
 
470
+ def trim_and_resize_if_required(self, image, reso, resized_size):
471
  image_height, image_width = image.shape[0:2]
472
 
473
  if image_width != resized_size[0] or image_height != resized_size[1]:
 
477
  image_height, image_width = image.shape[0:2]
478
  if image_width > reso[0]:
479
  trim_size = image_width - reso[0]
480
+ p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
481
  # print("w", trim_size, p)
482
  image = image[:, p:p + reso[0]]
483
  if image_height > reso[1]:
484
  trim_size = image_height - reso[1]
485
+ p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
486
  # print("h", trim_size, p)
487
  image = image[p:p + reso[1]]
488
 
489
  assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
490
  return image
491
 
 
 
 
492
  def cache_latents(self, vae):
493
  # TODO ここを高速化したい
494
  print("caching latents.")
495
  for info in tqdm(self.image_data.values()):
 
 
496
  if info.latents_npz is not None:
497
  info.latents = self.load_latents_from_npz(info, False)
498
  info.latents = torch.FloatTensor(info.latents)
 
502
  continue
503
 
504
  image = self.load_image(info.absolute_path)
505
+ image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
506
 
507
  img_tensor = self.image_transforms(image)
508
  img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
509
  info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
510
 
511
+ if self.flip_aug:
512
  image = image[:, ::-1].copy() # cannot convert to Tensor without copy
513
  img_tensor = self.image_transforms(image)
514
  img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
 
518
  image = Image.open(image_path)
519
  return image.size
520
 
521
+ def load_image_with_face_info(self, image_path: str):
522
  img = self.load_image(image_path)
523
 
524
  face_cx = face_cy = face_w = face_h = 0
525
+ if self.face_crop_aug_range is not None:
526
  tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
527
  if len(tokens) >= 5:
528
  face_cx = int(tokens[-4])
 
533
  return img, face_cx, face_cy, face_w, face_h
534
 
535
  # いい感じに切り出す
536
+ def crop_target(self, image, face_cx, face_cy, face_w, face_h):
537
  height, width = image.shape[0:2]
538
  if height == self.height and width == self.width:
539
  return image
 
541
  # 画像サイズはsizeより大きいのでリサイズする
542
  face_size = max(face_w, face_h)
543
  min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
544
+ min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
545
+ max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
546
  if min_scale >= max_scale: # range指定がmin==max
547
  scale = min_scale
548
  else:
 
560
  for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
561
  p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
562
 
563
+ if self.random_crop:
564
  # 背景も含めるために顔を中心に置く確率を高めつつずらす
565
  range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
566
  p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
567
  else:
568
  # range指定があるときのみ、すこしだけランダムに(わりと適当)
569
+ if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
570
  if face_size > self.size // 10 and face_size >= 40:
571
  p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
572
 
 
589
  return self._length
590
 
591
  def __getitem__(self, index):
592
+ if index == 0:
593
+ self.shuffle_buckets()
594
+
595
  bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
596
  bucket_batch_size = self.buckets_indices[index].bucket_batch_size
597
  image_index = self.buckets_indices[index].batch_index * bucket_batch_size
 
604
 
605
  for image_key in bucket[image_index:image_index + bucket_batch_size]:
606
  image_info = self.image_data[image_key]
 
607
  loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
608
 
609
  # image/latentsを処理する
610
  if image_info.latents is not None:
611
+ latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped
612
  image = None
613
  elif image_info.latents_npz is not None:
614
+ latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5)
615
  latents = torch.FloatTensor(latents)
616
  image = None
617
  else:
618
  # 画像を読み込み、必要ならcropする
619
+ img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
620
  im_h, im_w = img.shape[0:2]
621
 
622
  if self.enable_bucket:
623
+ img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
624
  else:
625
  if face_cx > 0: # 顔位置情報あり
626
+ img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
627
  elif im_h > self.height or im_w > self.width:
628
+ assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
629
  if im_h > self.height:
630
  p = random.randint(0, im_h - self.height)
631
  img = img[p:p + self.height]
 
637
  assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
638
 
639
  # augmentation
640
+ if self.aug is not None:
641
+ img = self.aug(image=img)['image']
 
642
 
643
  latents = None
644
  image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
 
646
  images.append(image)
647
  latents_list.append(latents)
648
 
649
+ caption = self.process_caption(image_info.caption)
650
  captions.append(caption)
651
  if not self.token_padding_disabled: # this option might be omitted in future
652
  input_ids_list.append(self.get_input_ids(caption))
 
677
 
678
 
679
  class DreamBoothDataset(BaseDataset):
680
+ def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
681
+ super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
682
+ resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
683
 
684
  assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
685
 
 
702
  self.bucket_reso_steps = None # この情報は使われない
703
  self.bucket_no_upscale = False
704
 
705
+ def read_caption(img_path):
706
  # captionの候補ファイル名を作る
707
  base_name = os.path.splitext(img_path)[0]
708
  base_name_face_det = base_name
 
725
  break
726
  return caption
727
 
728
+ def load_dreambooth_dir(dir):
729
+ if not os.path.isdir(dir):
730
+ # print(f"ignore file: {dir}")
731
+ return 0, [], []
732
+
733
+ tokens = os.path.basename(dir).split('_')
734
+ try:
735
+ n_repeats = int(tokens[0])
736
+ except ValueError as e:
737
+ print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
738
+ return 0, [], []
739
 
740
+ caption_by_folder = '_'.join(tokens[1:])
741
+ img_paths = glob_images(dir, "*")
742
+ print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
743
 
744
  # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
745
  captions = []
746
  for img_path in img_paths:
747
+ cap_for_img = read_caption(img_path)
748
+ captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
 
 
 
 
 
 
749
 
750
+ self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
 
 
 
 
 
 
 
 
 
 
 
 
 
751
 
752
+ return n_repeats, img_paths, captions
 
 
 
753
 
754
+ print("prepare train images.")
755
+ train_dirs = os.listdir(train_data_dir)
756
+ num_train_images = 0
757
+ for dir in train_dirs:
758
+ n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
759
+ num_train_images += n_repeats * len(img_paths)
760
 
761
  for img_path, caption in zip(img_paths, captions):
762
+ info = ImageInfo(img_path, n_repeats, caption, False, img_path)
763
+ self.register_image(info)
 
 
 
764
 
765
+ self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
 
766
 
767
  print(f"{num_train_images} train images with repeating.")
768
  self.num_train_images = num_train_images
769
 
770
+ # reg imageは数を数えて学習画像と同じ枚数にする
771
+ num_reg_images = 0
772
+ if reg_data_dir:
773
+ print("prepare reg images.")
774
+ reg_infos: List[ImageInfo] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
775
 
776
+ reg_dirs = os.listdir(reg_data_dir)
777
+ for dir in reg_dirs:
778
+ n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
779
+ num_reg_images += n_repeats * len(img_paths)
780
 
781
+ for img_path, caption in zip(img_paths, captions):
782
+ info = ImageInfo(img_path, n_repeats, caption, True, img_path)
783
+ reg_infos.append(info)
784
 
785
+ self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
 
 
 
786
 
787
+ print(f"{num_reg_images} reg images.")
788
+ if num_train_images < num_reg_images:
789
+ print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
790
 
791
+ if num_reg_images == 0:
792
+ print("no regularization images / 正則化画像が見つかりませんでした")
 
 
 
793
  else:
794
+ # num_repeatsを計算する:どうせ大した数ではないのでループで処理する
795
+ n = 0
796
+ first_loop = True
797
+ while n < num_train_images:
798
+ for info in reg_infos:
799
+ if first_loop:
800
+ self.register_image(info)
801
+ n += info.num_repeats
802
+ else:
803
+ info.num_repeats += 1
804
+ n += 1
805
+ if n >= num_train_images:
806
+ break
807
+ first_loop = False
808
 
809
+ self.num_reg_images = num_reg_images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
810
 
 
 
811
 
812
+ class FineTuningDataset(BaseDataset):
813
+ def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
814
+ super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
815
+ resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
816
+
817
+ # メタデータを読み込む
818
+ if os.path.exists(json_file_name):
819
+ print(f"loading existing metadata: {json_file_name}")
820
+ with open(json_file_name, "rt", encoding='utf-8') as f:
821
+ metadata = json.load(f)
822
+ else:
823
+ raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}")
824
 
825
+ self.metadata = metadata
826
+ self.train_data_dir = train_data_dir
827
+ self.batch_size = batch_size
828
 
829
+ tags_list = []
830
+ for image_key, img_md in metadata.items():
831
+ # path情報を作る
832
+ if os.path.exists(image_key):
833
+ abs_path = image_key
834
+ else:
835
+ # わりといい加減だがいい方法が思いつかん
836
+ abs_path = glob_images(train_data_dir, image_key)
837
+ assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
838
+ abs_path = abs_path[0]
839
+
840
+ caption = img_md.get('caption')
841
+ tags = img_md.get('tags')
842
+ if caption is None:
843
+ caption = tags
844
+ elif tags is not None and len(tags) > 0:
845
+ caption = caption + ', ' + tags
846
+ tags_list.append(tags)
847
+ assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
848
+
849
+ image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
850
+ image_info.image_size = img_md.get('train_resolution')
851
+
852
+ if not self.color_aug and not self.random_crop:
853
+ # if npz exists, use them
854
+ image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
855
+
856
+ self.register_image(image_info)
857
+ self.num_train_images = len(metadata) * dataset_repeats
858
+ self.num_reg_images = 0
859
 
860
+ # TODO do not record tag freq when no tag
861
+ self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
862
+ self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
 
863
 
864
  # check existence of all npz files
865
+ use_npz_latents = not (self.color_aug or self.random_crop)
866
  if use_npz_latents:
 
867
  npz_any = False
868
  npz_all = True
 
869
  for image_info in self.image_data.values():
 
 
870
  has_npz = image_info.latents_npz is not None
871
  npz_any = npz_any or has_npz
872
 
873
+ if self.flip_aug:
874
  has_npz = has_npz and image_info.latents_npz_flipped is not None
 
875
  npz_all = npz_all and has_npz
876
 
877
  if npz_any and not npz_all:
 
883
  elif not npz_all:
884
  use_npz_latents = False
885
  print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
886
+ if self.flip_aug:
887
  print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
888
  # else:
889
  # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
 
929
  for image_info in self.image_data.values():
930
  image_info.latents_npz = image_info.latents_npz_flipped = None
931
 
932
+ def image_key_to_npz_file(self, image_key):
933
  base_name = os.path.splitext(image_key)[0]
934
  npz_file_norm = base_name + '.npz'
935
 
 
941
  return npz_file_norm, npz_file_flip
942
 
943
  # image_key is relative path
944
+ npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz')
945
+ npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz')
946
 
947
  if not os.path.exists(npz_file_norm):
948
  npz_file_norm = None
 
953
  return npz_file_norm, npz_file_flip
954
 
955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
956
  def debug_dataset(train_dataset, show_input_ids=False):
957
  print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
958
  print("Escape for exit. / Escキーで中断、終了します")
959
 
960
  train_dataset.set_current_epoch(1)
961
  k = 0
962
+ for i, example in enumerate(train_dataset):
 
 
 
963
  if example['latents'] is not None:
964
  print(f"sample has latents from npz file: {example['latents'].size()}")
965
  for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
 
1364
  help='enable v-parameterization training / v-parameterization学習を有効にする')
1365
  parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
1366
  help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1367
 
1368
 
1369
  def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
 
1387
  parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
1388
  parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
1389
  help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
1390
+ parser.add_argument("--use_8bit_adam", action="store_true",
1391
+ help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
1392
+ parser.add_argument("--use_lion_optimizer", action="store_true",
1393
+ help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
1394
  parser.add_argument("--mem_eff_attn", action="store_true",
1395
  help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
1396
  parser.add_argument("--xformers", action="store_true",
 
1398
  parser.add_argument("--vae", type=str, default=None,
1399
  help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
1400
 
1401
+ parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
1402
  parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
1403
  parser.add_argument("--max_train_epochs", type=int, default=None,
1404
  help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
 
1419
  parser.add_argument("--logging_dir", type=str, default=None,
1420
  help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
1421
  parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
1422
+ parser.add_argument("--lr_scheduler", type=str, default="constant",
1423
+ help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
1424
+ parser.add_argument("--lr_warmup_steps", type=int, default=0,
1425
+ help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
1426
  parser.add_argument("--noise_offset", type=float, default=None,
1427
  help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
1428
  parser.add_argument("--lowram", action="store_true",
1429
  help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
1430
 
 
 
 
 
 
 
 
 
 
 
 
 
1431
  if support_dreambooth:
1432
  # DreamBooth training
1433
  parser.add_argument("--prior_loss_weight", type=float, default=1.0,
 
1449
  parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
1450
  parser.add_argument("--caption_extention", type=str, default=None,
1451
  help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
1452
+ parser.add_argument("--keep_tokens", type=int, default=None,
1453
+ help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
1454
  parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
1455
  parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
1456
  parser.add_argument("--face_crop_aug_range", type=str, default=None,
 
1475
  if support_caption_dropout:
1476
  # Textual Inversion はcaptionのdropoutをsupportしない
1477
  # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
1478
+ parser.add_argument("--caption_dropout_rate", type=float, default=0,
1479
  help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
1480
+ parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
1481
  help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
1482
+ parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
1483
  help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
1484
 
1485
  if support_dreambooth:
 
1504
  # region utils
1505
 
1506
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1507
  def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
1508
  # backward compatibility
1509
  if args.caption_extention is not None:
1510
  args.caption_extension = args.caption_extention
1511
  args.caption_extention = None
1512
 
1513
+ if args.cache_latents:
1514
+ assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
1515
+ assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
1516
+
1517
  # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
1518
  if args.resolution is not None:
1519
  args.resolution = tuple([int(r) for r in args.resolution.split(',')])
 
1536
 
1537
  def load_tokenizer(args: argparse.Namespace):
1538
  print("prepare tokenizer")
1539
+ if args.v2:
1540
+ tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
1541
+ else:
1542
+ tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
1543
+ if args.max_token_length is not None:
 
 
 
 
 
 
 
 
 
 
 
1544
  print(f"update token length: {args.max_token_length}")
 
 
 
 
 
1545
  return tokenizer
1546
 
1547
 
 
1592
 
1593
 
1594
  def load_target_model(args: argparse.Namespace, weight_dtype):
1595
+ load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers
 
 
1596
  if load_stable_diffusion_format:
1597
  print("load StableDiffusion checkpoint")
1598
+ text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
1599
  else:
1600
  print("load Diffusers pretrained models")
1601
+ pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
 
 
 
 
1602
  text_encoder = pipe.text_encoder
1603
  vae = pipe.vae
1604
  unet = pipe.unet
 
1767
  model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
1768
  accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
1769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1770
  # endregion
1771
 
1772
  # region 前処理用
networks/lora.py CHANGED
@@ -126,11 +126,6 @@ class LoRANetwork(torch.nn.Module):
126
  assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
127
  names.add(lora.lora_name)
128
 
129
- def set_multiplier(self, multiplier):
130
- self.multiplier = multiplier
131
- for lora in self.text_encoder_loras + self.unet_loras:
132
- lora.multiplier = self.multiplier
133
-
134
  def load_weights(self, file):
135
  if os.path.splitext(file)[1] == '.safetensors':
136
  from safetensors.torch import load_file, safe_open
 
126
  assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
127
  names.add(lora.lora_name)
128
 
 
 
 
 
 
129
  def load_weights(self, file):
130
  if os.path.splitext(file)[1] == '.safetensors':
131
  from safetensors.torch import load_file, safe_open
tools/convert_diffusers20_original_sd.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # convert Diffusers v1.x/v2.0 model to original Stable Diffusion
2
+
3
+ import argparse
4
+ import os
5
+ import torch
6
+ from diffusers import StableDiffusionPipeline
7
+
8
+ import library.model_util as model_util
9
+
10
+
11
+ def convert(args):
12
+ # 引数を確認する
13
+ load_dtype = torch.float16 if args.fp16 else None
14
+
15
+ save_dtype = None
16
+ if args.fp16:
17
+ save_dtype = torch.float16
18
+ elif args.bf16:
19
+ save_dtype = torch.bfloat16
20
+ elif args.float:
21
+ save_dtype = torch.float
22
+
23
+ is_load_ckpt = os.path.isfile(args.model_to_load)
24
+ is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
25
+
26
+ assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
27
+ assert is_save_ckpt or args.reference_model is not None, f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
28
+
29
+ # モデルを読み込む
30
+ msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
31
+ print(f"loading {msg}: {args.model_to_load}")
32
+
33
+ if is_load_ckpt:
34
+ v2_model = args.v2
35
+ text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load)
36
+ else:
37
+ pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None)
38
+ text_encoder = pipe.text_encoder
39
+ vae = pipe.vae
40
+ unet = pipe.unet
41
+
42
+ if args.v1 == args.v2:
43
+ # 自動判定する
44
+ v2_model = unet.config.cross_attention_dim == 1024
45
+ print("checking model version: model is " + ('v2' if v2_model else 'v1'))
46
+ else:
47
+ v2_model = not args.v1
48
+
49
+ # 変換して保存する
50
+ msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
51
+ print(f"converting and saving as {msg}: {args.model_to_save}")
52
+
53
+ if is_save_ckpt:
54
+ original_model = args.model_to_load if is_load_ckpt else None
55
+ key_count = model_util.save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet,
56
+ original_model, args.epoch, args.global_step, save_dtype, vae)
57
+ print(f"model saved. total converted state_dict keys: {key_count}")
58
+ else:
59
+ print(f"copy scheduler/tokenizer config from: {args.reference_model}")
60
+ model_util.save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors)
61
+ print(f"model saved.")
62
+
63
+
64
+ if __name__ == '__main__':
65
+ parser = argparse.ArgumentParser()
66
+ parser.add_argument("--v1", action='store_true',
67
+ help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')
68
+ parser.add_argument("--v2", action='store_true',
69
+ help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む')
70
+ parser.add_argument("--fp16", action='store_true',
71
+ help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)')
72
+ parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)')
73
+ parser.add_argument("--float", action='store_true',
74
+ help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)')
75
+ parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値')
76
+ parser.add_argument("--global_step", type=int, default=0,
77
+ help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')
78
+ parser.add_argument("--reference_model", type=str, default=None,
79
+ help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要")
80
+ parser.add_argument("--use_safetensors", action='store_true',
81
+ help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)")
82
+
83
+ parser.add_argument("model_to_load", type=str, default=None,
84
+ help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")
85
+ parser.add_argument("model_to_save", type=str, default=None,
86
+ help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")
87
+
88
+ args = parser.parse_args()
89
+ convert(args)
tools/detect_face_rotate.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
2
+ # (c) 2022 Kohya S. @kohya_ss
3
+
4
+ # 横長の画像から顔検出して正立するように回転し、そこを中心に正方形に切り出す
5
+
6
+ # v2: extract max face if multiple faces are found
7
+ # v3: add crop_ratio option
8
+ # v4: add multiple faces extraction and min/max size
9
+
10
+ import argparse
11
+ import math
12
+ import cv2
13
+ import glob
14
+ import os
15
+ from anime_face_detector import create_detector
16
+ from tqdm import tqdm
17
+ import numpy as np
18
+
19
+ KP_REYE = 11
20
+ KP_LEYE = 19
21
+
22
+ SCORE_THRES = 0.90
23
+
24
+
25
+ def detect_faces(detector, image, min_size):
26
+ preds = detector(image) # bgr
27
+ # print(len(preds))
28
+
29
+ faces = []
30
+ for pred in preds:
31
+ bb = pred['bbox']
32
+ score = bb[-1]
33
+ if score < SCORE_THRES:
34
+ continue
35
+
36
+ left, top, right, bottom = bb[:4]
37
+ cx = int((left + right) / 2)
38
+ cy = int((top + bottom) / 2)
39
+ fw = int(right - left)
40
+ fh = int(bottom - top)
41
+
42
+ lex, ley = pred['keypoints'][KP_LEYE, 0:2]
43
+ rex, rey = pred['keypoints'][KP_REYE, 0:2]
44
+ angle = math.atan2(ley - rey, lex - rex)
45
+ angle = angle / math.pi * 180
46
+
47
+ faces.append((cx, cy, fw, fh, angle))
48
+
49
+ faces.sort(key=lambda x: max(x[2], x[3]), reverse=True) # 大きい順
50
+ return faces
51
+
52
+
53
+ def rotate_image(image, angle, cx, cy):
54
+ h, w = image.shape[0:2]
55
+ rot_mat = cv2.getRotationMatrix2D((cx, cy), angle, 1.0)
56
+
57
+ # # 回転する分、すこし画像サイズを大きくする→とりあえず無効化
58
+ # nh = max(h, int(w * math.sin(angle)))
59
+ # nw = max(w, int(h * math.sin(angle)))
60
+ # if nh > h or nw > w:
61
+ # pad_y = nh - h
62
+ # pad_t = pad_y // 2
63
+ # pad_x = nw - w
64
+ # pad_l = pad_x // 2
65
+ # m = np.array([[0, 0, pad_l],
66
+ # [0, 0, pad_t]])
67
+ # rot_mat = rot_mat + m
68
+ # h, w = nh, nw
69
+ # cx += pad_l
70
+ # cy += pad_t
71
+
72
+ result = cv2.warpAffine(image, rot_mat, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
73
+ return result, cx, cy
74
+
75
+
76
+ def process(args):
77
+ assert (not args.resize_fit) or args.resize_face_size is None, f"resize_fit and resize_face_size can't be specified both / resize_fitとresize_face_sizeはどちらか片方しか指定できません"
78
+ assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません"
79
+
80
+ # アニメ顔検出モデルを読み込む
81
+ print("loading face detector.")
82
+ detector = create_detector('yolov3')
83
+
84
+ # cropの引数を解析する
85
+ if args.crop_size is None:
86
+ crop_width = crop_height = None
87
+ else:
88
+ tokens = args.crop_size.split(',')
89
+ assert len(tokens) == 2, f"crop_size must be 'width,height' / crop_sizeは'幅,高さ'で指定してください"
90
+ crop_width, crop_height = [int(t) for t in tokens]
91
+
92
+ if args.crop_ratio is None:
93
+ crop_h_ratio = crop_v_ratio = None
94
+ else:
95
+ tokens = args.crop_ratio.split(',')
96
+ assert len(tokens) == 2, f"crop_ratio must be 'horizontal,vertical' / crop_ratioは'幅,高さ'の倍率で指定してください"
97
+ crop_h_ratio, crop_v_ratio = [float(t) for t in tokens]
98
+
99
+ # 画像を処理する
100
+ print("processing.")
101
+ output_extension = ".png"
102
+
103
+ os.makedirs(args.dst_dir, exist_ok=True)
104
+ paths = glob.glob(os.path.join(args.src_dir, "*.png")) + glob.glob(os.path.join(args.src_dir, "*.jpg")) + \
105
+ glob.glob(os.path.join(args.src_dir, "*.webp"))
106
+ for path in tqdm(paths):
107
+ basename = os.path.splitext(os.path.basename(path))[0]
108
+
109
+ # image = cv2.imread(path) # 日本語ファイル名でエラーになる
110
+ image = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_UNCHANGED)
111
+ if len(image.shape) == 2:
112
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
113
+ if image.shape[2] == 4:
114
+ print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
115
+ image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい
116
+
117
+ h, w = image.shape[:2]
118
+
119
+ faces = detect_faces(detector, image, args.multiple_faces)
120
+ for i, face in enumerate(faces):
121
+ cx, cy, fw, fh, angle = face
122
+ face_size = max(fw, fh)
123
+ if args.min_size is not None and face_size < args.min_size:
124
+ continue
125
+ if args.max_size is not None and face_size >= args.max_size:
126
+ continue
127
+ face_suffix = f"_{i+1:02d}" if args.multiple_faces else ""
128
+
129
+ # オプション指定があれば回転する
130
+ face_img = image
131
+ if args.rotate:
132
+ face_img, cx, cy = rotate_image(face_img, angle, cx, cy)
133
+
134
+ # オプション指定があれば顔を中心に切り出す
135
+ if crop_width is not None or crop_h_ratio is not None:
136
+ cur_crop_width, cur_crop_height = crop_width, crop_height
137
+ if crop_h_ratio is not None:
138
+ cur_crop_width = int(face_size * crop_h_ratio + .5)
139
+ cur_crop_height = int(face_size * crop_v_ratio + .5)
140
+
141
+ # リサイズを必要なら行う
142
+ scale = 1.0
143
+ if args.resize_face_size is not None:
144
+ # 顔サイズを基準にリサイズする
145
+ scale = args.resize_face_size / face_size
146
+ if scale < cur_crop_width / w:
147
+ print(
148
+ f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
149
+ scale = cur_crop_width / w
150
+ if scale < cur_crop_height / h:
151
+ print(
152
+ f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
153
+ scale = cur_crop_height / h
154
+ elif crop_h_ratio is not None:
155
+ # 倍率指定の時にはリサイズしない
156
+ pass
157
+ else:
158
+ # 切り出しサイズ指定あり
159
+ if w < cur_crop_width:
160
+ print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
161
+ scale = cur_crop_width / w
162
+ if h < cur_crop_height:
163
+ print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
164
+ scale = cur_crop_height / h
165
+ if args.resize_fit:
166
+ scale = max(cur_crop_width / w, cur_crop_height / h)
167
+
168
+ if scale != 1.0:
169
+ w = int(w * scale + .5)
170
+ h = int(h * scale + .5)
171
+ face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4)
172
+ cx = int(cx * scale + .5)
173
+ cy = int(cy * scale + .5)
174
+ fw = int(fw * scale + .5)
175
+ fh = int(fh * scale + .5)
176
+
177
+ cur_crop_width = min(cur_crop_width, face_img.shape[1])
178
+ cur_crop_height = min(cur_crop_height, face_img.shape[0])
179
+
180
+ x = cx - cur_crop_width // 2
181
+ cx = cur_crop_width // 2
182
+ if x < 0:
183
+ cx = cx + x
184
+ x = 0
185
+ elif x + cur_crop_width > w:
186
+ cx = cx + (x + cur_crop_width - w)
187
+ x = w - cur_crop_width
188
+ face_img = face_img[:, x:x+cur_crop_width]
189
+
190
+ y = cy - cur_crop_height // 2
191
+ cy = cur_crop_height // 2
192
+ if y < 0:
193
+ cy = cy + y
194
+ y = 0
195
+ elif y + cur_crop_height > h:
196
+ cy = cy + (y + cur_crop_height - h)
197
+ y = h - cur_crop_height
198
+ face_img = face_img[y:y + cur_crop_height]
199
+
200
+ # # debug
201
+ # print(path, cx, cy, angle)
202
+ # crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8))
203
+ # cv2.imshow("image", crp)
204
+ # if cv2.waitKey() == 27:
205
+ # break
206
+ # cv2.destroyAllWindows()
207
+
208
+ # debug
209
+ if args.debug:
210
+ cv2.rectangle(face_img, (cx-fw//2, cy-fh//2), (cx+fw//2, cy+fh//2), (255, 0, 255), fw//20)
211
+
212
+ _, buf = cv2.imencode(output_extension, face_img)
213
+ with open(os.path.join(args.dst_dir, f"{basename}{face_suffix}_{cx:04d}_{cy:04d}_{fw:04d}_{fh:04d}{output_extension}"), "wb") as f:
214
+ buf.tofile(f)
215
+
216
+
217
+ if __name__ == '__main__':
218
+ parser = argparse.ArgumentParser()
219
+ parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ")
220
+ parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ")
221
+ parser.add_argument("--rotate", action="store_true", help="rotate images to align faces / 顔が正立するように画像を回転する")
222
+ parser.add_argument("--resize_fit", action="store_true",
223
+ help="resize to fit smaller side after cropping / 切り出し後の画像の短辺がcrop_sizeにあうようにリサイズする")
224
+ parser.add_argument("--resize_face_size", type=int, default=None,
225
+ help="resize image before cropping by face size / 切り出し前に顔がこのサイズになるようにリサイズする")
226
+ parser.add_argument("--crop_size", type=str, default=None,
227
+ help="crop images with 'width,height' pixels, face centered / 顔を中心として'幅,高さ'のサイズで切り出す")
228
+ parser.add_argument("--crop_ratio", type=str, default=None,
229
+ help="crop images with 'horizontal,vertical' ratio to face, face centered / 顔を中心として顔サイズの'幅倍率,高さ倍率'のサイズで切り出す")
230
+ parser.add_argument("--min_size", type=int, default=None,
231
+ help="minimum face size to output (included) / 処理対象とする顔の最小サイズ(この値以上)")
232
+ parser.add_argument("--max_size", type=int, default=None,
233
+ help="maximum face size to output (excluded) / 処理対象とする顔の最大サイズ(この値未満)")
234
+ parser.add_argument("--multiple_faces", action="store_true",
235
+ help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す")
236
+ parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します")
237
+ args = parser.parse_args()
238
+
239
+ process(args)
tools/resize_images_to_resolution.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import cv2
4
+ import argparse
5
+ import shutil
6
+ import math
7
+ from PIL import Image
8
+ import numpy as np
9
+
10
+
11
+ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
12
+ # Split the max_resolution string by "," and strip any whitespaces
13
+ max_resolutions = [res.strip() for res in max_resolution.split(',')]
14
+
15
+ # # Calculate max_pixels from max_resolution string
16
+ # max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
17
+
18
+ # Create destination folder if it does not exist
19
+ if not os.path.exists(dst_img_folder):
20
+ os.makedirs(dst_img_folder)
21
+
22
+ # Select interpolation method
23
+ if interpolation == 'lanczos4':
24
+ cv2_interpolation = cv2.INTER_LANCZOS4
25
+ elif interpolation == 'cubic':
26
+ cv2_interpolation = cv2.INTER_CUBIC
27
+ else:
28
+ cv2_interpolation = cv2.INTER_AREA
29
+
30
+ # Iterate through all files in src_img_folder
31
+ img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
32
+ for filename in os.listdir(src_img_folder):
33
+ # Check if the image is png, jpg or webp etc...
34
+ if not filename.endswith(img_exts):
35
+ # Copy the file to the destination folder if not png, jpg or webp etc (.txt or .caption or etc.)
36
+ shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
37
+ continue
38
+
39
+ # Load image
40
+ # img = cv2.imread(os.path.join(src_img_folder, filename))
41
+ image = Image.open(os.path.join(src_img_folder, filename))
42
+ if not image.mode == "RGB":
43
+ image = image.convert("RGB")
44
+ img = np.array(image, np.uint8)
45
+
46
+ base, _ = os.path.splitext(filename)
47
+ for max_resolution in max_resolutions:
48
+ # Calculate max_pixels from max_resolution string
49
+ max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
50
+
51
+ # Calculate current number of pixels
52
+ current_pixels = img.shape[0] * img.shape[1]
53
+
54
+ # Check if the image needs resizing
55
+ if current_pixels > max_pixels:
56
+ # Calculate scaling factor
57
+ scale_factor = max_pixels / current_pixels
58
+
59
+ # Calculate new dimensions
60
+ new_height = int(img.shape[0] * math.sqrt(scale_factor))
61
+ new_width = int(img.shape[1] * math.sqrt(scale_factor))
62
+
63
+ # Resize image
64
+ img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
65
+ else:
66
+ new_height, new_width = img.shape[0:2]
67
+
68
+ # Calculate the new height and width that are divisible by divisible_by (with/without resizing)
69
+ new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
70
+ new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
71
+
72
+ # Center crop the image to the calculated dimensions
73
+ y = int((img.shape[0] - new_height) / 2)
74
+ x = int((img.shape[1] - new_width) / 2)
75
+ img = img[y:y + new_height, x:x + new_width]
76
+
77
+ # Split filename into base and extension
78
+ new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg')
79
+
80
+ # Save resized image in dst_img_folder
81
+ # cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
82
+ image = Image.fromarray(img)
83
+ image.save(os.path.join(dst_img_folder, new_filename), quality=100)
84
+
85
+ proc = "Resized" if current_pixels > max_pixels else "Saved"
86
+ print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
87
+
88
+ # If other files with same basename, copy them with resolution suffix
89
+ if copy_associated_files:
90
+ asoc_files = glob.glob(os.path.join(src_img_folder, base + ".*"))
91
+ for asoc_file in asoc_files:
92
+ ext = os.path.splitext(asoc_file)[1]
93
+ if ext in img_exts:
94
+ continue
95
+ for max_resolution in max_resolutions:
96
+ new_asoc_file = base + '+' + max_resolution + ext
97
+ print(f"Copy {asoc_file} as {new_asoc_file}")
98
+ shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))
99
+
100
+
101
+ def main():
102
+ parser = argparse.ArgumentParser(
103
+ description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします')
104
+ parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ')
105
+ parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images / リサイズ後の画像を保存するフォルダ')
106
+ parser.add_argument('--max_resolution', type=str,
107
+ help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
108
+ parser.add_argument('--divisible_by', type=int,
109
+ help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
110
+ parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
111
+ default='area', help='Interpolation method for resizing / リサイズ時の補完方法')
112
+ parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
113
+ parser.add_argument('--copy_associated_files', action='store_true',
114
+ help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')
115
+
116
+ args = parser.parse_args()
117
+ resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution,
118
+ args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files)
119
+
120
+
121
+ if __name__ == '__main__':
122
+ main()
train_db.py CHANGED
@@ -15,11 +15,7 @@ import diffusers
15
  from diffusers import DDPMScheduler
16
 
17
  import library.train_util as train_util
18
- import library.config_util as config_util
19
- from library.config_util import (
20
- ConfigSanitizer,
21
- BlueprintGenerator,
22
- )
23
 
24
 
25
  def collate_fn(examples):
@@ -37,33 +33,24 @@ def train(args):
37
 
38
  tokenizer = train_util.load_tokenizer(args)
39
 
40
- blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
41
- if args.dataset_config is not None:
42
- print(f"Load dataset config from {args.dataset_config}")
43
- user_config = config_util.load_user_config(args.dataset_config)
44
- ignored = ["train_data_dir", "reg_data_dir"]
45
- if any(getattr(args, attr) is not None for attr in ignored):
46
- print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
47
- else:
48
- user_config = {
49
- "datasets": [{
50
- "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
51
- }]
52
- }
53
-
54
- blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
55
- train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
56
 
57
  if args.no_token_padding:
58
- train_dataset_group.disable_token_padding()
 
 
 
 
 
59
 
60
  if args.debug_dataset:
61
- train_util.debug_dataset(train_dataset_group)
62
  return
63
 
64
- if cache_latents:
65
- assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
66
-
67
  # acceleratorを準備する
68
  print("prepare accelerator")
69
 
@@ -104,7 +91,7 @@ def train(args):
104
  vae.requires_grad_(False)
105
  vae.eval()
106
  with torch.no_grad():
107
- train_dataset_group.cache_latents(vae)
108
  vae.to("cpu")
109
  if torch.cuda.is_available():
110
  torch.cuda.empty_cache()
@@ -128,18 +115,38 @@ def train(args):
128
 
129
  # 学習に必要なクラスを準備する
130
  print("prepare optimizer, data loader etc.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  if train_text_encoder:
132
  trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
133
  else:
134
  trainable_params = unet.parameters()
135
 
136
- _, _, optimizer = train_util.get_optimizer(args, trainable_params)
 
137
 
138
  # dataloaderを準備する
139
  # DataLoaderのプロセス数:0はメインプロセスになる
140
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
141
  train_dataloader = torch.utils.data.DataLoader(
142
- train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
143
 
144
  # 学習ステップ数を計算する
145
  if args.max_train_epochs is not None:
@@ -149,10 +156,9 @@ def train(args):
149
  if args.stop_text_encoder_training is None:
150
  args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
151
 
152
- # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
153
- lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
154
- num_training_steps=args.max_train_steps,
155
- num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
156
 
157
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
158
  if args.full_fp16:
@@ -189,8 +195,8 @@ def train(args):
189
  # 学習する
190
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
191
  print("running training / 学習開始")
192
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
193
- print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
194
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
195
  print(f" num epochs / epoch数: {num_train_epochs}")
196
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
@@ -211,7 +217,7 @@ def train(args):
211
  loss_total = 0.0
212
  for epoch in range(num_train_epochs):
213
  print(f"epoch {epoch+1}/{num_train_epochs}")
214
- train_dataset_group.set_current_epoch(epoch + 1)
215
 
216
  # 指定したステップ数までText Encoderを学習する:epoch最初の状態
217
  unet.train()
@@ -275,12 +281,12 @@ def train(args):
275
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
276
 
277
  accelerator.backward(loss)
278
- if accelerator.sync_gradients and args.max_grad_norm != 0.0:
279
  if train_text_encoder:
280
  params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
281
  else:
282
  params_to_clip = unet.parameters()
283
- accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
284
 
285
  optimizer.step()
286
  lr_scheduler.step()
@@ -291,13 +297,9 @@ def train(args):
291
  progress_bar.update(1)
292
  global_step += 1
293
 
294
- train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
295
-
296
  current_loss = loss.detach().item()
297
  if args.logging_dir is not None:
298
- logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
299
- if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
300
- logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
301
  accelerator.log(logs, step=global_step)
302
 
303
  if epoch == 0:
@@ -324,8 +326,6 @@ def train(args):
324
  train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
325
  save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
326
 
327
- train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
328
-
329
  is_main_process = accelerator.is_main_process
330
  if is_main_process:
331
  unet = unwrap_model(unet)
@@ -352,8 +352,6 @@ if __name__ == '__main__':
352
  train_util.add_dataset_arguments(parser, True, False, True)
353
  train_util.add_training_arguments(parser, True)
354
  train_util.add_sd_saving_arguments(parser)
355
- train_util.add_optimizer_arguments(parser)
356
- config_util.add_config_arguments(parser)
357
 
358
  parser.add_argument("--no_token_padding", action="store_true",
359
  help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
 
15
  from diffusers import DDPMScheduler
16
 
17
  import library.train_util as train_util
18
+ from library.train_util import DreamBoothDataset
 
 
 
 
19
 
20
 
21
  def collate_fn(examples):
 
33
 
34
  tokenizer = train_util.load_tokenizer(args)
35
 
36
+ train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
37
+ tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
38
+ args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
39
+ args.bucket_reso_steps, args.bucket_no_upscale,
40
+ args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  if args.no_token_padding:
43
+ train_dataset.disable_token_padding()
44
+
45
+ # 学習データのdropout率を設定する
46
+ train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
47
+
48
+ train_dataset.make_buckets()
49
 
50
  if args.debug_dataset:
51
+ train_util.debug_dataset(train_dataset)
52
  return
53
 
 
 
 
54
  # acceleratorを準備する
55
  print("prepare accelerator")
56
 
 
91
  vae.requires_grad_(False)
92
  vae.eval()
93
  with torch.no_grad():
94
+ train_dataset.cache_latents(vae)
95
  vae.to("cpu")
96
  if torch.cuda.is_available():
97
  torch.cuda.empty_cache()
 
115
 
116
  # 学習に必要なクラスを準備する
117
  print("prepare optimizer, data loader etc.")
118
+
119
+ # 8-bit Adamを使う
120
+ if args.use_8bit_adam:
121
+ try:
122
+ import bitsandbytes as bnb
123
+ except ImportError:
124
+ raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
125
+ print("use 8-bit Adam optimizer")
126
+ optimizer_class = bnb.optim.AdamW8bit
127
+ elif args.use_lion_optimizer:
128
+ try:
129
+ import lion_pytorch
130
+ except ImportError:
131
+ raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
132
+ print("use Lion optimizer")
133
+ optimizer_class = lion_pytorch.Lion
134
+ else:
135
+ optimizer_class = torch.optim.AdamW
136
+
137
  if train_text_encoder:
138
  trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
139
  else:
140
  trainable_params = unet.parameters()
141
 
142
+ # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
143
+ optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
144
 
145
  # dataloaderを準備する
146
  # DataLoaderのプロセス数:0はメインプロセスになる
147
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
148
  train_dataloader = torch.utils.data.DataLoader(
149
+ train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
150
 
151
  # 学習ステップ数を計算する
152
  if args.max_train_epochs is not None:
 
156
  if args.stop_text_encoder_training is None:
157
  args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
158
 
159
+ # lr schedulerを用意する
160
+ lr_scheduler = diffusers.optimization.get_scheduler(
161
+ args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
 
162
 
163
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
164
  if args.full_fp16:
 
195
  # 学習する
196
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
197
  print("running training / 学習開始")
198
+ print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
199
+ print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
200
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
201
  print(f" num epochs / epoch数: {num_train_epochs}")
202
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
 
217
  loss_total = 0.0
218
  for epoch in range(num_train_epochs):
219
  print(f"epoch {epoch+1}/{num_train_epochs}")
220
+ train_dataset.set_current_epoch(epoch + 1)
221
 
222
  # 指定したステップ数までText Encoderを学習する:epoch最初の状態
223
  unet.train()
 
281
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
282
 
283
  accelerator.backward(loss)
284
+ if accelerator.sync_gradients:
285
  if train_text_encoder:
286
  params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
287
  else:
288
  params_to_clip = unet.parameters()
289
+ accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
290
 
291
  optimizer.step()
292
  lr_scheduler.step()
 
297
  progress_bar.update(1)
298
  global_step += 1
299
 
 
 
300
  current_loss = loss.detach().item()
301
  if args.logging_dir is not None:
302
+ logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
 
 
303
  accelerator.log(logs, step=global_step)
304
 
305
  if epoch == 0:
 
326
  train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
327
  save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
328
 
 
 
329
  is_main_process = accelerator.is_main_process
330
  if is_main_process:
331
  unet = unwrap_model(unet)
 
352
  train_util.add_dataset_arguments(parser, True, False, True)
353
  train_util.add_training_arguments(parser, True)
354
  train_util.add_sd_saving_arguments(parser)
 
 
355
 
356
  parser.add_argument("--no_token_padding", action="store_true",
357
  help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
train_network.py CHANGED
@@ -1,4 +1,8 @@
 
 
 
1
  from torch.nn.parallel import DistributedDataParallel as DDP
 
2
  import importlib
3
  import argparse
4
  import gc
@@ -11,41 +15,94 @@ import json
11
  from tqdm import tqdm
12
  import torch
13
  from accelerate.utils import set_seed
 
14
  from diffusers import DDPMScheduler
15
 
16
  import library.train_util as train_util
17
- from library.train_util import (
18
- DreamBoothDataset,
19
- )
20
- import library.config_util as config_util
21
- from library.config_util import (
22
- ConfigSanitizer,
23
- BlueprintGenerator,
24
- )
25
 
26
 
27
  def collate_fn(examples):
28
  return examples[0]
29
 
30
 
31
- # TODO 他のスクリプトと共通化する
32
  def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
33
  logs = {"loss/current": current_loss, "loss/average": avr_loss}
34
 
35
  if args.network_train_unet_only:
36
- logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
37
  elif args.network_train_text_encoder_only:
38
- logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
39
  else:
40
- logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
41
- logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
42
-
43
- if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
44
- logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
45
 
46
  return logs
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def train(args):
50
  session_id = random.randint(0, 2**32)
51
  training_started_at = time.time()
@@ -54,7 +111,6 @@ def train(args):
54
 
55
  cache_latents = args.cache_latents
56
  use_dreambooth_method = args.in_json is None
57
- use_user_config = args.dataset_config is not None
58
 
59
  if args.seed is not None:
60
  set_seed(args.seed)
@@ -62,47 +118,35 @@ def train(args):
62
  tokenizer = train_util.load_tokenizer(args)
63
 
64
  # データセットを準備する
65
- blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
66
- if use_user_config:
67
- print(f"Load dataset config from {args.dataset_config}")
68
- user_config = config_util.load_user_config(args.dataset_config)
69
- ignored = ["train_data_dir", "reg_data_dir", "in_json"]
70
- if any(getattr(args, attr) is not None for attr in ignored):
71
- print(
72
- "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
73
  else:
74
- if use_dreambooth_method:
75
- print("Use DreamBooth method.")
76
- user_config = {
77
- "datasets": [{
78
- "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
79
- }]
80
- }
81
- else:
82
- print("Train with captions.")
83
- user_config = {
84
- "datasets": [{
85
- "subsets": [{
86
- "image_dir": args.train_data_dir,
87
- "metadata_file": args.in_json,
88
- }]
89
- }]
90
- }
91
-
92
- blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
93
- train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
94
 
95
  if args.debug_dataset:
96
- train_util.debug_dataset(train_dataset_group)
97
  return
98
- if len(train_dataset_group) == 0:
99
  print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
100
  return
101
 
102
- if cache_latents:
103
- assert train_dataset_group.is_latent_cacheable(
104
- ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
105
-
106
  # acceleratorを準備する
107
  print("prepare accelerator")
108
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
@@ -117,7 +161,7 @@ def train(args):
117
  if args.lowram:
118
  text_encoder.to("cuda")
119
  unet.to("cuda")
120
-
121
  # モデルに xformers とか memory efficient attention を組み込む
122
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
123
 
@@ -127,7 +171,7 @@ def train(args):
127
  vae.requires_grad_(False)
128
  vae.eval()
129
  with torch.no_grad():
130
- train_dataset_group.cache_latents(vae)
131
  vae.to("cpu")
132
  if torch.cuda.is_available():
133
  torch.cuda.empty_cache()
@@ -164,14 +208,36 @@ def train(args):
164
  # 学習に必要なクラスを準備する
165
  print("prepare optimizer, data loader etc.")
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
168
- optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
 
 
169
 
170
  # dataloaderを準備する
171
  # DataLoaderのプロセス数:0はメインプロセスになる
172
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
173
  train_dataloader = torch.utils.data.DataLoader(
174
- train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
175
 
176
  # 学習ステップ数を計算する
177
  if args.max_train_epochs is not None:
@@ -179,9 +245,11 @@ def train(args):
179
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
180
 
181
  # lr schedulerを用意する
182
- lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
183
- num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
184
- num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
 
 
185
 
186
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
187
  if args.full_fp16:
@@ -249,19 +317,17 @@ def train(args):
249
  args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
250
 
251
  # 学習する
252
- # TODO: find a way to handle total batch size when there are multiple datasets
253
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
254
  print("running training / 学習開始")
255
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
256
- print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
257
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
258
  print(f" num epochs / epoch数: {num_train_epochs}")
259
- print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
260
- # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
261
  print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
262
  print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
263
 
264
- # TODO refactor metadata creation and move to util
265
  metadata = {
266
  "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
267
  "ss_training_started_at": training_started_at, # unix timestamp
@@ -269,10 +335,12 @@ def train(args):
269
  "ss_learning_rate": args.learning_rate,
270
  "ss_text_encoder_lr": args.text_encoder_lr,
271
  "ss_unet_lr": args.unet_lr,
272
- "ss_num_train_images": train_dataset_group.num_train_images,
273
- "ss_num_reg_images": train_dataset_group.num_reg_images,
274
  "ss_num_batches_per_epoch": len(train_dataloader),
275
  "ss_num_epochs": num_train_epochs,
 
 
276
  "ss_gradient_checkpointing": args.gradient_checkpointing,
277
  "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
278
  "ss_max_train_steps": args.max_train_steps,
@@ -284,149 +352,29 @@ def train(args):
284
  "ss_mixed_precision": args.mixed_precision,
285
  "ss_full_fp16": bool(args.full_fp16),
286
  "ss_v2": bool(args.v2),
 
287
  "ss_clip_skip": args.clip_skip,
288
  "ss_max_token_length": args.max_token_length,
 
 
 
 
289
  "ss_cache_latents": bool(args.cache_latents),
 
 
 
290
  "ss_seed": args.seed,
291
- "ss_lowram": args.lowram,
292
  "ss_noise_offset": args.noise_offset,
 
 
 
 
293
  "ss_training_comment": args.training_comment, # will not be updated after training
294
  "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
295
- "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
296
- "ss_max_grad_norm": args.max_grad_norm,
297
- "ss_caption_dropout_rate": args.caption_dropout_rate,
298
- "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
299
- "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
300
- "ss_face_crop_aug_range": args.face_crop_aug_range,
301
- "ss_prior_loss_weight": args.prior_loss_weight,
302
  }
303
 
304
- if use_user_config:
305
- # save metadata of multiple datasets
306
- # NOTE: pack "ss_datasets" value as json one time
307
- # or should also pack nested collections as json?
308
- datasets_metadata = []
309
- tag_frequency = {} # merge tag frequency for metadata editor
310
- dataset_dirs_info = {} # merge subset dirs for metadata editor
311
-
312
- for dataset in train_dataset_group.datasets:
313
- is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
314
- dataset_metadata = {
315
- "is_dreambooth": is_dreambooth_dataset,
316
- "batch_size_per_device": dataset.batch_size,
317
- "num_train_images": dataset.num_train_images, # includes repeating
318
- "num_reg_images": dataset.num_reg_images,
319
- "resolution": (dataset.width, dataset.height),
320
- "enable_bucket": bool(dataset.enable_bucket),
321
- "min_bucket_reso": dataset.min_bucket_reso,
322
- "max_bucket_reso": dataset.max_bucket_reso,
323
- "tag_frequency": dataset.tag_frequency,
324
- "bucket_info": dataset.bucket_info,
325
- }
326
-
327
- subsets_metadata = []
328
- for subset in dataset.subsets:
329
- subset_metadata = {
330
- "img_count": subset.img_count,
331
- "num_repeats": subset.num_repeats,
332
- "color_aug": bool(subset.color_aug),
333
- "flip_aug": bool(subset.flip_aug),
334
- "random_crop": bool(subset.random_crop),
335
- "shuffle_caption": bool(subset.shuffle_caption),
336
- "keep_tokens": subset.keep_tokens,
337
- }
338
-
339
- image_dir_or_metadata_file = None
340
- if subset.image_dir:
341
- image_dir = os.path.basename(subset.image_dir)
342
- subset_metadata["image_dir"] = image_dir
343
- image_dir_or_metadata_file = image_dir
344
-
345
- if is_dreambooth_dataset:
346
- subset_metadata["class_tokens"] = subset.class_tokens
347
- subset_metadata["is_reg"] = subset.is_reg
348
- if subset.is_reg:
349
- image_dir_or_metadata_file = None # not merging reg dataset
350
- else:
351
- metadata_file = os.path.basename(subset.metadata_file)
352
- subset_metadata["metadata_file"] = metadata_file
353
- image_dir_or_metadata_file = metadata_file # may overwrite
354
-
355
- subsets_metadata.append(subset_metadata)
356
-
357
- # merge dataset dir: not reg subset only
358
- # TODO update additional-network extension to show detailed dataset config from metadata
359
- if image_dir_or_metadata_file is not None:
360
- # datasets may have a certain dir multiple times
361
- v = image_dir_or_metadata_file
362
- i = 2
363
- while v in dataset_dirs_info:
364
- v = image_dir_or_metadata_file + f" ({i})"
365
- i += 1
366
- image_dir_or_metadata_file = v
367
-
368
- dataset_dirs_info[image_dir_or_metadata_file] = {
369
- "n_repeats": subset.num_repeats,
370
- "img_count": subset.img_count
371
- }
372
-
373
- dataset_metadata["subsets"] = subsets_metadata
374
- datasets_metadata.append(dataset_metadata)
375
-
376
- # merge tag frequency:
377
- for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
378
- # あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
379
- # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
380
- # なので、ここで複数datasetの回数を合算してもあまり意味はない
381
- if ds_dir_name in tag_frequency:
382
- continue
383
- tag_frequency[ds_dir_name] = ds_freq_for_dir
384
-
385
- metadata["ss_datasets"] = json.dumps(datasets_metadata)
386
- metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
387
- metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
388
- else:
389
- # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
390
- assert len(
391
- train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
392
-
393
- dataset = train_dataset_group.datasets[0]
394
-
395
- dataset_dirs_info = {}
396
- reg_dataset_dirs_info = {}
397
- if use_dreambooth_method:
398
- for subset in dataset.subsets:
399
- info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
400
- info[os.path.basename(subset.image_dir)] = {
401
- "n_repeats": subset.num_repeats,
402
- "img_count": subset.img_count
403
- }
404
- else:
405
- for subset in dataset.subsets:
406
- dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
407
- "n_repeats": subset.num_repeats,
408
- "img_count": subset.img_count
409
- }
410
-
411
- metadata.update({
412
- "ss_batch_size_per_device": args.train_batch_size,
413
- "ss_total_batch_size": total_batch_size,
414
- "ss_resolution": args.resolution,
415
- "ss_color_aug": bool(args.color_aug),
416
- "ss_flip_aug": bool(args.flip_aug),
417
- "ss_random_crop": bool(args.random_crop),
418
- "ss_shuffle_caption": bool(args.shuffle_caption),
419
- "ss_enable_bucket": bool(dataset.enable_bucket),
420
- "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
421
- "ss_min_bucket_reso": dataset.min_bucket_reso,
422
- "ss_max_bucket_reso": dataset.max_bucket_reso,
423
- "ss_keep_tokens": args.keep_tokens,
424
- "ss_dataset_dirs": json.dumps(dataset_dirs_info),
425
- "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
426
- "ss_tag_frequency": json.dumps(dataset.tag_frequency),
427
- "ss_bucket_info": json.dumps(dataset.bucket_info),
428
- })
429
-
430
  # uncomment if another network is added
431
  # for key, value in net_kwargs.items():
432
  # metadata["ss_arg_" + key] = value
@@ -462,7 +410,7 @@ def train(args):
462
  loss_total = 0.0
463
  for epoch in range(num_train_epochs):
464
  print(f"epoch {epoch+1}/{num_train_epochs}")
465
- train_dataset_group.set_current_epoch(epoch + 1)
466
 
467
  metadata["ss_epoch"] = str(epoch+1)
468
 
@@ -499,7 +447,7 @@ def train(args):
499
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
500
 
501
  # Predict the noise residual
502
- with accelerator.autocast():
503
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
504
 
505
  if args.v_parameterization:
@@ -517,9 +465,9 @@ def train(args):
517
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
518
 
519
  accelerator.backward(loss)
520
- if accelerator.sync_gradients and args.max_grad_norm != 0.0:
521
  params_to_clip = network.get_trainable_params()
522
- accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
523
 
524
  optimizer.step()
525
  lr_scheduler.step()
@@ -530,8 +478,6 @@ def train(args):
530
  progress_bar.update(1)
531
  global_step += 1
532
 
533
- train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
534
-
535
  current_loss = loss.detach().item()
536
  if epoch == 0:
537
  loss_list.append(current_loss)
@@ -562,7 +508,6 @@ def train(args):
562
  def save_func():
563
  ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
564
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
565
- metadata["ss_training_finished_at"] = str(time.time())
566
  print(f"saving checkpoint: {ckpt_file}")
567
  unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
568
 
@@ -577,12 +522,9 @@ def train(args):
577
  if saving and args.save_state:
578
  train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
579
 
580
- train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
581
-
582
  # end of epoch
583
 
584
  metadata["ss_epoch"] = str(num_train_epochs)
585
- metadata["ss_training_finished_at"] = str(time.time())
586
 
587
  is_main_process = accelerator.is_main_process
588
  if is_main_process:
@@ -613,8 +555,6 @@ if __name__ == '__main__':
613
  train_util.add_sd_models_arguments(parser)
614
  train_util.add_dataset_arguments(parser, True, True, True)
615
  train_util.add_training_arguments(parser, True)
616
- train_util.add_optimizer_arguments(parser)
617
- config_util.add_config_arguments(parser)
618
 
619
  parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
620
  parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
@@ -622,6 +562,10 @@ if __name__ == '__main__':
622
 
623
  parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
624
  parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
 
 
 
 
625
 
626
  parser.add_argument("--network_weights", type=str, default=None,
627
  help="pretrained weights for network / 学習するネットワークの初期重み")
 
1
+ from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
2
+ from torch.optim import Optimizer
3
+ from torch.cuda.amp import autocast
4
  from torch.nn.parallel import DistributedDataParallel as DDP
5
+ from typing import Optional, Union
6
  import importlib
7
  import argparse
8
  import gc
 
15
  from tqdm import tqdm
16
  import torch
17
  from accelerate.utils import set_seed
18
+ import diffusers
19
  from diffusers import DDPMScheduler
20
 
21
  import library.train_util as train_util
22
+ from library.train_util import DreamBoothDataset, FineTuningDataset
 
 
 
 
 
 
 
23
 
24
 
25
  def collate_fn(examples):
26
  return examples[0]
27
 
28
 
 
29
  def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
30
  logs = {"loss/current": current_loss, "loss/average": avr_loss}
31
 
32
  if args.network_train_unet_only:
33
+ logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
34
  elif args.network_train_text_encoder_only:
35
+ logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
36
  else:
37
+ logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
38
+ logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
 
 
 
39
 
40
  return logs
41
 
42
 
43
+ # Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
44
+ # code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
45
+ # Which is a newer release of diffusers than currently packaged with sd-scripts
46
+ # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
47
+
48
+
49
+ def get_scheduler_fix(
50
+ name: Union[str, SchedulerType],
51
+ optimizer: Optimizer,
52
+ num_warmup_steps: Optional[int] = None,
53
+ num_training_steps: Optional[int] = None,
54
+ num_cycles: int = 1,
55
+ power: float = 1.0,
56
+ ):
57
+ """
58
+ Unified API to get any scheduler from its name.
59
+ Args:
60
+ name (`str` or `SchedulerType`):
61
+ The name of the scheduler to use.
62
+ optimizer (`torch.optim.Optimizer`):
63
+ The optimizer that will be used during training.
64
+ num_warmup_steps (`int`, *optional*):
65
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
66
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
67
+ num_training_steps (`int``, *optional*):
68
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
69
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
70
+ num_cycles (`int`, *optional*):
71
+ The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
72
+ power (`float`, *optional*, defaults to 1.0):
73
+ Power factor. See `POLYNOMIAL` scheduler
74
+ last_epoch (`int`, *optional*, defaults to -1):
75
+ The index of the last epoch when resuming training.
76
+ """
77
+ name = SchedulerType(name)
78
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
79
+ if name == SchedulerType.CONSTANT:
80
+ return schedule_func(optimizer)
81
+
82
+ # All other schedulers require `num_warmup_steps`
83
+ if num_warmup_steps is None:
84
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
85
+
86
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
87
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
88
+
89
+ # All other schedulers require `num_training_steps`
90
+ if num_training_steps is None:
91
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
92
+
93
+ if name == SchedulerType.COSINE_WITH_RESTARTS:
94
+ return schedule_func(
95
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
96
+ )
97
+
98
+ if name == SchedulerType.POLYNOMIAL:
99
+ return schedule_func(
100
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
101
+ )
102
+
103
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
104
+
105
+
106
  def train(args):
107
  session_id = random.randint(0, 2**32)
108
  training_started_at = time.time()
 
111
 
112
  cache_latents = args.cache_latents
113
  use_dreambooth_method = args.in_json is None
 
114
 
115
  if args.seed is not None:
116
  set_seed(args.seed)
 
118
  tokenizer = train_util.load_tokenizer(args)
119
 
120
  # データセットを準備する
121
+ if use_dreambooth_method:
122
+ print("Use DreamBooth method.")
123
+ train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
124
+ tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
125
+ args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
126
+ args.bucket_reso_steps, args.bucket_no_upscale,
127
+ args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
128
+ args.random_crop, args.debug_dataset)
129
  else:
130
+ print("Train with captions.")
131
+ train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
132
+ tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
133
+ args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
134
+ args.bucket_reso_steps, args.bucket_no_upscale,
135
+ args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
136
+ args.dataset_repeats, args.debug_dataset)
137
+
138
+ # 学習データのdropout率を設定する
139
+ train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
140
+
141
+ train_dataset.make_buckets()
 
 
 
 
 
 
 
 
142
 
143
  if args.debug_dataset:
144
+ train_util.debug_dataset(train_dataset)
145
  return
146
+ if len(train_dataset) == 0:
147
  print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
148
  return
149
 
 
 
 
 
150
  # acceleratorを準備する
151
  print("prepare accelerator")
152
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
 
161
  if args.lowram:
162
  text_encoder.to("cuda")
163
  unet.to("cuda")
164
+
165
  # モデルに xformers とか memory efficient attention を組み込む
166
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
167
 
 
171
  vae.requires_grad_(False)
172
  vae.eval()
173
  with torch.no_grad():
174
+ train_dataset.cache_latents(vae)
175
  vae.to("cpu")
176
  if torch.cuda.is_available():
177
  torch.cuda.empty_cache()
 
208
  # 学習に必要なクラスを準備する
209
  print("prepare optimizer, data loader etc.")
210
 
211
+ # 8-bit Adamを使う
212
+ if args.use_8bit_adam:
213
+ try:
214
+ import bitsandbytes as bnb
215
+ except ImportError:
216
+ raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
217
+ print("use 8-bit Adam optimizer")
218
+ optimizer_class = bnb.optim.AdamW8bit
219
+ elif args.use_lion_optimizer:
220
+ try:
221
+ import lion_pytorch
222
+ except ImportError:
223
+ raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
224
+ print("use Lion optimizer")
225
+ optimizer_class = lion_pytorch.Lion
226
+ else:
227
+ optimizer_class = torch.optim.AdamW
228
+
229
+ optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
230
+
231
  trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
232
+
233
+ # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
234
+ optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
235
 
236
  # dataloaderを準備する
237
  # DataLoaderのプロセス数:0はメインプロセスになる
238
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
239
  train_dataloader = torch.utils.data.DataLoader(
240
+ train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
241
 
242
  # 学習ステップ数を計算する
243
  if args.max_train_epochs is not None:
 
245
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
246
 
247
  # lr schedulerを用意する
248
+ # lr_scheduler = diffusers.optimization.get_scheduler(
249
+ lr_scheduler = get_scheduler_fix(
250
+ args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
251
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
252
+ num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
253
 
254
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
255
  if args.full_fp16:
 
317
  args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
318
 
319
  # 学習する
 
320
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
321
  print("running training / 学習開始")
322
+ print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
323
+ print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
324
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
325
  print(f" num epochs / epoch数: {num_train_epochs}")
326
+ print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
327
+ print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
328
  print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
329
  print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
330
 
 
331
  metadata = {
332
  "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
333
  "ss_training_started_at": training_started_at, # unix timestamp
 
335
  "ss_learning_rate": args.learning_rate,
336
  "ss_text_encoder_lr": args.text_encoder_lr,
337
  "ss_unet_lr": args.unet_lr,
338
+ "ss_num_train_images": train_dataset.num_train_images, # includes repeating
339
+ "ss_num_reg_images": train_dataset.num_reg_images,
340
  "ss_num_batches_per_epoch": len(train_dataloader),
341
  "ss_num_epochs": num_train_epochs,
342
+ "ss_batch_size_per_device": args.train_batch_size,
343
+ "ss_total_batch_size": total_batch_size,
344
  "ss_gradient_checkpointing": args.gradient_checkpointing,
345
  "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
346
  "ss_max_train_steps": args.max_train_steps,
 
352
  "ss_mixed_precision": args.mixed_precision,
353
  "ss_full_fp16": bool(args.full_fp16),
354
  "ss_v2": bool(args.v2),
355
+ "ss_resolution": args.resolution,
356
  "ss_clip_skip": args.clip_skip,
357
  "ss_max_token_length": args.max_token_length,
358
+ "ss_color_aug": bool(args.color_aug),
359
+ "ss_flip_aug": bool(args.flip_aug),
360
+ "ss_random_crop": bool(args.random_crop),
361
+ "ss_shuffle_caption": bool(args.shuffle_caption),
362
  "ss_cache_latents": bool(args.cache_latents),
363
+ "ss_enable_bucket": bool(train_dataset.enable_bucket),
364
+ "ss_min_bucket_reso": train_dataset.min_bucket_reso,
365
+ "ss_max_bucket_reso": train_dataset.max_bucket_reso,
366
  "ss_seed": args.seed,
367
+ "ss_keep_tokens": args.keep_tokens,
368
  "ss_noise_offset": args.noise_offset,
369
+ "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
370
+ "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
371
+ "ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
372
+ "ss_bucket_info": json.dumps(train_dataset.bucket_info),
373
  "ss_training_comment": args.training_comment, # will not be updated after training
374
  "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
375
+ "ss_optimizer": optimizer_name
 
 
 
 
 
 
376
  }
377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  # uncomment if another network is added
379
  # for key, value in net_kwargs.items():
380
  # metadata["ss_arg_" + key] = value
 
410
  loss_total = 0.0
411
  for epoch in range(num_train_epochs):
412
  print(f"epoch {epoch+1}/{num_train_epochs}")
413
+ train_dataset.set_current_epoch(epoch + 1)
414
 
415
  metadata["ss_epoch"] = str(epoch+1)
416
 
 
447
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
448
 
449
  # Predict the noise residual
450
+ with autocast():
451
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
452
 
453
  if args.v_parameterization:
 
465
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
466
 
467
  accelerator.backward(loss)
468
+ if accelerator.sync_gradients:
469
  params_to_clip = network.get_trainable_params()
470
+ accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
471
 
472
  optimizer.step()
473
  lr_scheduler.step()
 
478
  progress_bar.update(1)
479
  global_step += 1
480
 
 
 
481
  current_loss = loss.detach().item()
482
  if epoch == 0:
483
  loss_list.append(current_loss)
 
508
  def save_func():
509
  ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
510
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
 
511
  print(f"saving checkpoint: {ckpt_file}")
512
  unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
513
 
 
522
  if saving and args.save_state:
523
  train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
524
 
 
 
525
  # end of epoch
526
 
527
  metadata["ss_epoch"] = str(num_train_epochs)
 
528
 
529
  is_main_process = accelerator.is_main_process
530
  if is_main_process:
 
555
  train_util.add_sd_models_arguments(parser)
556
  train_util.add_dataset_arguments(parser, True, True, True)
557
  train_util.add_training_arguments(parser, True)
 
 
558
 
559
  parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
560
  parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
 
562
 
563
  parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
564
  parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
565
+ parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
566
+ help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
567
+ parser.add_argument("--lr_scheduler_power", type=float, default=1,
568
+ help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
569
 
570
  parser.add_argument("--network_weights", type=str, default=None,
571
  help="pretrained weights for network / 学習するネットワークの初期重み")
train_network_opt.py CHANGED
@@ -1,5 +1,8 @@
 
 
1
  from torch.cuda.amp import autocast
2
  from torch.nn.parallel import DistributedDataParallel as DDP
 
3
  import importlib
4
  import argparse
5
  import gc
@@ -12,49 +15,138 @@ import json
12
  from tqdm import tqdm
13
  import torch
14
  from accelerate.utils import set_seed
15
- #import diffusers
16
  from diffusers import DDPMScheduler
17
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  ##### バケット拡張のためのモジュール
19
  import append_module
20
  ######
21
  import library.train_util as train_util
22
- from library.train_util import (
23
- DreamBoothDataset,
24
- )
25
- import library.config_util as config_util
26
- from library.config_util import (
27
- ConfigSanitizer,
28
- BlueprintGenerator,
29
- )
30
 
31
 
32
  def collate_fn(examples):
33
  return examples[0]
34
 
35
 
36
- # TODO 他のスクリプトと共通化する
37
- def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, split_names=None):
38
  logs = {"loss/current": current_loss, "loss/average": avr_loss}
39
- if not args.split_lora_networks:
40
- if args.network_train_unet_only:
41
- logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
42
- elif args.network_train_text_encoder_only:
43
- logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
44
- else:
45
- logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
46
- logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
47
  else:
48
  last_lrs = lr_scheduler.get_last_lr()
49
- for last_lr, t_name in zip(last_lrs, split_names):
50
- logs[f"lr/{t_name}"] = float(last_lr)
51
- #D-Adaptationの仕様ちゃんと見てないからたぶん分割したのをちゃんと表示するならそれに合わせた記述が必要 でも多分D-Adaptationの挙動的に全部同一の形になるのでいらない
52
- if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
53
- logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  return logs
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def train(args):
59
  session_id = random.randint(0, 2**32)
60
  training_started_at = time.time()
@@ -63,7 +155,6 @@ def train(args):
63
 
64
  cache_latents = args.cache_latents
65
  use_dreambooth_method = args.in_json is None
66
- use_user_config = args.dataset_config is not None
67
 
68
  if args.seed is not None:
69
  set_seed(args.seed)
@@ -71,72 +162,52 @@ def train(args):
71
  tokenizer = train_util.load_tokenizer(args)
72
 
73
  # データセットを準備する
74
- if args.min_resolution:
75
- args.min_resolution = tuple([int(r) for r in args.min_resolution.split(',')])
76
- if len(args.min_resolution) == 1:
77
- args.min_resolution = (args.min_resolution[0], args.min_resolution[0])
78
- blueprint_generator = append_module.BlueprintGenerator(append_module.ConfigSanitizer(True, True, True))
79
- else:
80
- blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
81
- if use_user_config:
82
- print(f"Load dataset config from {args.dataset_config}")
83
- user_config = config_util.load_user_config(args.dataset_config)
84
- ignored = ["train_data_dir", "reg_data_dir", "in_json"]
85
- if any(getattr(args, attr) is not None for attr in ignored):
86
- print(
87
- "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
88
- else:
89
- if use_dreambooth_method:
90
- print("Use DreamBooth method.")
91
- user_config = {
92
- "datasets": [{
93
- "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
94
- }]
95
- }
96
- else:
97
- print("Train with captions.")
98
- user_config = {
99
- "datasets": [{
100
- "subsets": [{
101
- "image_dir": args.train_data_dir,
102
- "metadata_file": args.in_json,
103
- }]
104
- }]
105
- }
106
-
107
- blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
108
- if args.min_resolution:
109
- train_dataset_group = append_module.generate_dataset_group_by_blueprint(blueprint.dataset_group)
110
  else:
111
- train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  if args.debug_dataset:
114
- train_util.debug_dataset(train_dataset_group)
115
  return
116
- if len(train_dataset_group) == 0:
117
  print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
118
  return
119
 
120
- if cache_latents:
121
- assert train_dataset_group.is_latent_cacheable(
122
- ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
123
-
124
  # acceleratorを準備する
125
  print("prepare accelerator")
126
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
127
- is_main_process = accelerator.is_main_process
128
 
129
  # mixed precisionに対応した型を用意しておき適宜castする
130
  weight_dtype, save_dtype = train_util.prepare_dtype(args)
131
 
132
  # モデルを読み込む
133
  text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
134
-
135
- # work on low-ram device
136
- if args.lowram:
137
- text_encoder.to("cuda")
138
- unet.to("cuda")
139
-
140
  # モデルに xformers とか memory efficient attention を組み込む
141
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
142
 
@@ -146,15 +217,13 @@ def train(args):
146
  vae.requires_grad_(False)
147
  vae.eval()
148
  with torch.no_grad():
149
- train_dataset_group.cache_latents(vae)
150
  vae.to("cpu")
151
  if torch.cuda.is_available():
152
  torch.cuda.empty_cache()
153
  gc.collect()
154
 
155
  # prepare network
156
- import sys
157
- sys.path.append(os.path.dirname(__file__))
158
  print("import network module:", args.network_module)
159
  network_module = importlib.import_module(args.network_module)
160
 
@@ -184,65 +253,188 @@ def train(args):
184
 
185
  # 学習に必要なクラスを準備する
186
  print("prepare optimizer, data loader etc.")
187
- split_flag = (args.split_lora_networks) or ((not args.network_train_text_encoder_only) and (not args.network_train_unet_only))
188
-
189
- used_names = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  if args.split_lora_networks:
191
- lr_dic, block_args_dic = append_module.create_lr_blocks(args.blocks_lr_setting, args.block_optim_args)
192
  lora_names = append_module.create_split_names(args.split_lora_networks, args.split_lora_level)
193
- append_module.replace_prepare_optimizer_params(network, network_module)
194
- trainable_params, adafactor_scheduler_arg, used_names = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, lora_names, lr_dic, block_args_dic)
195
  else:
196
  trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
197
- if split_flag:
198
- _t_lr = 0.
199
- _u_lr = 0.
200
- if args.text_encoder_lr:
201
- _t_lr = args.text_encoder_lr
202
- if args.unet_lr:
203
- _u_lr = args.unet_lr
204
- adafactor_scheduler_arg = {"initial_lr": [_t_lr, _u_lr]}
205
-
206
- optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
207
- if args.use_lookahead:
208
- try:
209
- import torch_optimizer
210
- lookahed_arg = {"k": 5, "alpha": 0.5}
211
- if args.lookahead_arg is not None:
212
- for _arg in args.lookahead_arg:
213
- k, v = _arg.split("=")
214
- if k == "k":
215
- lookahed_arg[k] = int(v)
216
- else:
217
- lookahed_arg[k] = float(v)
218
- optimizer = torch_optimizer.Lookahead(optimizer, **lookahed_arg)
219
- except:
220
- print("\n============\ntorch_optimizerのimportに失敗しました Lookaheadを無効化して処理を続けます\n============\n")
221
  # dataloaderを準備する
222
  # DataLoaderのプロセス数:0はメインプロセスになる
223
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
224
  train_dataloader = torch.utils.data.DataLoader(
225
- train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
226
 
227
  # 学習ステップ数を計算する
228
  if args.max_train_epochs is not None:
229
- args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes)
230
- if is_main_process:
231
- print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
232
 
233
  # lr schedulerを用意する
234
- if args.lr_scheduler.startswith("adafactor") and split_flag:
235
- lr_scheduler = append_module.get_scheduler_Adafactor(args.lr_scheduler, optimizer, adafactor_scheduler_arg)
 
 
 
 
236
  else:
237
- lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
238
- num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps,
239
- num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
240
-
 
241
  #追加機能の設定をコメントに追記して残す
242
- if args.use_lookahead:
243
- args.training_comment=f"{args.training_comment} use Lookahead: True Lookahead args: {lookahed_arg}"
244
- if args.split_lora_networks:
245
- args.training_comment=f"{args.training_comment} split_lora_networks: {args.split_lora_networks} split_level: {args.split_lora_level}"
246
  if args.min_resolution:
247
  args.training_comment=f"{args.training_comment} min_resolution: {args.min_resolution} area_step: {args.area_step}"
248
 
@@ -312,21 +504,17 @@ def train(args):
312
  args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
313
 
314
  # 学習する
315
- # TODO: find a way to handle total batch size when there are multiple datasets
316
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
317
-
318
- if is_main_process:
319
- print("running training / 学習開始")
320
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
321
- print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
322
- print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
323
- print(f" num epochs / epoch数: {num_train_epochs}")
324
- print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
325
- # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
326
- print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
327
- print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
328
-
329
- # TODO refactor metadata creation and move to util
330
  metadata = {
331
  "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
332
  "ss_training_started_at": training_started_at, # unix timestamp
@@ -334,10 +522,12 @@ def train(args):
334
  "ss_learning_rate": args.learning_rate,
335
  "ss_text_encoder_lr": args.text_encoder_lr,
336
  "ss_unet_lr": args.unet_lr,
337
- "ss_num_train_images": train_dataset_group.num_train_images,
338
- "ss_num_reg_images": train_dataset_group.num_reg_images,
339
  "ss_num_batches_per_epoch": len(train_dataloader),
340
  "ss_num_epochs": num_train_epochs,
 
 
341
  "ss_gradient_checkpointing": args.gradient_checkpointing,
342
  "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
343
  "ss_max_train_steps": args.max_train_steps,
@@ -349,156 +539,32 @@ def train(args):
349
  "ss_mixed_precision": args.mixed_precision,
350
  "ss_full_fp16": bool(args.full_fp16),
351
  "ss_v2": bool(args.v2),
 
352
  "ss_clip_skip": args.clip_skip,
353
  "ss_max_token_length": args.max_token_length,
 
 
 
 
354
  "ss_cache_latents": bool(args.cache_latents),
 
 
 
355
  "ss_seed": args.seed,
356
- "ss_lowram": args.lowram,
357
  "ss_noise_offset": args.noise_offset,
 
 
 
 
358
  "ss_training_comment": args.training_comment, # will not be updated after training
359
- "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
360
- "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
361
- "ss_max_grad_norm": args.max_grad_norm,
362
- "ss_caption_dropout_rate": args.caption_dropout_rate,
363
- "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
364
- "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
365
- "ss_face_crop_aug_range": args.face_crop_aug_range,
366
- "ss_prior_loss_weight": args.prior_loss_weight,
367
  }
368
 
369
- if use_user_config:
370
- # save metadata of multiple datasets
371
- # NOTE: pack "ss_datasets" value as json one time
372
- # or should also pack nested collections as json?
373
- datasets_metadata = []
374
- tag_frequency = {} # merge tag frequency for metadata editor
375
- dataset_dirs_info = {} # merge subset dirs for metadata editor
376
-
377
- for dataset in train_dataset_group.datasets:
378
- is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
379
- dataset_metadata = {
380
- "is_dreambooth": is_dreambooth_dataset,
381
- "batch_size_per_device": dataset.batch_size,
382
- "num_train_images": dataset.num_train_images, # includes repeating
383
- "num_reg_images": dataset.num_reg_images,
384
- "resolution": (dataset.width, dataset.height),
385
- "enable_bucket": bool(dataset.enable_bucket),
386
- "min_bucket_reso": dataset.min_bucket_reso,
387
- "max_bucket_reso": dataset.max_bucket_reso,
388
- "tag_frequency": dataset.tag_frequency,
389
- "bucket_info": dataset.bucket_info,
390
- }
391
-
392
- subsets_metadata = []
393
- for subset in dataset.subsets:
394
- subset_metadata = {
395
- "img_count": subset.img_count,
396
- "num_repeats": subset.num_repeats,
397
- "color_aug": bool(subset.color_aug),
398
- "flip_aug": bool(subset.flip_aug),
399
- "random_crop": bool(subset.random_crop),
400
- "shuffle_caption": bool(subset.shuffle_caption),
401
- "keep_tokens": subset.keep_tokens,
402
- }
403
-
404
- image_dir_or_metadata_file = None
405
- if subset.image_dir:
406
- image_dir = os.path.basename(subset.image_dir)
407
- subset_metadata["image_dir"] = image_dir
408
- image_dir_or_metadata_file = image_dir
409
-
410
- if is_dreambooth_dataset:
411
- subset_metadata["class_tokens"] = subset.class_tokens
412
- subset_metadata["is_reg"] = subset.is_reg
413
- if subset.is_reg:
414
- image_dir_or_metadata_file = None # not merging reg dataset
415
- else:
416
- metadata_file = os.path.basename(subset.metadata_file)
417
- subset_metadata["metadata_file"] = metadata_file
418
- image_dir_or_metadata_file = metadata_file # may overwrite
419
-
420
- subsets_metadata.append(subset_metadata)
421
-
422
- # merge dataset dir: not reg subset only
423
- # TODO update additional-network extension to show detailed dataset config from metadata
424
- if image_dir_or_metadata_file is not None:
425
- # datasets may have a certain dir multiple times
426
- v = image_dir_or_metadata_file
427
- i = 2
428
- while v in dataset_dirs_info:
429
- v = image_dir_or_metadata_file + f" ({i})"
430
- i += 1
431
- image_dir_or_metadata_file = v
432
-
433
- dataset_dirs_info[image_dir_or_metadata_file] = {
434
- "n_repeats": subset.num_repeats,
435
- "img_count": subset.img_count
436
- }
437
-
438
- dataset_metadata["subsets"] = subsets_metadata
439
- datasets_metadata.append(dataset_metadata)
440
-
441
- # merge tag frequency:
442
- for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
443
- # あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
444
- # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
445
- # なので、ここで複数datasetの回数を合算してもあまり意味はない
446
- if ds_dir_name in tag_frequency:
447
- continue
448
- tag_frequency[ds_dir_name] = ds_freq_for_dir
449
-
450
- metadata["ss_datasets"] = json.dumps(datasets_metadata)
451
- metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
452
- metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
453
- else:
454
- # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
455
- assert len(
456
- train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
457
-
458
- dataset = train_dataset_group.datasets[0]
459
-
460
- dataset_dirs_info = {}
461
- reg_dataset_dirs_info = {}
462
- if use_dreambooth_method:
463
- for subset in dataset.subsets:
464
- info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
465
- info[os.path.basename(subset.image_dir)] = {
466
- "n_repeats": subset.num_repeats,
467
- "img_count": subset.img_count
468
- }
469
- else:
470
- for subset in dataset.subsets:
471
- dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
472
- "n_repeats": subset.num_repeats,
473
- "img_count": subset.img_count
474
- }
475
-
476
- metadata.update({
477
- "ss_batch_size_per_device": args.train_batch_size,
478
- "ss_total_batch_size": total_batch_size,
479
- "ss_resolution": args.resolution,
480
- "ss_color_aug": bool(args.color_aug),
481
- "ss_flip_aug": bool(args.flip_aug),
482
- "ss_random_crop": bool(args.random_crop),
483
- "ss_shuffle_caption": bool(args.shuffle_caption),
484
- "ss_enable_bucket": bool(dataset.enable_bucket),
485
- "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
486
- "ss_min_bucket_reso": dataset.min_bucket_reso,
487
- "ss_max_bucket_reso": dataset.max_bucket_reso,
488
- "ss_keep_tokens": args.keep_tokens,
489
- "ss_dataset_dirs": json.dumps(dataset_dirs_info),
490
- "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
491
- "ss_tag_frequency": json.dumps(dataset.tag_frequency),
492
- "ss_bucket_info": json.dumps(dataset.bucket_info),
493
- })
494
-
495
- # add extra args
496
- if args.network_args:
497
- metadata["ss_network_args"] = json.dumps(net_kwargs)
498
  # for key, value in net_kwargs.items():
499
  # metadata["ss_arg_" + key] = value
500
 
501
- # model name and hash
502
  if args.pretrained_model_name_or_path is not None:
503
  sd_model_name = args.pretrained_model_name_or_path
504
  if os.path.exists(sd_model_name):
@@ -517,13 +583,6 @@ def train(args):
517
 
518
  metadata = {k: str(v) for k, v in metadata.items()}
519
 
520
- # make minimum metadata for filtering
521
- minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"]
522
- minimum_metadata = {}
523
- for key in minimum_keys:
524
- if key in metadata:
525
- minimum_metadata[key] = metadata[key]
526
-
527
  progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
528
  global_step = 0
529
 
@@ -536,9 +595,8 @@ def train(args):
536
  loss_list = []
537
  loss_total = 0.0
538
  for epoch in range(num_train_epochs):
539
- if is_main_process:
540
- print(f"epoch {epoch+1}/{num_train_epochs}")
541
- train_dataset_group.set_current_epoch(epoch + 1)
542
 
543
  metadata["ss_epoch"] = str(epoch+1)
544
 
@@ -575,7 +633,7 @@ def train(args):
575
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
576
 
577
  # Predict the noise residual
578
- with accelerator.autocast():
579
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
580
 
581
  if args.v_parameterization:
@@ -593,13 +651,12 @@ def train(args):
593
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
594
 
595
  accelerator.backward(loss)
596
- if accelerator.sync_gradients and args.max_grad_norm != 0.0:
597
  params_to_clip = network.get_trainable_params()
598
- accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
599
 
600
  optimizer.step()
601
- if accelerator.sync_gradients:
602
- lr_scheduler.step()
603
  optimizer.zero_grad(set_to_none=True)
604
 
605
  # Checks if the accelerator has performed an optimization step behind the scenes
@@ -607,8 +664,6 @@ def train(args):
607
  progress_bar.update(1)
608
  global_step += 1
609
 
610
- train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
611
-
612
  current_loss = loss.detach().item()
613
  if epoch == 0:
614
  loss_list.append(current_loss)
@@ -621,7 +676,7 @@ def train(args):
621
  progress_bar.set_postfix(**logs)
622
 
623
  if args.logging_dir is not None:
624
- logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, used_names)
625
  accelerator.log(logs, step=global_step)
626
 
627
  if global_step >= args.max_train_steps:
@@ -639,9 +694,8 @@ def train(args):
639
  def save_func():
640
  ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
641
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
642
- metadata["ss_training_finished_at"] = str(time.time())
643
  print(f"saving checkpoint: {ckpt_file}")
644
- unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
645
 
646
  def remove_old_func(old_epoch_no):
647
  old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
@@ -650,18 +704,15 @@ def train(args):
650
  print(f"removing old checkpoint: {old_ckpt_file}")
651
  os.remove(old_ckpt_file)
652
 
653
- if is_main_process:
654
- saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
655
- if saving and args.save_state:
656
- train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
657
-
658
- train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
659
 
660
  # end of epoch
661
 
662
  metadata["ss_epoch"] = str(num_train_epochs)
663
- metadata["ss_training_finished_at"] = str(time.time())
664
 
 
665
  if is_main_process:
666
  network = unwrap_model(network)
667
 
@@ -680,7 +731,7 @@ def train(args):
680
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
681
 
682
  print(f"save trained model to {ckpt_file}")
683
- network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
684
  print("model saved.")
685
 
686
 
@@ -690,8 +741,6 @@ if __name__ == '__main__':
690
  train_util.add_sd_models_arguments(parser)
691
  train_util.add_dataset_arguments(parser, True, True, True)
692
  train_util.add_training_arguments(parser, True)
693
- train_util.add_optimizer_arguments(parser)
694
- config_util.add_config_arguments(parser)
695
 
696
  parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
697
  parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
@@ -699,6 +748,10 @@ if __name__ == '__main__':
699
 
700
  parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
701
  parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
 
 
 
 
702
 
703
  parser.add_argument("--network_weights", type=str, default=None,
704
  help="pretrained weights for network / 学習するネットワークの初期重み")
@@ -718,30 +771,27 @@ if __name__ == '__main__':
718
  #Optimizer変更関連のオプション追加
719
  append_module.add_append_arguments(parser)
720
  args = append_module.get_config(parser)
721
- if not args.not_output_config:
722
- #argsを保存する
723
- import yaml
724
- import datetime
725
- _t = datetime.datetime.today().strftime('%Y%m%d_%H%M')
726
- if args.output_name==None:
727
- config_name = f"train_network_config_{_t}.yaml"
728
- else:
729
- config_name = f"train_network_config_{os.path.basename(args.output_name)}_{_t}.yaml"
730
- print(f"{config_name} に設定を書き出し中...")
731
- with open(config_name, mode="w") as f:
732
- yaml.dump(args.__dict__, f, indent=4)
733
 
734
  if args.resolution==args.min_resolution:
735
  args.min_resolution=None
736
 
737
  train(args)
738
- print("done!")
739
 
 
 
 
 
 
 
 
 
 
 
 
 
740
 
741
  '''
742
  optimizer設定メモ
743
- torch_optimizer.AdaBelief
744
- adastand.Adastand
745
  (optimizer_argから設定できるように変更するためのメモ)
746
 
747
  AdamWのweight_decay初期値は1e-2
@@ -771,7 +821,6 @@ Adafactor
771
  transformerベースのT5学習において最強とかいう噂のoptimizer
772
  huggingfaceのサンプルパラ
773
  eps=1e-30,1e-3 clip_threshold=1.0 decay_rate=-0.8 relative_step=False scale_parameter=False warmup_init=False
774
- epsの二つ目の値1e-3が学習率に影響大きい
775
 
776
  AggMo
777
 
 
1
+ from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
2
+ from torch.optim import Optimizer
3
  from torch.cuda.amp import autocast
4
  from torch.nn.parallel import DistributedDataParallel as DDP
5
+ from typing import Optional, Union
6
  import importlib
7
  import argparse
8
  import gc
 
15
  from tqdm import tqdm
16
  import torch
17
  from accelerate.utils import set_seed
18
+ import diffusers
19
  from diffusers import DDPMScheduler
20
+ print("**********************************")
21
+ #先に
22
+ #pip install torch_optimizer
23
+ #が必要
24
+ try:
25
+ import torch_optimizer as optim
26
+ except:
27
+ print("torch_optimizerがインストールされていないためAdafactorとAdastand以外の追加optimzierは使えません。\noptimizerの変更をしたい場合先にpip install torch_optimizerでライブラリを追加してください")
28
+ try:
29
+ import adastand
30
+ except:
31
+ print("※Adastandが使えません")
32
+
33
+ from transformers.optimization import Adafactor, AdafactorSchedule
34
+ print("**********************************")
35
  ##### バケット拡張のためのモジュール
36
  import append_module
37
  ######
38
  import library.train_util as train_util
39
+ from library.train_util import DreamBoothDataset, FineTuningDataset
 
 
 
 
 
 
 
40
 
41
 
42
  def collate_fn(examples):
43
  return examples[0]
44
 
45
 
46
+ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
 
47
  logs = {"loss/current": current_loss, "loss/average": avr_loss}
48
+
49
+ if args.network_train_unet_only:
50
+ logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
51
+ elif args.network_train_text_encoder_only:
52
+ logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
 
 
 
53
  else:
54
  last_lrs = lr_scheduler.get_last_lr()
55
+ if len(last_lrs) == 2:
56
+ logs["lr/textencoder"] = float(last_lrs[0])
57
+ logs["lr/unet"] = float(last_lrs[-1]) # may be same to textencoder
58
+ else:
59
+ if len(last_lrs) == 4:
60
+ logs_names = ["textencoder", "lora_unet_mid_block", "unet_down_blocks", "unet_up_blocks"]
61
+ elif len(last_lrs) == 8:
62
+ logs_names = ["textencoder", "unet_midblock"]
63
+ for i in range(3):
64
+ logs_names.append(f"unet_down_blocks_{i}")
65
+ logs_names.append(f"unet_up_blocks_{i+1}")
66
+ else:
67
+ logs_names = []
68
+ for i in range(12):
69
+ logs_names.append(f"text_model_encoder_layers_{i}_")
70
+ logs_names.append("unet_midblock")
71
+ for i in range(3):
72
+ logs_names.append(f"unet_down_blocks_{i}")
73
+ logs_names.append(f"unet_up_blocks_{i+1}")
74
+
75
+ for last_lr, logs_name in zip(last_lrs, logs_names):
76
+ logs[f"lr/{logs_name}"] = float(last_lr)
77
 
78
  return logs
79
 
80
 
81
+ # Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
82
+ # code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
83
+ # Which is a newer release of diffusers than currently packaged with sd-scripts
84
+ # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
85
+
86
+
87
+ def get_scheduler_fix(
88
+ name: Union[str, SchedulerType],
89
+ optimizer: Optimizer,
90
+ num_warmup_steps: Optional[int] = None,
91
+ num_training_steps: Optional[int] = None,
92
+ num_cycles: float = 1.,
93
+ power: float = 1.0,
94
+ ):
95
+ """
96
+ Unified API to get any scheduler from its name.
97
+ Args:
98
+ name (`str` or `SchedulerType`):
99
+ The name of the scheduler to use.
100
+ optimizer (`torch.optim.Optimizer`):
101
+ The optimizer that will be used during training.
102
+ num_warmup_steps (`int`, *optional*):
103
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
104
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
105
+ num_training_steps (`int``, *optional*):
106
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
107
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
108
+ num_cycles (`int`, *optional*):
109
+ The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
110
+ power (`float`, *optional*, defaults to 1.0):
111
+ Power factor. See `POLYNOMIAL` scheduler
112
+ last_epoch (`int`, *optional*, defaults to -1):
113
+ The index of the last epoch when resuming training.
114
+ """
115
+ name = SchedulerType(name)
116
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
117
+ if name == SchedulerType.CONSTANT:
118
+ return schedule_func(optimizer)
119
+
120
+ # All other schedulers require `num_warmup_steps`
121
+ if num_warmup_steps is None:
122
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
123
+
124
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
125
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
126
+
127
+ # All other schedulers require `num_training_steps`
128
+ if num_training_steps is None:
129
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
130
+
131
+ if name == SchedulerType.COSINE:
132
+ print(f"{name} num_cycles: {num_cycles}")
133
+ return schedule_func(
134
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
135
+ )
136
+ if name == SchedulerType.COSINE_WITH_RESTARTS:
137
+ print(f"{name} num_cycles: {int(num_cycles)}")
138
+ return schedule_func(
139
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=int(num_cycles)
140
+ )
141
+
142
+ if name == SchedulerType.POLYNOMIAL:
143
+ return schedule_func(
144
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
145
+ )
146
+
147
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
148
+
149
+
150
  def train(args):
151
  session_id = random.randint(0, 2**32)
152
  training_started_at = time.time()
 
155
 
156
  cache_latents = args.cache_latents
157
  use_dreambooth_method = args.in_json is None
 
158
 
159
  if args.seed is not None:
160
  set_seed(args.seed)
 
162
  tokenizer = train_util.load_tokenizer(args)
163
 
164
  # データセットを準備する
165
+ if use_dreambooth_method:
166
+ if args.min_resolution:
167
+ args.min_resolution = tuple([int(r) for r in args.min_resolution.split(',')])
168
+ if len(args.min_resolution) == 1:
169
+ args.min_resolution = (args.min_resolution[0], args.min_resolution[0])
170
+
171
+ print("Use DreamBooth method.")
172
+ train_dataset = append_module.DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
173
+ tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
174
+ args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
175
+ args.bucket_reso_steps, args.bucket_no_upscale,
176
+ args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
177
+ args.random_crop, args.debug_dataset, args.min_resolution, args.area_step)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  else:
179
+ print("Train with captions.")
180
+ train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
181
+ tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
182
+ args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
183
+ args.bucket_reso_steps, args.bucket_no_upscale,
184
+ args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
185
+ args.dataset_repeats, args.debug_dataset)
186
+
187
+ # 学習データのdropout率を設定する
188
+ train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
189
+
190
+ train_dataset.make_buckets()
191
 
192
  if args.debug_dataset:
193
+ train_util.debug_dataset(train_dataset)
194
  return
195
+ if len(train_dataset) == 0:
196
  print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
197
  return
198
 
 
 
 
 
199
  # acceleratorを準備する
200
  print("prepare accelerator")
201
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
 
202
 
203
  # mixed precisionに対応した型を用意しておき適宜castする
204
  weight_dtype, save_dtype = train_util.prepare_dtype(args)
205
 
206
  # モデルを読み込む
207
  text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
208
+ # unnecessary, but work on low-ram device
209
+ text_encoder.to("cuda")
210
+ unet.to("cuda")
 
 
 
211
  # モデルに xformers とか memory efficient attention を組み込む
212
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
213
 
 
217
  vae.requires_grad_(False)
218
  vae.eval()
219
  with torch.no_grad():
220
+ train_dataset.cache_latents(vae)
221
  vae.to("cpu")
222
  if torch.cuda.is_available():
223
  torch.cuda.empty_cache()
224
  gc.collect()
225
 
226
  # prepare network
 
 
227
  print("import network module:", args.network_module)
228
  network_module = importlib.import_module(args.network_module)
229
 
 
253
 
254
  # 学習に必要なクラスを準備する
255
  print("prepare optimizer, data loader etc.")
256
+ try:
257
+ print(f"torch_optimzier version is {optim.__version__}")
258
+ not_torch_optimizer_flag = False
259
+ except:
260
+ not_torch_optimizer_flag = True
261
+ try:
262
+ print(f"adastand version is {adastand.__version__()}")
263
+ not_adasatand_optimzier_flag = False
264
+ except:
265
+ not_adasatand_optimzier_flag = True
266
+
267
+ # 8-bit Adamを使う
268
+ if args.optimizer=="Adafactor" or args.optimizer=="Adastand" or args.optimizer=="Adastand_belief":
269
+ not_torch_optimizer_flag = False
270
+ if args.optimizer=="Adafactor":
271
+ not_adasatand_optimzier_flag = False
272
+ if not_torch_optimizer_flag or not_adasatand_optimzier_flag:
273
+ print(f"==========================\n必要なライブラリがないため {args.optimizer} の使用ができません。optimizerを AdamW に変更して実行します\n==========================")
274
+ args.optimizer="AdamW"
275
+ if args.use_8bit_adam:
276
+ if not args.optimizer=="AdamW" and not args.optimizer=="Lamb":
277
+ print(f"\n==========================\n{args.optimizer} は8bitAdamに実装されていないので8bitAdamをオフにします\n==========================\n")
278
+ args.use_8bit_adam=False
279
+ if args.use_8bit_adam:
280
+ try:
281
+ import bitsandbytes as bnb
282
+ except ImportError:
283
+ raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
284
+ print("use 8-bit Adam optimizer")
285
+ args.training_comment=f"{args.training_comment} use_8bit_adam=True"
286
+ if args.optimizer=="Lamb":
287
+ optimizer_class = bnb.optim.LAMB8bit
288
+ else:
289
+ args.optimizer="AdamW"
290
+ optimizer_class = bnb.optim.AdamW8bit
291
+ else:
292
+ print(f"use {args.optimizer}")
293
+ if args.optimizer=="RAdam":
294
+ optimizer_class = torch.optim.RAdam
295
+ elif args.optimizer=="AdaBound":
296
+ optimizer_class = optim.AdaBound
297
+ elif args.optimizer=="AdaBelief":
298
+ optimizer_class = optim.AdaBelief
299
+ elif args.optimizer=="AdamP":
300
+ optimizer_class = optim.AdamP
301
+ elif args.optimizer=="Adafactor":
302
+ optimizer_class = Adafactor
303
+ elif args.optimizer=="Adastand":
304
+ optimizer_class = adastand.Adastand
305
+ elif args.optimizer=="Adastand_belief":
306
+ optimizer_class = adastand.Adastand_b
307
+ elif args.optimizer=="AggMo":
308
+ optimizer_class = optim.AggMo
309
+ elif args.optimizer=="Apollo":
310
+ optimizer_class = optim.Apollo
311
+ elif args.optimizer=="Lamb":
312
+ optimizer_class = optim.Lamb
313
+ elif args.optimizer=="Ranger":
314
+ optimizer_class = optim.Ranger
315
+ elif args.optimizer=="RangerVA":
316
+ optimizer_class = optim.RangerVA
317
+ elif args.optimizer=="Yogi":
318
+ optimizer_class = optim.Yogi
319
+ elif args.optimizer=="Shampoo":
320
+ optimizer_class = optim.Shampoo
321
+ elif args.optimizer=="NovoGrad":
322
+ optimizer_class = optim.NovoGrad
323
+ elif args.optimizer=="QHAdam":
324
+ optimizer_class = optim.QHAdam
325
+ elif args.optimizer=="DiffGrad" or args.optimizer=="Lookahead_DiffGrad":
326
+ optimizer_class = optim.DiffGrad
327
+ elif args.optimizer=="MADGRAD":
328
+ optimizer_class = optim.MADGRAD
329
+ else:
330
+ optimizer_class = torch.optim.AdamW
331
+
332
+ # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
333
+ #optimizerデフォ設定
334
+ if args.optimizer_arg==None:
335
+ if args.optimizer=="AdaBelief":
336
+ args.optimizer_arg = ["eps=1e-16","betas=0.9,0.999","weight_decouple=True","rectify=False","fixed_decay=False"]
337
+ elif args.optimizer=="DiffGrad":
338
+ args.optimizer_arg = ["eps=1e-16"]
339
+ optimizer_arg = {}
340
+ lookahed_arg = {"k": 5, "alpha": 0.5}
341
+ adafactor_scheduler_arg = {"initial_lr": 0.}
342
+ int_args = ["k","n_sma_threshold","warmup"]
343
+ str_args = ["transformer","grad_transformer"]
344
+ if not args.optimizer_arg==None and len(args.optimizer_arg)>0:
345
+ for _opt_arg in args.optimizer_arg:
346
+ key, value = _opt_arg.split("=")
347
+ if value=="True" or value=="False":
348
+ optimizer_arg[key]=bool((value=="True"))
349
+ elif key=="betas" or key=="nus" or key=="eps2" or (key=="eps" and "," in value):
350
+ _value = value.split(",")
351
+ optimizer_arg[key] = (float(_value[0]),float(_value[1]))
352
+ del _value
353
+ elif key in int_args:
354
+ if "Lookahead" in args.optimizer:
355
+ lookahed_arg[key] = int(value)
356
+ else:
357
+ optimizer_arg[key] = int(value)
358
+ elif key in str_args:
359
+ optimizer_arg[key] = value
360
+ else:
361
+ if key=="alpha" and "Lookahead" in args.optimizer:
362
+ lookahed_arg[key] = int(value)
363
+ elif key=="initial_lr" and args.optimizer == "Adafactor":
364
+ adafactor_scheduler_arg[key] = float(value)
365
+ else:
366
+ optimizer_arg[key] = float(value)
367
+ del _opt_arg
368
+ AdafactorScheduler_Flag = False
369
+ list_of_init_lr = []
370
+ if args.optimizer=="Adafactor":
371
+ if not "relative_step" in optimizer_arg:
372
+ optimizer_arg["relative_step"] = True
373
+ if "warmup_init" in optimizer_arg:
374
+ if optimizer_arg["warmup_init"]==True and optimizer_arg["relative_step"]==False:
375
+ print("**************\nwarmup_initはrelative_stepがオンである必要があるためrelative_stepをオンにします\n**************")
376
+ optimizer_arg["relative_step"] = True
377
+ if optimizer_arg["relative_step"] == True:
378
+ AdafactorScheduler_Flag = True
379
+ list_of_init_lr = [0.,0.]
380
+ if args.text_encoder_lr is not None: list_of_init_lr[0] = float(args.text_encoder_lr)
381
+ if args.unet_lr is not None: list_of_init_lr[1] = float(args.unet_lr)
382
+ #if not "initial_lr" in adafactor_scheduler_arg:
383
+ # adafactor_scheduler_arg = args.learning_rate
384
+ args.learning_rate = None
385
+ args.text_encoder_lr = None
386
+ args.unet_lr = None
387
+ print(f"optimizer arg: {optimizer_arg}")
388
+ print("=-----------------------------------=")
389
+ if not AdafactorScheduler_Flag: args.split_lora_networks = False
390
  if args.split_lora_networks:
 
391
  lora_names = append_module.create_split_names(args.split_lora_networks, args.split_lora_level)
392
+ append_module.replace_prepare_optimizer_params(network)
393
+ trainable_params, _list_of_init_lr = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, list_of_init_lr, lora_names)
394
  else:
395
  trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
396
+ _list_of_init_lr = []
397
+ print(f"trainable_params_len: {len(trainable_params)}")
398
+ if len(_list_of_init_lr)>0:
399
+ list_of_init_lr = _list_of_init_lr
400
+ print(f"split loras network is {len(list_of_init_lr)}")
401
+ if len(list_of_init_lr) > 0:
402
+ adafactor_scheduler_arg["initial_lr"] = list_of_init_lr
403
+
404
+ optimizer = optimizer_class(trainable_params, lr=args.learning_rate, **optimizer_arg)
405
+
406
+ if args.optimizer=="Lookahead_DiffGrad" or args.optimizer=="Lookahedad_Adam":
407
+ optimizer = optim.Lookahead(optimizer, **lookahed_arg)
408
+ print(f"lookahed_arg: {lookahed_arg}")
409
+
 
 
 
 
 
 
 
 
 
 
410
  # dataloaderを準備する
411
  # DataLoaderのプロセス数:0はメインプロセスになる
412
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
413
  train_dataloader = torch.utils.data.DataLoader(
414
+ train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
415
 
416
  # 学習ステップ数を計算する
417
  if args.max_train_epochs is not None:
418
+ args.max_train_steps = args.max_train_epochs * len(train_dataloader)
419
+ print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
 
420
 
421
  # lr schedulerを用意する
422
+ # lr_scheduler = diffusers.optimization.get_scheduler(
423
+ if AdafactorScheduler_Flag:
424
+ print("===================================\nAdafactorはデフォルトでrelative_stepがオンになっているので lrは自動算出されるためLrScheculerの指定も無効になります\nもし任意のLrやLr_Schedulerを使いたい場合は --optimizer_arg relative_ste=False を指定してください\nまた任意のLrを使う場合は scale_parameter=False も併せて指定するのが推奨です\n===================================")
425
+ lr_scheduler = append_module.AdafactorSchedule_append(optimizer, **adafactor_scheduler_arg)
426
+ print(f"AdafactorSchedule initial lrs: {lr_scheduler.get_lr()}")
427
+ del list_of_init_lr
428
  else:
429
+ lr_scheduler = get_scheduler_fix(
430
+ args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
431
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
432
+ num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
433
+
434
  #追加機能の設定をコメントに追記して残す
435
+ args.training_comment=f"{args.training_comment} optimizer: {args.optimizer} / optimizer_arg: {args.optimizer_arg}"
436
+ if AdafactorScheduler_Flag:
437
+ args.training_comment=f"{args.training_comment} split_lora_networks: {args.split_lora_networks}"
 
438
  if args.min_resolution:
439
  args.training_comment=f"{args.training_comment} min_resolution: {args.min_resolution} area_step: {args.area_step}"
440
 
 
504
  args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
505
 
506
  # 学習する
 
507
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
508
+ print("running training / 学習開始")
509
+ print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
510
+ print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
511
+ print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
512
+ print(f" num epochs / epoch数: {num_train_epochs}")
513
+ print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
514
+ print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
515
+ print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
516
+ print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
517
+
 
 
 
518
  metadata = {
519
  "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
520
  "ss_training_started_at": training_started_at, # unix timestamp
 
522
  "ss_learning_rate": args.learning_rate,
523
  "ss_text_encoder_lr": args.text_encoder_lr,
524
  "ss_unet_lr": args.unet_lr,
525
+ "ss_num_train_images": train_dataset.num_train_images, # includes repeating
526
+ "ss_num_reg_images": train_dataset.num_reg_images,
527
  "ss_num_batches_per_epoch": len(train_dataloader),
528
  "ss_num_epochs": num_train_epochs,
529
+ "ss_batch_size_per_device": args.train_batch_size,
530
+ "ss_total_batch_size": total_batch_size,
531
  "ss_gradient_checkpointing": args.gradient_checkpointing,
532
  "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
533
  "ss_max_train_steps": args.max_train_steps,
 
539
  "ss_mixed_precision": args.mixed_precision,
540
  "ss_full_fp16": bool(args.full_fp16),
541
  "ss_v2": bool(args.v2),
542
+ "ss_resolution": args.resolution,
543
  "ss_clip_skip": args.clip_skip,
544
  "ss_max_token_length": args.max_token_length,
545
+ "ss_color_aug": bool(args.color_aug),
546
+ "ss_flip_aug": bool(args.flip_aug),
547
+ "ss_random_crop": bool(args.random_crop),
548
+ "ss_shuffle_caption": bool(args.shuffle_caption),
549
  "ss_cache_latents": bool(args.cache_latents),
550
+ "ss_enable_bucket": bool(train_dataset.enable_bucket),
551
+ "ss_min_bucket_reso": train_dataset.min_bucket_reso,
552
+ "ss_max_bucket_reso": train_dataset.max_bucket_reso,
553
  "ss_seed": args.seed,
554
+ "ss_keep_tokens": args.keep_tokens,
555
  "ss_noise_offset": args.noise_offset,
556
+ "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
557
+ "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
558
+ "ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
559
+ "ss_bucket_info": json.dumps(train_dataset.bucket_info),
560
  "ss_training_comment": args.training_comment, # will not be updated after training
561
+ "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash()
 
 
 
 
 
 
 
562
  }
563
 
564
+ # uncomment if another network is added
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
  # for key, value in net_kwargs.items():
566
  # metadata["ss_arg_" + key] = value
567
 
 
568
  if args.pretrained_model_name_or_path is not None:
569
  sd_model_name = args.pretrained_model_name_or_path
570
  if os.path.exists(sd_model_name):
 
583
 
584
  metadata = {k: str(v) for k, v in metadata.items()}
585
 
 
 
 
 
 
 
 
586
  progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
587
  global_step = 0
588
 
 
595
  loss_list = []
596
  loss_total = 0.0
597
  for epoch in range(num_train_epochs):
598
+ print(f"epoch {epoch+1}/{num_train_epochs}")
599
+ train_dataset.set_current_epoch(epoch + 1)
 
600
 
601
  metadata["ss_epoch"] = str(epoch+1)
602
 
 
633
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
634
 
635
  # Predict the noise residual
636
+ with autocast():
637
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
638
 
639
  if args.v_parameterization:
 
651
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
652
 
653
  accelerator.backward(loss)
654
+ if accelerator.sync_gradients:
655
  params_to_clip = network.get_trainable_params()
656
+ accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
657
 
658
  optimizer.step()
659
+ lr_scheduler.step()
 
660
  optimizer.zero_grad(set_to_none=True)
661
 
662
  # Checks if the accelerator has performed an optimization step behind the scenes
 
664
  progress_bar.update(1)
665
  global_step += 1
666
 
 
 
667
  current_loss = loss.detach().item()
668
  if epoch == 0:
669
  loss_list.append(current_loss)
 
676
  progress_bar.set_postfix(**logs)
677
 
678
  if args.logging_dir is not None:
679
+ logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
680
  accelerator.log(logs, step=global_step)
681
 
682
  if global_step >= args.max_train_steps:
 
694
  def save_func():
695
  ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
696
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
 
697
  print(f"saving checkpoint: {ckpt_file}")
698
+ unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
699
 
700
  def remove_old_func(old_epoch_no):
701
  old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
 
704
  print(f"removing old checkpoint: {old_ckpt_file}")
705
  os.remove(old_ckpt_file)
706
 
707
+ saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
708
+ if saving and args.save_state:
709
+ train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
 
 
 
710
 
711
  # end of epoch
712
 
713
  metadata["ss_epoch"] = str(num_train_epochs)
 
714
 
715
+ is_main_process = accelerator.is_main_process
716
  if is_main_process:
717
  network = unwrap_model(network)
718
 
 
731
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
732
 
733
  print(f"save trained model to {ckpt_file}")
734
+ network.save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
735
  print("model saved.")
736
 
737
 
 
741
  train_util.add_sd_models_arguments(parser)
742
  train_util.add_dataset_arguments(parser, True, True, True)
743
  train_util.add_training_arguments(parser, True)
 
 
744
 
745
  parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
746
  parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
 
748
 
749
  parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
750
  parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
751
+ parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
752
+ help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
753
+ parser.add_argument("--lr_scheduler_power", type=float, default=1,
754
+ help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
755
 
756
  parser.add_argument("--network_weights", type=str, default=None,
757
  help="pretrained weights for network / 学習するネットワークの初期重み")
 
771
  #Optimizer変更関連のオプション追加
772
  append_module.add_append_arguments(parser)
773
  args = append_module.get_config(parser)
 
 
 
 
 
 
 
 
 
 
 
 
774
 
775
  if args.resolution==args.min_resolution:
776
  args.min_resolution=None
777
 
778
  train(args)
 
779
 
780
+ #学習が終わったら現在のargsを保存する
781
+ # import yaml
782
+ # import datetime
783
+ # _t = datetime.datetime.today().strftime('%Y%m%d_%H%M')
784
+ # if args.output_name==None:
785
+ # config_name = f"train_network_config_{_t}.yaml"
786
+ # else:
787
+ # config_name = f"train_network_config_{os.path.basename(args.output_name)}_{_t}.yaml"
788
+ # print(f"{config_name} に設定を書き出し中...")
789
+ # with open(config_name, mode="w") as f:
790
+ # yaml.dump(args.__dict__, f, indent=4)
791
+ # print("done!")
792
 
793
  '''
794
  optimizer設定メモ
 
 
795
  (optimizer_argから設定できるように変更するためのメモ)
796
 
797
  AdamWのweight_decay初期値は1e-2
 
821
  transformerベースのT5学習において最強とかいう噂のoptimizer
822
  huggingfaceのサンプルパラ
823
  eps=1e-30,1e-3 clip_threshold=1.0 decay_rate=-0.8 relative_step=False scale_parameter=False warmup_init=False
 
824
 
825
  AggMo
826
 
train_textual_inversion.py CHANGED
@@ -11,11 +11,7 @@ import diffusers
11
  from diffusers import DDPMScheduler
12
 
13
  import library.train_util as train_util
14
- import library.config_util as config_util
15
- from library.config_util import (
16
- ConfigSanitizer,
17
- BlueprintGenerator,
18
- )
19
 
20
  imagenet_templates_small = [
21
  "a photo of a {}",
@@ -83,6 +79,7 @@ def train(args):
83
  train_util.prepare_dataset_args(args, True)
84
 
85
  cache_latents = args.cache_latents
 
86
 
87
  if args.seed is not None:
88
  set_seed(args.seed)
@@ -142,35 +139,21 @@ def train(args):
142
  print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
143
 
144
  # データセットを準備する
145
- blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
146
- if args.dataset_config is not None:
147
- print(f"Load dataset config from {args.dataset_config}")
148
- user_config = config_util.load_user_config(args.dataset_config)
149
- ignored = ["train_data_dir", "reg_data_dir", "in_json"]
150
- if any(getattr(args, attr) is not None for attr in ignored):
151
- print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
152
  else:
153
- use_dreambooth_method = args.in_json is None
154
- if use_dreambooth_method:
155
- print("Use DreamBooth method.")
156
- user_config = {
157
- "datasets": [{
158
- "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
159
- }]
160
- }
161
- else:
162
- print("Train with captions.")
163
- user_config = {
164
- "datasets": [{
165
- "subsets": [{
166
- "image_dir": args.train_data_dir,
167
- "metadata_file": args.in_json,
168
- }]
169
- }]
170
- }
171
-
172
- blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
173
- train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
174
 
175
  # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
176
  if use_template:
@@ -180,25 +163,20 @@ def train(args):
180
  captions = []
181
  for tmpl in templates:
182
  captions.append(tmpl.format(replace_to))
183
- train_dataset_group.add_replacement("", captions)
184
- else:
185
- if args.num_vectors_per_token > 1:
186
- replace_to = " ".join(token_strings)
187
- train_dataset_group.add_replacement(args.token_string, replace_to)
188
- prompt_replacement = (args.token_string, replace_to)
189
- else:
190
- prompt_replacement = None
191
 
192
  if args.debug_dataset:
193
- train_util.debug_dataset(train_dataset_group, show_input_ids=True)
194
  return
195
- if len(train_dataset_group) == 0:
196
  print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
197
  return
198
 
199
- if cache_latents:
200
- assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
201
-
202
  # モデルに xformers とか memory efficient attention を組み込む
203
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
204
 
@@ -208,7 +186,7 @@ def train(args):
208
  vae.requires_grad_(False)
209
  vae.eval()
210
  with torch.no_grad():
211
- train_dataset_group.cache_latents(vae)
212
  vae.to("cpu")
213
  if torch.cuda.is_available():
214
  torch.cuda.empty_cache()
@@ -220,14 +198,35 @@ def train(args):
220
 
221
  # 学習に必要なクラスを準備する
222
  print("prepare optimizer, data loader etc.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  trainable_params = text_encoder.get_input_embeddings().parameters()
224
- _, _, optimizer = train_util.get_optimizer(args, trainable_params)
 
 
225
 
226
  # dataloaderを準備する
227
  # DataLoaderのプロセス数:0はメインプロセスになる
228
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
229
  train_dataloader = torch.utils.data.DataLoader(
230
- train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
231
 
232
  # 学習ステップ数を計算する
233
  if args.max_train_epochs is not None:
@@ -235,9 +234,8 @@ def train(args):
235
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
236
 
237
  # lr schedulerを用意する
238
- lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
239
- num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
240
- num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
241
 
242
  # acceleratorがなんかよろしくやってくれるらしい
243
  text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
@@ -285,8 +283,8 @@ def train(args):
285
  # 学習する
286
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
287
  print("running training / 学習開始")
288
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
289
- print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
290
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
291
  print(f" num epochs / epoch数: {num_train_epochs}")
292
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
@@ -305,11 +303,12 @@ def train(args):
305
 
306
  for epoch in range(num_train_epochs):
307
  print(f"epoch {epoch+1}/{num_train_epochs}")
308
- train_dataset_group.set_current_epoch(epoch + 1)
309
 
310
  text_encoder.train()
311
 
312
  loss_total = 0
 
313
  for step, batch in enumerate(train_dataloader):
314
  with accelerator.accumulate(text_encoder):
315
  with torch.no_grad():
@@ -358,9 +357,9 @@ def train(args):
358
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
359
 
360
  accelerator.backward(loss)
361
- if accelerator.sync_gradients and args.max_grad_norm != 0.0:
362
  params_to_clip = text_encoder.get_input_embeddings().parameters()
363
- accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
364
 
365
  optimizer.step()
366
  lr_scheduler.step()
@@ -375,14 +374,9 @@ def train(args):
375
  progress_bar.update(1)
376
  global_step += 1
377
 
378
- train_util.sample_images(accelerator, args, None, global_step, accelerator.device,
379
- vae, tokenizer, text_encoder, unet, prompt_replacement)
380
-
381
  current_loss = loss.detach().item()
382
  if args.logging_dir is not None:
383
- logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
384
- if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
385
- logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
386
  accelerator.log(logs, step=global_step)
387
 
388
  loss_total += current_loss
@@ -400,6 +394,8 @@ def train(args):
400
  accelerator.wait_for_everyone()
401
 
402
  updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
 
 
403
 
404
  if args.save_every_n_epochs is not None:
405
  model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
@@ -421,9 +417,6 @@ def train(args):
421
  if saving and args.save_state:
422
  train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
423
 
424
- train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device,
425
- vae, tokenizer, text_encoder, unet, prompt_replacement)
426
-
427
  # end of epoch
428
 
429
  is_main_process = accelerator.is_main_process
@@ -498,8 +491,6 @@ if __name__ == '__main__':
498
  train_util.add_sd_models_arguments(parser)
499
  train_util.add_dataset_arguments(parser, True, True, False)
500
  train_util.add_training_arguments(parser, True)
501
- train_util.add_optimizer_arguments(parser)
502
- config_util.add_config_arguments(parser)
503
 
504
  parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
505
  help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
 
11
  from diffusers import DDPMScheduler
12
 
13
  import library.train_util as train_util
14
+ from library.train_util import DreamBoothDataset, FineTuningDataset
 
 
 
 
15
 
16
  imagenet_templates_small = [
17
  "a photo of a {}",
 
79
  train_util.prepare_dataset_args(args, True)
80
 
81
  cache_latents = args.cache_latents
82
+ use_dreambooth_method = args.in_json is None
83
 
84
  if args.seed is not None:
85
  set_seed(args.seed)
 
139
  print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
140
 
141
  # データセットを準備する
142
+ if use_dreambooth_method:
143
+ print("Use DreamBooth method.")
144
+ train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
145
+ tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
146
+ args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
147
+ args.bucket_reso_steps, args.bucket_no_upscale,
148
+ args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
149
  else:
150
+ print("Train with captions.")
151
+ train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
152
+ tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
153
+ args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
154
+ args.bucket_reso_steps, args.bucket_no_upscale,
155
+ args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
156
+ args.dataset_repeats, args.debug_dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
159
  if use_template:
 
163
  captions = []
164
  for tmpl in templates:
165
  captions.append(tmpl.format(replace_to))
166
+ train_dataset.add_replacement("", captions)
167
+ elif args.num_vectors_per_token > 1:
168
+ replace_to = " ".join(token_strings)
169
+ train_dataset.add_replacement(args.token_string, replace_to)
170
+
171
+ train_dataset.make_buckets()
 
 
172
 
173
  if args.debug_dataset:
174
+ train_util.debug_dataset(train_dataset, show_input_ids=True)
175
  return
176
+ if len(train_dataset) == 0:
177
  print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
178
  return
179
 
 
 
 
180
  # モデルに xformers とか memory efficient attention を組み込む
181
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
182
 
 
186
  vae.requires_grad_(False)
187
  vae.eval()
188
  with torch.no_grad():
189
+ train_dataset.cache_latents(vae)
190
  vae.to("cpu")
191
  if torch.cuda.is_available():
192
  torch.cuda.empty_cache()
 
198
 
199
  # 学習に必要なクラスを準備する
200
  print("prepare optimizer, data loader etc.")
201
+
202
+ # 8-bit Adamを使う
203
+ if args.use_8bit_adam:
204
+ try:
205
+ import bitsandbytes as bnb
206
+ except ImportError:
207
+ raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
208
+ print("use 8-bit Adam optimizer")
209
+ optimizer_class = bnb.optim.AdamW8bit
210
+ elif args.use_lion_optimizer:
211
+ try:
212
+ import lion_pytorch
213
+ except ImportError:
214
+ raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
215
+ print("use Lion optimizer")
216
+ optimizer_class = lion_pytorch.Lion
217
+ else:
218
+ optimizer_class = torch.optim.AdamW
219
+
220
  trainable_params = text_encoder.get_input_embeddings().parameters()
221
+
222
+ # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
223
+ optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
224
 
225
  # dataloaderを準備する
226
  # DataLoaderのプロセス数:0はメインプロセスになる
227
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
228
  train_dataloader = torch.utils.data.DataLoader(
229
+ train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
230
 
231
  # 学習ステップ数を計算する
232
  if args.max_train_epochs is not None:
 
234
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
235
 
236
  # lr schedulerを用意する
237
+ lr_scheduler = diffusers.optimization.get_scheduler(
238
+ args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
 
239
 
240
  # acceleratorがなんかよろしくやってくれるらしい
241
  text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
 
283
  # 学習する
284
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
285
  print("running training / 学習開始")
286
+ print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
287
+ print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
288
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
289
  print(f" num epochs / epoch数: {num_train_epochs}")
290
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
 
303
 
304
  for epoch in range(num_train_epochs):
305
  print(f"epoch {epoch+1}/{num_train_epochs}")
306
+ train_dataset.set_current_epoch(epoch + 1)
307
 
308
  text_encoder.train()
309
 
310
  loss_total = 0
311
+ bef_epo_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
312
  for step, batch in enumerate(train_dataloader):
313
  with accelerator.accumulate(text_encoder):
314
  with torch.no_grad():
 
357
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
358
 
359
  accelerator.backward(loss)
360
+ if accelerator.sync_gradients:
361
  params_to_clip = text_encoder.get_input_embeddings().parameters()
362
+ accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
363
 
364
  optimizer.step()
365
  lr_scheduler.step()
 
374
  progress_bar.update(1)
375
  global_step += 1
376
 
 
 
 
377
  current_loss = loss.detach().item()
378
  if args.logging_dir is not None:
379
+ logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
 
 
380
  accelerator.log(logs, step=global_step)
381
 
382
  loss_total += current_loss
 
394
  accelerator.wait_for_everyone()
395
 
396
  updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
397
+ # d = updated_embs - bef_epo_embs
398
+ # print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
399
 
400
  if args.save_every_n_epochs is not None:
401
  model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
 
417
  if saving and args.save_state:
418
  train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
419
 
 
 
 
420
  # end of epoch
421
 
422
  is_main_process = accelerator.is_main_process
 
491
  train_util.add_sd_models_arguments(parser)
492
  train_util.add_dataset_arguments(parser, True, True, False)
493
  train_util.add_training_arguments(parser, True)
 
 
494
 
495
  parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
496
  help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")