abc commited on
Commit
26a0909
·
1 Parent(s): 3249d87

Upload 55 files

Browse files
.gitattributes CHANGED
@@ -1 +1,34 @@
1
- bitsandbytes_windows/libbitsandbytes_cuda116.dll filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
append_module.py CHANGED
@@ -2,19 +2,7 @@ import argparse
2
  import json
3
  import shutil
4
  import time
5
- from typing import (
6
- Dict,
7
- List,
8
- NamedTuple,
9
- Optional,
10
- Sequence,
11
- Tuple,
12
- Union,
13
- )
14
- from dataclasses import (
15
- asdict,
16
- dataclass,
17
- )
18
  from accelerate import Accelerator
19
  from torch.autograd.function import Function
20
  import glob
@@ -40,7 +28,6 @@ import safetensors.torch
40
 
41
  import library.model_util as model_util
42
  import library.train_util as train_util
43
- import library.config_util as config_util
44
 
45
  #============================================================================================================
46
  #AdafactorScheduleに暫定的にinitial_lrを層別に適用できるようにしたもの
@@ -128,124 +115,6 @@ def make_bucket_resolutions_fix(max_reso, min_reso, min_size=256, max_size=1024,
128
  return area_size_resos_list, area_size_list
129
 
130
  #============================================================================================================
131
- #config_util 内より
132
- #============================================================================================================
133
- @dataclass
134
- class DreamBoothDatasetParams(config_util.DreamBoothDatasetParams):
135
- min_resolution: Optional[Tuple[int, int]] = None
136
- area_step : int = 2
137
-
138
- class ConfigSanitizer(config_util.ConfigSanitizer):
139
- #@config_util.curry
140
- @staticmethod
141
- def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
142
- config_util.Schema(config_util.ExactSequence([klass, klass]))(value)
143
- return tuple(value)
144
-
145
- #@config_util.curry
146
- @staticmethod
147
- def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
148
- config_util.Schema(config_util.Any(klass, config_util.ExactSequence([klass, klass])))(value)
149
- try:
150
- config_util.Schema(klass)(value)
151
- return (value, value)
152
- except:
153
- return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
154
- # datasets schema
155
- DATASET_ASCENDABLE_SCHEMA = {
156
- "batch_size": int,
157
- "bucket_no_upscale": bool,
158
- "bucket_reso_steps": int,
159
- "enable_bucket": bool,
160
- "max_bucket_reso": int,
161
- "min_bucket_reso": int,
162
- "resolution": config_util.functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
163
- "min_resolution": config_util.functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
164
- "area_step": int,
165
- }
166
- def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None:
167
- super().__init__(support_dreambooth, support_finetuning, support_dropout)
168
- def _check(self):
169
- print(self.db_dataset_schema)
170
-
171
- class BlueprintGenerator(config_util.BlueprintGenerator):
172
- def __init__(self, sanitizer: ConfigSanitizer):
173
- config_util.DreamBoothDatasetParams = DreamBoothDatasetParams
174
- super().__init__(sanitizer)
175
-
176
- def generate_dataset_group_by_blueprint(dataset_group_blueprint: config_util.DatasetGroupBlueprint):
177
- datasets: List[Union[DreamBoothDataset, train_util.FineTuningDataset]] = []
178
-
179
- for dataset_blueprint in dataset_group_blueprint.datasets:
180
- if dataset_blueprint.is_dreambooth:
181
- subset_klass = train_util.DreamBoothSubset
182
- dataset_klass = DreamBoothDataset
183
- else:
184
- subset_klass = train_util.FineTuningSubset
185
- dataset_klass = train_util.FineTuningDataset
186
-
187
- subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
188
- dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
189
- datasets.append(dataset)
190
-
191
- # print info
192
- info = ""
193
- for i, dataset in enumerate(datasets):
194
- is_dreambooth = isinstance(dataset, DreamBoothDataset)
195
- info += config_util.dedent(f"""\
196
- [Dataset {i}]
197
- batch_size: {dataset.batch_size}
198
- resolution: {(dataset.width, dataset.height)}
199
- enable_bucket: {dataset.enable_bucket}
200
- """)
201
-
202
- if dataset.enable_bucket:
203
- info += config_util.indent(config_util.dedent(f"""\
204
- min_bucket_reso: {dataset.min_bucket_reso}
205
- max_bucket_reso: {dataset.max_bucket_reso}
206
- bucket_reso_steps: {dataset.bucket_reso_steps}
207
- bucket_no_upscale: {dataset.bucket_no_upscale}
208
- \n"""), " ")
209
- else:
210
- info += "\n"
211
-
212
- for j, subset in enumerate(dataset.subsets):
213
- info += config_util.indent(config_util.dedent(f"""\
214
- [Subset {j} of Dataset {i}]
215
- image_dir: "{subset.image_dir}"
216
- image_count: {subset.img_count}
217
- num_repeats: {subset.num_repeats}
218
- shuffle_caption: {subset.shuffle_caption}
219
- keep_tokens: {subset.keep_tokens}
220
- caption_dropout_rate: {subset.caption_dropout_rate}
221
- caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
222
- caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
223
- color_aug: {subset.color_aug}
224
- flip_aug: {subset.flip_aug}
225
- face_crop_aug_range: {subset.face_crop_aug_range}
226
- random_crop: {subset.random_crop}
227
- """), " ")
228
-
229
- if is_dreambooth:
230
- info += config_util.indent(config_util.dedent(f"""\
231
- is_reg: {subset.is_reg}
232
- class_tokens: {subset.class_tokens}
233
- caption_extension: {subset.caption_extension}
234
- \n"""), " ")
235
- else:
236
- info += config_util.indent(config_util.dedent(f"""\
237
- metadata_file: {subset.metadata_file}
238
- \n"""), " ")
239
-
240
- print(info)
241
-
242
- # make buckets first because it determines the length of dataset
243
- for i, dataset in enumerate(datasets):
244
- print(f"[Dataset {i}]")
245
- dataset.make_buckets()
246
-
247
- return train_util.DatasetGroup(datasets)
248
- #============================================================================================================
249
  #train_util 内より
250
  #============================================================================================================
251
  class BucketManager_append(train_util.BucketManager):
@@ -310,7 +179,7 @@ class BucketManager_append(train_util.BucketManager):
310
  bucket_size_id_list.append(bucket_size_id + i + 1)
311
  _min_error = 1000.
312
  _min_id = bucket_size_id
313
- for now_size_id in bucket_size_id_list:
314
  self.predefined_aspect_ratios = self.predefined_aspect_ratios_list[now_size_id]
315
  ar_errors = self.predefined_aspect_ratios - aspect_ratio
316
  ar_error = np.abs(ar_errors).min()
@@ -384,13 +253,13 @@ class BucketManager_append(train_util.BucketManager):
384
  return reso, resized_size, ar_error
385
 
386
  class DreamBoothDataset(train_util.DreamBoothDataset):
387
- def __init__(self, subsets: Sequence[train_util.DreamBoothSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset, min_resolution=None, area_step=None) -> None:
388
  print("use append DreamBoothDataset")
389
  self.min_resolution = min_resolution
390
  self.area_step = area_step
391
- super().__init__(subsets, batch_size, tokenizer, max_token_length,
392
- resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale,
393
- prior_loss_weight, debug_dataset)
394
  def make_buckets(self):
395
  '''
396
  bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
@@ -483,50 +352,40 @@ class DreamBoothDataset(train_util.DreamBoothDataset):
483
  self.shuffle_buckets()
484
  self._length = len(self.buckets_indices)
485
 
486
- import transformers
487
- from torch.optim import Optimizer
488
- from diffusers.optimization import SchedulerType
489
- from typing import Union
490
- def get_scheduler_Adafactor(
491
- name: Union[str, SchedulerType],
492
- optimizer: Optimizer,
493
- scheduler_arg: Dict
494
- ):
495
- if name.startswith("adafactor"):
496
- assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
497
- print(scheduler_arg)
498
- return AdafactorSchedule_append(optimizer, **scheduler_arg)
 
 
 
 
 
 
 
 
499
  #============================================================================================================
500
  #networks.lora
501
  #============================================================================================================
502
- #from networks.lora import LoRANetwork
503
- def replace_prepare_optimizer_params(networks, network_module):
504
- def prepare_optimizer_params(self, text_encoder_lr, unet_lr, loranames=None, lr_dic=None, block_args_dic=None):
505
-
506
  def enumerate_params(loras, lora_name=None):
507
  params = []
508
  for lora in loras:
509
  if lora_name is not None:
510
- get_param_flag = False
511
- if "attentions" in lora_name or "lora_unet_up_blocks_0_resnets_2":
512
- lora_names = [lora_name]
513
- if "attentions" in lora_name:
514
- lora_names.append(lora_name.replace("attentions", "resnets"))
515
- elif "lora_unet_up_blocks_0_resnets_2" in lora_name:
516
- lora_names.append("lora_unet_up_blocks_0_upsamplers_")
517
- elif "lora_unet_up_blocks_1_attentions_2_" in lora_name:
518
- lora_names.append("lora_unet_up_blocks_1_upsamplers_")
519
- elif "lora_unet_up_blocks_2_attentions_2_" in lora_name:
520
- lora_names.append("lora_unet_up_blocks_2_upsamplers_")
521
-
522
- for _name in lora_names:
523
- if _name in lora.lora_name:
524
- get_param_flag = True
525
- break
526
- else:
527
- if lora_name in lora.lora_name:
528
- get_param_flag = True
529
- if get_param_flag: params.extend(lora.parameters())
530
  else:
531
  params.extend(lora.parameters())
532
  return params
@@ -534,7 +393,6 @@ def replace_prepare_optimizer_params(networks, network_module):
534
  self.requires_grad_(True)
535
  all_params = []
536
  ret_scheduler_lr = []
537
- used_names = []
538
 
539
  if loranames is not None:
540
  textencoder_names = [None]
@@ -547,181 +405,37 @@ def replace_prepare_optimizer_params(networks, network_module):
547
  if self.text_encoder_loras:
548
  for textencoder_name in textencoder_names:
549
  param_data = {'params': enumerate_params(self.text_encoder_loras, lora_name=textencoder_name)}
550
- used_names.append(textencoder_name)
551
  if text_encoder_lr is not None:
552
  param_data['lr'] = text_encoder_lr
553
- if lr_dic is not None:
554
- if textencoder_name in lr_dic:
555
- param_data['lr'] = lr_dic[textencoder_name]
556
- print(f"{textencoder_name} lr: {param_data['lr']}")
557
-
558
- if block_args_dic is not None:
559
- if "lora_te_" in block_args_dic:
560
- for pname, value in block_args_dic["lora_te_"].items():
561
- param_data[pname] = value
562
- if textencoder_name in block_args_dic:
563
- for pname, value in block_args_dic[textencoder_name].items():
564
- param_data[pname] = value
565
-
566
- if text_encoder_lr is not None:
567
- ret_scheduler_lr.append(text_encoder_lr)
568
- else:
569
- ret_scheduler_lr.append(0.)
570
- if lr_dic is not None:
571
- if textencoder_name in lr_dic:
572
- ret_scheduler_lr[-1] = lr_dic[textencoder_name]
573
  all_params.append(param_data)
574
 
575
  if self.unet_loras:
576
  for unet_name in unet_names:
577
  param_data = {'params': enumerate_params(self.unet_loras, lora_name=unet_name)}
578
- if len(param_data["params"])==0: continue
579
- used_names.append(unet_name)
580
  if unet_lr is not None:
581
  param_data['lr'] = unet_lr
582
- if lr_dic is not None:
583
- if unet_name in lr_dic:
584
- param_data['lr'] = lr_dic[unet_name]
585
- print(f"{unet_name} lr: {param_data['lr']}")
586
-
587
- if block_args_dic is not None:
588
- if "lora_unet_" in block_args_dic:
589
- for pname, value in block_args_dic["lora_unet_"].items():
590
- param_data[pname] = value
591
- if unet_name in block_args_dic:
592
- for pname, value in block_args_dic[unet_name].items():
593
- param_data[pname] = value
594
-
595
- if unet_lr is not None:
596
- ret_scheduler_lr.append(unet_lr)
597
- else:
598
- ret_scheduler_lr.append(0.)
599
- if lr_dic is not None:
600
- if unet_name in lr_dic:
601
- ret_scheduler_lr[-1] = lr_dic[unet_name]
602
  all_params.append(param_data)
603
 
604
- return all_params, {"initial_lr" : ret_scheduler_lr}, used_names
605
- try:
606
- network_module.LoRANetwork.prepare_optimizer_params = prepare_optimizer_params
607
- except:
608
- print("cant't replace prepare_optimizer_params")
609
 
610
  #============================================================================================================
611
  #新規追加
612
  #============================================================================================================
613
  def add_append_arguments(parser: argparse.ArgumentParser):
614
  # for train_network_opt.py
615
- #parser.add_argument("--optimizer", type=str, default="AdamW", choices=["AdamW", "RAdam", "AdaBound", "AdaBelief", "AggMo", "AdamP", "Adastand", "Adastand_belief", "Apollo", "Lamb", "Ranger", "RangerVA", "Lookahead_Adam", "Lookahead_DiffGrad", "Yogi", "NovoGrad", "QHAdam", "DiffGrad", "MADGRAD", "Adafactor"], help="使用するoptimizerを指定する")
616
- #parser.add_argument("--optimizer_arg", type=str, default=None, nargs='*')
617
- parser.add_argument("--use_lookahead", action="store_true")
618
- parser.add_argument("--lookahead_arg", type=str, nargs="*", default=None)
619
  parser.add_argument("--split_lora_networks", action="store_true")
620
  parser.add_argument("--split_lora_level", type=int, default=0, help="どれくらい細分化するかの設定 0がunetのみを層別に 1がunetを大枠で分割 2がtextencoder含めて層別")
621
- parser.add_argument("--blocks_lr_setting", type=str, default=None)
622
- parser.add_argument("--block_optim_args", type=str, nargs="*", default=None)
623
  parser.add_argument("--min_resolution", type=str, default=None)
624
  parser.add_argument("--area_step", type=int, default=1)
625
  parser.add_argument("--config", type=str, default=None)
626
- parser.add_argument("--not_output_config", action="store_true")
627
-
628
- class MyNetwork_Names:
629
- ex_block_weight_dic = {
630
- "BASE": ["te"],
631
- "IN01": ["down_0_at_0","donw_0_res_0"], "IN02": ["down_0_at_1","down_0_res_1"], "IN03": ["down_0_down"],
632
- "IN04": ["down_1_at_0","donw_1_res_0"], "IN05": ["down_1_at_1","donw_1_res_1"], "IN06": ["down_1_down"],
633
- "IN07": ["down_2_at_0","donw_2_res_0"], "IN08": ["down_2_at_1","donw_2_res_1"], "IN09": ["down_2_down"],
634
- "IN10": ["down_3_res_0"], "IN11": ["down_3_res_1"],
635
- "MID": ["mid"],
636
- "OUT00": ["up_0_res_0"], "OUT01": ["up_0_res_1"], "OUT02": ["up_0_res_2", "up_0_up"],
637
- "OUT03": ["up_1_at_0", "up_1_res_0"], "OUT04": ["up_1_at_1", "up_1_res_1"], "OUT05": ["up_1_at_2", "up_1_res_2", "up_1_up"],
638
- "OUT06": ["up_2_at_0", "up_2_res_0"], "OUT07": ["up_2_at_1", "up_2_res_1"], "OUT08": ["up_2_at_2", "up_2_res_2", "up_2_up"],
639
- "OUT09": ["up_3_at_0", "up_3_res_0"], "OUT10": ["up_3_at_1", "up_3_res_1"], "OUT11": ["up_3_at_2", "up_3_res_2"],
640
- }
641
-
642
- blocks_name_dic = { "te": "lora_te_",
643
- "unet": "lora_unet_",
644
- "mid": "lora_unet_mid_block_",
645
- "down": "lora_unet_down_blocks_",
646
- "up": "lora_unet_up_blocks_"}
647
- for i in range(12):
648
- blocks_name_dic[f"te_{i}"] = f"lora_te_text_model_encoder_layers_{i}_"
649
- for i in range(3):
650
- blocks_name_dic[f"down_{i}"] = f"lora_unet_down_blocks_{i}"
651
- blocks_name_dic[f"up_{i+1}"] = f"lora_unet_up_blocks_{i+1}"
652
- for i in range(4):
653
- for j in range(2):
654
- if i<=2: blocks_name_dic[f"down_{i}_at_{j}"] = f"lora_unet_down_blocks_{i}_attentions_{j}_"
655
- blocks_name_dic[f"down_{i}_res_{j}"] = f"lora_unet_down_blocks_{i}_resnets_{j}"
656
- for j in range(3):
657
- if i>=1: blocks_name_dic[f"up_{i}_at_{j}"] = f"lora_unet_up_blocks_{i}_attentions_{j}_"
658
- blocks_name_dic[f"up_{i}_res_{j}"] = f"lora_unet_up_blocks_{i}_resnets_{j}"
659
- if i<=2:
660
- blocks_name_dic[f"down_{i}_down"] = f"lora_unet_down_blocks_{i}_downsamplers_"
661
- blocks_name_dic[f"up_{i}_up"] = f"lora_unet_up_blocks_{i}_upsamplers_"
662
-
663
- def create_lr_blocks(lr_setting_str=None, block_optim_args=None):
664
- ex_block_weight_dic = MyNetwork_Names.ex_block_weight_dic
665
- blocks_name_dic = MyNetwork_Names.blocks_name_dic
666
-
667
- lr_dic = {}
668
- if lr_setting_str==None or lr_setting_str=="":
669
- pass
670
- else:
671
- lr_settings = lr_setting_str.replace(" ", "").split(",")
672
- for lr_setting in lr_settings:
673
- key, value = lr_setting.split("=")
674
- if key in ex_block_weight_dic:
675
- keys = ex_block_weight_dic[key]
676
- else:
677
- keys = [key]
678
- for key in keys:
679
- if key in blocks_name_dic:
680
- new_key = blocks_name_dic[key]
681
- lr_dic[new_key] = float(value)
682
- if len(lr_dic)==0:
683
- lr_dic = None
684
-
685
- args_dic = {}
686
- if (block_optim_args is None):
687
- block_optim_args = []
688
- if (len(block_optim_args)>0):
689
- for my_arg in block_optim_args:
690
- my_arg = my_arg.replace(" ", "")
691
- splits = my_arg.split(":")
692
- b_name = splits[0]
693
-
694
- key, _value = splits[1].split("=")
695
- value_type = float
696
- if len(splits)==3:
697
- if _value=="str":
698
- value_type = str
699
- elif _value=="int":
700
- value_type = int
701
- _value = splits[2]
702
- if _value=="true" or _value=="false":
703
- value_type = bool
704
- if "," in _value:
705
- _value = _value.split(",")
706
- for i in range(len(_value)):
707
- _value[i] = value_type(_value[i])
708
- value=tuple(_value)
709
- else:
710
- value = value_type(_value)
711
-
712
- if b_name in ex_block_weight_dic:
713
- b_names = ex_block_weight_dic[b_name]
714
- else:
715
- b_names = [b_name]
716
- for b_name in b_names:
717
- new_b_name = blocks_name_dic[b_name]
718
- if not new_b_name in args_dic:
719
- args_dic[new_b_name] = {}
720
- args_dic[new_b_name][key] = value
721
-
722
- if len(args_dic)==0:
723
- args_dic = None
724
- return lr_dic, args_dic
725
 
726
  def create_split_names(split_flag, split_level):
727
  split_names = None
@@ -732,28 +446,14 @@ def create_split_names(split_flag, split_level):
732
  if split_level==1:
733
  unet_names.append(f"lora_unet_down_blocks_")
734
  unet_names.append(f"lora_unet_up_blocks_")
735
- elif split_level==2 or split_level==0 or split_level==4:
736
- if split_level>=2:
737
  text_encoder_names = []
738
  for i in range(12):
739
  text_encoder_names.append(f"lora_te_text_model_encoder_layers_{i}_")
740
-
741
- if split_level<=2:
742
- for i in range(3):
743
- unet_names.append(f"lora_unet_down_blocks_{i}")
744
- unet_names.append(f"lora_unet_up_blocks_{i+1}")
745
-
746
- if split_level>=3:
747
- for i in range(4):
748
- for j in range(2):
749
- if i<=2: unet_names.append(f"lora_unet_down_blocks_{i}_attentions_{j}_")
750
- if i== 3: unet_names.append(f"lora_unet_down_blocks_{i}_resnets_{j}")
751
- for j in range(3):
752
- if i>=1: unet_names.append(f"lora_unet_up_blocks_{i}_attentions_{j}_")
753
- if i==0: unet_names.append(f"lora_unet_up_blocks_{i}_resnets_{j}")
754
- if i<=2:
755
- unet_names.append(f"lora_unet_down_blocks_{i}_downsamplers_")
756
-
757
  split_names["text_encoder"] = text_encoder_names
758
  split_names["unet"] = unet_names
759
  return split_names
@@ -765,7 +465,7 @@ def get_config(parser):
765
  import datetime
766
  if os.path.splitext(args.config)[-1] == ".yaml":
767
  args.config = os.path.splitext(args.config)[0]
768
- config_path = f"{args.config}.yaml"
769
  if os.path.exists(config_path):
770
  print(f"{config_path} から設定を読み込み中...")
771
  margs, rest = parser.parse_known_args()
@@ -786,41 +486,19 @@ def get_config(parser):
786
  args_type_dic[key] = act.type
787
  #データタイプの確認とargsにkeyの内容を代入していく
788
  for key, v in configs.items():
789
- if v is not None:
790
- if key in args_dic:
791
- if args_dic[key] is not None:
792
- new_type = type(args_dic[key])
793
- if (not type(v) == new_type) and (not new_type==list):
794
- v = new_type(v)
795
- else:
796
  if not type(v) == args_type_dic[key]:
797
  v = args_type_dic[key](v)
798
- args_dic[key] = v
799
  #最後にデフォから指定が変わってるものを変更する
800
  for key, v in change_def_dic.items():
801
  args_dic[key] = v
802
  else:
803
  print(f"{config_path} が見つかりませんでした")
804
  return args
805
-
806
- '''
807
- class GradientReversalFunction(torch.autograd.Function):
808
- @staticmethod
809
- def forward(ctx, input_forward: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
810
- ctx.save_for_backward(scale)
811
- return input_forward
812
- @staticmethod
813
- def backward(ctx, grad_backward: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
814
- scale, = ctx.saved_tensors
815
- return scale * -grad_backward, None
816
-
817
- class GradientReversal(torch.nn.Module):
818
- def __init__(self, scale: float):
819
- super(GradientReversal, self).__init__()
820
- self.scale = torch.tensor(scale)
821
- def forward(self, x: torch.Tensor, flag: bool = False) -> torch.Tensor:
822
- if flag:
823
- return x
824
- else:
825
- return GradientReversalFunction.apply(x, self.scale)
826
- '''
 
2
  import json
3
  import shutil
4
  import time
5
+ from typing import Dict, List, NamedTuple, Tuple
 
 
 
 
 
 
 
 
 
 
 
 
6
  from accelerate import Accelerator
7
  from torch.autograd.function import Function
8
  import glob
 
28
 
29
  import library.model_util as model_util
30
  import library.train_util as train_util
 
31
 
32
  #============================================================================================================
33
  #AdafactorScheduleに暫定的にinitial_lrを層別に適用できるようにしたもの
 
115
  return area_size_resos_list, area_size_list
116
 
117
  #============================================================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  #train_util 内より
119
  #============================================================================================================
120
  class BucketManager_append(train_util.BucketManager):
 
179
  bucket_size_id_list.append(bucket_size_id + i + 1)
180
  _min_error = 1000.
181
  _min_id = bucket_size_id
182
+ for now_size_id in bucket_size_id:
183
  self.predefined_aspect_ratios = self.predefined_aspect_ratios_list[now_size_id]
184
  ar_errors = self.predefined_aspect_ratios - aspect_ratio
185
  ar_error = np.abs(ar_errors).min()
 
253
  return reso, resized_size, ar_error
254
 
255
  class DreamBoothDataset(train_util.DreamBoothDataset):
256
+ def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset, min_resolution=None, area_step=None) -> None:
257
  print("use append DreamBoothDataset")
258
  self.min_resolution = min_resolution
259
  self.area_step = area_step
260
+ super().__init__(batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens,
261
+ resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight,
262
+ flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
263
  def make_buckets(self):
264
  '''
265
  bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
 
352
  self.shuffle_buckets()
353
  self._length = len(self.buckets_indices)
354
 
355
+ class FineTuningDataset(train_util.FineTuningDataset):
356
+ def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
357
+ train_util.glob_images = glob_images
358
+ super().__init__( json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
359
+ resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range,
360
+ random_crop, dataset_repeats, debug_dataset)
361
+
362
+ def glob_images(directory, base="*", npz_flag=True):
363
+ img_paths = []
364
+ dots = []
365
+ for ext in train_util.IMAGE_EXTENSIONS:
366
+ dots.append(ext)
367
+ if npz_flag:
368
+ dots.append(".npz")
369
+ for ext in dots:
370
+ if base == '*':
371
+ img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
372
+ else:
373
+ img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
374
+ return img_paths
375
+
376
  #============================================================================================================
377
  #networks.lora
378
  #============================================================================================================
379
+ from networks.lora import LoRANetwork
380
+ def replace_prepare_optimizer_params(networks):
381
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, scheduler_lr=None, loranames=None):
382
+
383
  def enumerate_params(loras, lora_name=None):
384
  params = []
385
  for lora in loras:
386
  if lora_name is not None:
387
+ if lora_name in lora.lora_name:
388
+ params.extend(lora.parameters())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  else:
390
  params.extend(lora.parameters())
391
  return params
 
393
  self.requires_grad_(True)
394
  all_params = []
395
  ret_scheduler_lr = []
 
396
 
397
  if loranames is not None:
398
  textencoder_names = [None]
 
405
  if self.text_encoder_loras:
406
  for textencoder_name in textencoder_names:
407
  param_data = {'params': enumerate_params(self.text_encoder_loras, lora_name=textencoder_name)}
 
408
  if text_encoder_lr is not None:
409
  param_data['lr'] = text_encoder_lr
410
+ if scheduler_lr is not None:
411
+ ret_scheduler_lr.append(scheduler_lr[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  all_params.append(param_data)
413
 
414
  if self.unet_loras:
415
  for unet_name in unet_names:
416
  param_data = {'params': enumerate_params(self.unet_loras, lora_name=unet_name)}
 
 
417
  if unet_lr is not None:
418
  param_data['lr'] = unet_lr
419
+ if scheduler_lr is not None:
420
+ ret_scheduler_lr.append(scheduler_lr[1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  all_params.append(param_data)
422
 
423
+ return all_params, ret_scheduler_lr
424
+
425
+ LoRANetwork.prepare_optimizer_params = prepare_optimizer_params
 
 
426
 
427
  #============================================================================================================
428
  #新規追加
429
  #============================================================================================================
430
  def add_append_arguments(parser: argparse.ArgumentParser):
431
  # for train_network_opt.py
432
+ parser.add_argument("--optimizer", type=str, default="AdamW", choices=["AdamW", "RAdam", "AdaBound", "AdaBelief", "AggMo", "AdamP", "Adastand", "Adastand_belief", "Apollo", "Lamb", "Ranger", "RangerVA", "Lookahead_Adam", "Lookahead_DiffGrad", "Yogi", "NovoGrad", "QHAdam", "DiffGrad", "MADGRAD", "Adafactor"], help="使用するoptimizerを指定する")
433
+ parser.add_argument("--optimizer_arg", type=str, default=None, nargs='*')
 
 
434
  parser.add_argument("--split_lora_networks", action="store_true")
435
  parser.add_argument("--split_lora_level", type=int, default=0, help="どれくらい細分化するかの設定 0がunetのみを層別に 1がunetを大枠で分割 2がtextencoder含めて層別")
 
 
436
  parser.add_argument("--min_resolution", type=str, default=None)
437
  parser.add_argument("--area_step", type=int, default=1)
438
  parser.add_argument("--config", type=str, default=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
 
440
  def create_split_names(split_flag, split_level):
441
  split_names = None
 
446
  if split_level==1:
447
  unet_names.append(f"lora_unet_down_blocks_")
448
  unet_names.append(f"lora_unet_up_blocks_")
449
+ elif split_level==2 or split_level==0:
450
+ if split_level==2:
451
  text_encoder_names = []
452
  for i in range(12):
453
  text_encoder_names.append(f"lora_te_text_model_encoder_layers_{i}_")
454
+ for i in range(3):
455
+ unet_names.append(f"lora_unet_down_blocks_{i}")
456
+ unet_names.append(f"lora_unet_up_blocks_{i+1}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  split_names["text_encoder"] = text_encoder_names
458
  split_names["unet"] = unet_names
459
  return split_names
 
465
  import datetime
466
  if os.path.splitext(args.config)[-1] == ".yaml":
467
  args.config = os.path.splitext(args.config)[0]
468
+ config_path = f"./{args.config}.yaml"
469
  if os.path.exists(config_path):
470
  print(f"{config_path} から設定を読み込み中...")
471
  margs, rest = parser.parse_known_args()
 
486
  args_type_dic[key] = act.type
487
  #データタイプの確認とargsにkeyの内容を代入していく
488
  for key, v in configs.items():
489
+ if key in args_dic:
490
+ if args_dic[key] is not None:
491
+ new_type = type(args_dic[key])
492
+ if (not type(v) == new_type) and (not new_type==list):
493
+ v = new_type(v)
494
+ else:
495
+ if v is not None:
496
  if not type(v) == args_type_dic[key]:
497
  v = args_type_dic[key](v)
498
+ args_dic[key] = v
499
  #最後にデフォから指定が変わってるものを変更する
500
  for key, v in change_def_dic.items():
501
  args_dic[key] = v
502
  else:
503
  print(f"{config_path} が見つかりませんでした")
504
  return args
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/lib/library/__init__.py ADDED
File without changes
build/lib/library/model_util.py ADDED
@@ -0,0 +1,1180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # v1: split from train_db_fixed.py.
2
+ # v2: support safetensors
3
+
4
+ import math
5
+ import os
6
+ import torch
7
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
8
+ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
9
+ from safetensors.torch import load_file, save_file
10
+
11
+ # DiffUsers版StableDiffusionのモデルパラメータ
12
+ NUM_TRAIN_TIMESTEPS = 1000
13
+ BETA_START = 0.00085
14
+ BETA_END = 0.0120
15
+
16
+ UNET_PARAMS_MODEL_CHANNELS = 320
17
+ UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
18
+ UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
19
+ UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
20
+ UNET_PARAMS_IN_CHANNELS = 4
21
+ UNET_PARAMS_OUT_CHANNELS = 4
22
+ UNET_PARAMS_NUM_RES_BLOCKS = 2
23
+ UNET_PARAMS_CONTEXT_DIM = 768
24
+ UNET_PARAMS_NUM_HEADS = 8
25
+
26
+ VAE_PARAMS_Z_CHANNELS = 4
27
+ VAE_PARAMS_RESOLUTION = 256
28
+ VAE_PARAMS_IN_CHANNELS = 3
29
+ VAE_PARAMS_OUT_CH = 3
30
+ VAE_PARAMS_CH = 128
31
+ VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
32
+ VAE_PARAMS_NUM_RES_BLOCKS = 2
33
+
34
+ # V2
35
+ V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
36
+ V2_UNET_PARAMS_CONTEXT_DIM = 1024
37
+
38
+ # Diffusersの設定を読み込むための参照モデル
39
+ DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
40
+ DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
41
+
42
+
43
+ # region StableDiffusion->Diffusersの変換コード
44
+ # convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
45
+
46
+
47
+ def shave_segments(path, n_shave_prefix_segments=1):
48
+ """
49
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
50
+ """
51
+ if n_shave_prefix_segments >= 0:
52
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
53
+ else:
54
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
55
+
56
+
57
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
58
+ """
59
+ Updates paths inside resnets to the new naming scheme (local renaming)
60
+ """
61
+ mapping = []
62
+ for old_item in old_list:
63
+ new_item = old_item.replace("in_layers.0", "norm1")
64
+ new_item = new_item.replace("in_layers.2", "conv1")
65
+
66
+ new_item = new_item.replace("out_layers.0", "norm2")
67
+ new_item = new_item.replace("out_layers.3", "conv2")
68
+
69
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
70
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
71
+
72
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
73
+
74
+ mapping.append({"old": old_item, "new": new_item})
75
+
76
+ return mapping
77
+
78
+
79
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
80
+ """
81
+ Updates paths inside resnets to the new naming scheme (local renaming)
82
+ """
83
+ mapping = []
84
+ for old_item in old_list:
85
+ new_item = old_item
86
+
87
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
88
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
89
+
90
+ mapping.append({"old": old_item, "new": new_item})
91
+
92
+ return mapping
93
+
94
+
95
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
96
+ """
97
+ Updates paths inside attentions to the new naming scheme (local renaming)
98
+ """
99
+ mapping = []
100
+ for old_item in old_list:
101
+ new_item = old_item
102
+
103
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
104
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
105
+
106
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
107
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
108
+
109
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
110
+
111
+ mapping.append({"old": old_item, "new": new_item})
112
+
113
+ return mapping
114
+
115
+
116
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
117
+ """
118
+ Updates paths inside attentions to the new naming scheme (local renaming)
119
+ """
120
+ mapping = []
121
+ for old_item in old_list:
122
+ new_item = old_item
123
+
124
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
125
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
126
+
127
+ new_item = new_item.replace("q.weight", "query.weight")
128
+ new_item = new_item.replace("q.bias", "query.bias")
129
+
130
+ new_item = new_item.replace("k.weight", "key.weight")
131
+ new_item = new_item.replace("k.bias", "key.bias")
132
+
133
+ new_item = new_item.replace("v.weight", "value.weight")
134
+ new_item = new_item.replace("v.bias", "value.bias")
135
+
136
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
137
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
138
+
139
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
140
+
141
+ mapping.append({"old": old_item, "new": new_item})
142
+
143
+ return mapping
144
+
145
+
146
+ def assign_to_checkpoint(
147
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
148
+ ):
149
+ """
150
+ This does the final conversion step: take locally converted weights and apply a global renaming
151
+ to them. It splits attention layers, and takes into account additional replacements
152
+ that may arise.
153
+
154
+ Assigns the weights to the new checkpoint.
155
+ """
156
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
157
+
158
+ # Splits the attention layers into three variables.
159
+ if attention_paths_to_split is not None:
160
+ for path, path_map in attention_paths_to_split.items():
161
+ old_tensor = old_checkpoint[path]
162
+ channels = old_tensor.shape[0] // 3
163
+
164
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
165
+
166
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
167
+
168
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
169
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
170
+
171
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
172
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
173
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
174
+
175
+ for path in paths:
176
+ new_path = path["new"]
177
+
178
+ # These have already been assigned
179
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
180
+ continue
181
+
182
+ # Global renaming happens here
183
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
184
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
185
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
186
+
187
+ if additional_replacements is not None:
188
+ for replacement in additional_replacements:
189
+ new_path = new_path.replace(replacement["old"], replacement["new"])
190
+
191
+ # proj_attn.weight has to be converted from conv 1D to linear
192
+ if "proj_attn.weight" in new_path:
193
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
194
+ else:
195
+ checkpoint[new_path] = old_checkpoint[path["old"]]
196
+
197
+
198
+ def conv_attn_to_linear(checkpoint):
199
+ keys = list(checkpoint.keys())
200
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
201
+ for key in keys:
202
+ if ".".join(key.split(".")[-2:]) in attn_keys:
203
+ if checkpoint[key].ndim > 2:
204
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
205
+ elif "proj_attn.weight" in key:
206
+ if checkpoint[key].ndim > 2:
207
+ checkpoint[key] = checkpoint[key][:, :, 0]
208
+
209
+
210
+ def linear_transformer_to_conv(checkpoint):
211
+ keys = list(checkpoint.keys())
212
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
213
+ for key in keys:
214
+ if ".".join(key.split(".")[-2:]) in tf_keys:
215
+ if checkpoint[key].ndim == 2:
216
+ checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
217
+
218
+
219
+ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
220
+ """
221
+ Takes a state dict and a config, and returns a converted checkpoint.
222
+ """
223
+
224
+ # extract state_dict for UNet
225
+ unet_state_dict = {}
226
+ unet_key = "model.diffusion_model."
227
+ keys = list(checkpoint.keys())
228
+ for key in keys:
229
+ if key.startswith(unet_key):
230
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
231
+
232
+ new_checkpoint = {}
233
+
234
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
235
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
236
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
237
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
238
+
239
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
240
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
241
+
242
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
243
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
244
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
245
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
246
+
247
+ # Retrieves the keys for the input blocks only
248
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
249
+ input_blocks = {
250
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
251
+ for layer_id in range(num_input_blocks)
252
+ }
253
+
254
+ # Retrieves the keys for the middle blocks only
255
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
256
+ middle_blocks = {
257
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
258
+ for layer_id in range(num_middle_blocks)
259
+ }
260
+
261
+ # Retrieves the keys for the output blocks only
262
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
263
+ output_blocks = {
264
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
265
+ for layer_id in range(num_output_blocks)
266
+ }
267
+
268
+ for i in range(1, num_input_blocks):
269
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
270
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
271
+
272
+ resnets = [
273
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
274
+ ]
275
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
276
+
277
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
278
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
279
+ f"input_blocks.{i}.0.op.weight"
280
+ )
281
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
282
+ f"input_blocks.{i}.0.op.bias"
283
+ )
284
+
285
+ paths = renew_resnet_paths(resnets)
286
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
287
+ assign_to_checkpoint(
288
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
289
+ )
290
+
291
+ if len(attentions):
292
+ paths = renew_attention_paths(attentions)
293
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
294
+ assign_to_checkpoint(
295
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
296
+ )
297
+
298
+ resnet_0 = middle_blocks[0]
299
+ attentions = middle_blocks[1]
300
+ resnet_1 = middle_blocks[2]
301
+
302
+ resnet_0_paths = renew_resnet_paths(resnet_0)
303
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
304
+
305
+ resnet_1_paths = renew_resnet_paths(resnet_1)
306
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
307
+
308
+ attentions_paths = renew_attention_paths(attentions)
309
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
310
+ assign_to_checkpoint(
311
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
312
+ )
313
+
314
+ for i in range(num_output_blocks):
315
+ block_id = i // (config["layers_per_block"] + 1)
316
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
317
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
318
+ output_block_list = {}
319
+
320
+ for layer in output_block_layers:
321
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
322
+ if layer_id in output_block_list:
323
+ output_block_list[layer_id].append(layer_name)
324
+ else:
325
+ output_block_list[layer_id] = [layer_name]
326
+
327
+ if len(output_block_list) > 1:
328
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
329
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
330
+
331
+ resnet_0_paths = renew_resnet_paths(resnets)
332
+ paths = renew_resnet_paths(resnets)
333
+
334
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
335
+ assign_to_checkpoint(
336
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
337
+ )
338
+
339
+ # オリジナル:
340
+ # if ["conv.weight", "conv.bias"] in output_block_list.values():
341
+ # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
342
+
343
+ # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
344
+ for l in output_block_list.values():
345
+ l.sort()
346
+
347
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
348
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
349
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
350
+ f"output_blocks.{i}.{index}.conv.bias"
351
+ ]
352
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
353
+ f"output_blocks.{i}.{index}.conv.weight"
354
+ ]
355
+
356
+ # Clear attentions as they have been attributed above.
357
+ if len(attentions) == 2:
358
+ attentions = []
359
+
360
+ if len(attentions):
361
+ paths = renew_attention_paths(attentions)
362
+ meta_path = {
363
+ "old": f"output_blocks.{i}.1",
364
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
365
+ }
366
+ assign_to_checkpoint(
367
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
368
+ )
369
+ else:
370
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
371
+ for path in resnet_0_paths:
372
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
373
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
374
+
375
+ new_checkpoint[new_path] = unet_state_dict[old_path]
376
+
377
+ # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
378
+ if v2:
379
+ linear_transformer_to_conv(new_checkpoint)
380
+
381
+ return new_checkpoint
382
+
383
+
384
+ def convert_ldm_vae_checkpoint(checkpoint, config):
385
+ # extract state dict for VAE
386
+ vae_state_dict = {}
387
+ vae_key = "first_stage_model."
388
+ keys = list(checkpoint.keys())
389
+ for key in keys:
390
+ if key.startswith(vae_key):
391
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
392
+ # if len(vae_state_dict) == 0:
393
+ # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
394
+ # vae_state_dict = checkpoint
395
+
396
+ new_checkpoint = {}
397
+
398
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
399
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
400
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
401
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
402
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
403
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
404
+
405
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
406
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
407
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
408
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
409
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
410
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
411
+
412
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
413
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
414
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
415
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
416
+
417
+ # Retrieves the keys for the encoder down blocks only
418
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
419
+ down_blocks = {
420
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
421
+ }
422
+
423
+ # Retrieves the keys for the decoder up blocks only
424
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
425
+ up_blocks = {
426
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
427
+ }
428
+
429
+ for i in range(num_down_blocks):
430
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
431
+
432
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
433
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
434
+ f"encoder.down.{i}.downsample.conv.weight"
435
+ )
436
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
437
+ f"encoder.down.{i}.downsample.conv.bias"
438
+ )
439
+
440
+ paths = renew_vae_resnet_paths(resnets)
441
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
442
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
443
+
444
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
445
+ num_mid_res_blocks = 2
446
+ for i in range(1, num_mid_res_blocks + 1):
447
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
448
+
449
+ paths = renew_vae_resnet_paths(resnets)
450
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
451
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
452
+
453
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
454
+ paths = renew_vae_attention_paths(mid_attentions)
455
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
456
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
457
+ conv_attn_to_linear(new_checkpoint)
458
+
459
+ for i in range(num_up_blocks):
460
+ block_id = num_up_blocks - 1 - i
461
+ resnets = [
462
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
463
+ ]
464
+
465
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
466
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
467
+ f"decoder.up.{block_id}.upsample.conv.weight"
468
+ ]
469
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
470
+ f"decoder.up.{block_id}.upsample.conv.bias"
471
+ ]
472
+
473
+ paths = renew_vae_resnet_paths(resnets)
474
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
475
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
476
+
477
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
478
+ num_mid_res_blocks = 2
479
+ for i in range(1, num_mid_res_blocks + 1):
480
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
481
+
482
+ paths = renew_vae_resnet_paths(resnets)
483
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
484
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
485
+
486
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
487
+ paths = renew_vae_attention_paths(mid_attentions)
488
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
489
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
490
+ conv_attn_to_linear(new_checkpoint)
491
+ return new_checkpoint
492
+
493
+
494
+ def create_unet_diffusers_config(v2):
495
+ """
496
+ Creates a config for the diffusers based on the config of the LDM model.
497
+ """
498
+ # unet_params = original_config.model.params.unet_config.params
499
+
500
+ block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
501
+
502
+ down_block_types = []
503
+ resolution = 1
504
+ for i in range(len(block_out_channels)):
505
+ block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
506
+ down_block_types.append(block_type)
507
+ if i != len(block_out_channels) - 1:
508
+ resolution *= 2
509
+
510
+ up_block_types = []
511
+ for i in range(len(block_out_channels)):
512
+ block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
513
+ up_block_types.append(block_type)
514
+ resolution //= 2
515
+
516
+ config = dict(
517
+ sample_size=UNET_PARAMS_IMAGE_SIZE,
518
+ in_channels=UNET_PARAMS_IN_CHANNELS,
519
+ out_channels=UNET_PARAMS_OUT_CHANNELS,
520
+ down_block_types=tuple(down_block_types),
521
+ up_block_types=tuple(up_block_types),
522
+ block_out_channels=tuple(block_out_channels),
523
+ layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
524
+ cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
525
+ attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
526
+ )
527
+
528
+ return config
529
+
530
+
531
+ def create_vae_diffusers_config():
532
+ """
533
+ Creates a config for the diffusers based on the config of the LDM model.
534
+ """
535
+ # vae_params = original_config.model.params.first_stage_config.params.ddconfig
536
+ # _ = original_config.model.params.first_stage_config.params.embed_dim
537
+ block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
538
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
539
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
540
+
541
+ config = dict(
542
+ sample_size=VAE_PARAMS_RESOLUTION,
543
+ in_channels=VAE_PARAMS_IN_CHANNELS,
544
+ out_channels=VAE_PARAMS_OUT_CH,
545
+ down_block_types=tuple(down_block_types),
546
+ up_block_types=tuple(up_block_types),
547
+ block_out_channels=tuple(block_out_channels),
548
+ latent_channels=VAE_PARAMS_Z_CHANNELS,
549
+ layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
550
+ )
551
+ return config
552
+
553
+
554
+ def convert_ldm_clip_checkpoint_v1(checkpoint):
555
+ keys = list(checkpoint.keys())
556
+ text_model_dict = {}
557
+ for key in keys:
558
+ if key.startswith("cond_stage_model.transformer"):
559
+ text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
560
+ return text_model_dict
561
+
562
+
563
+ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
564
+ # 嫌になるくらい違うぞ!
565
+ def convert_key(key):
566
+ if not key.startswith("cond_stage_model"):
567
+ return None
568
+
569
+ # common conversion
570
+ key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
571
+ key = key.replace("cond_stage_model.model.", "text_model.")
572
+
573
+ if "resblocks" in key:
574
+ # resblocks conversion
575
+ key = key.replace(".resblocks.", ".layers.")
576
+ if ".ln_" in key:
577
+ key = key.replace(".ln_", ".layer_norm")
578
+ elif ".mlp." in key:
579
+ key = key.replace(".c_fc.", ".fc1.")
580
+ key = key.replace(".c_proj.", ".fc2.")
581
+ elif '.attn.out_proj' in key:
582
+ key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
583
+ elif '.attn.in_proj' in key:
584
+ key = None # 特殊なので後で処理する
585
+ else:
586
+ raise ValueError(f"unexpected key in SD: {key}")
587
+ elif '.positional_embedding' in key:
588
+ key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
589
+ elif '.text_projection' in key:
590
+ key = None # 使われない???
591
+ elif '.logit_scale' in key:
592
+ key = None # 使われない???
593
+ elif '.token_embedding' in key:
594
+ key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
595
+ elif '.ln_final' in key:
596
+ key = key.replace(".ln_final", ".final_layer_norm")
597
+ return key
598
+
599
+ keys = list(checkpoint.keys())
600
+ new_sd = {}
601
+ for key in keys:
602
+ # remove resblocks 23
603
+ if '.resblocks.23.' in key:
604
+ continue
605
+ new_key = convert_key(key)
606
+ if new_key is None:
607
+ continue
608
+ new_sd[new_key] = checkpoint[key]
609
+
610
+ # attnの変換
611
+ for key in keys:
612
+ if '.resblocks.23.' in key:
613
+ continue
614
+ if '.resblocks' in key and '.attn.in_proj_' in key:
615
+ # 三つに分割
616
+ values = torch.chunk(checkpoint[key], 3)
617
+
618
+ key_suffix = ".weight" if "weight" in key else ".bias"
619
+ key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
620
+ key_pfx = key_pfx.replace("_weight", "")
621
+ key_pfx = key_pfx.replace("_bias", "")
622
+ key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
623
+ new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
624
+ new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
625
+ new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
626
+
627
+ # rename or add position_ids
628
+ ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
629
+ if ANOTHER_POSITION_IDS_KEY in new_sd:
630
+ # waifu diffusion v1.4
631
+ position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
632
+ del new_sd[ANOTHER_POSITION_IDS_KEY]
633
+ else:
634
+ position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
635
+
636
+ new_sd["text_model.embeddings.position_ids"] = position_ids
637
+ return new_sd
638
+
639
+ # endregion
640
+
641
+
642
+ # region Diffusers->StableDiffusion の変換コード
643
+ # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
644
+
645
+ def conv_transformer_to_linear(checkpoint):
646
+ keys = list(checkpoint.keys())
647
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
648
+ for key in keys:
649
+ if ".".join(key.split(".")[-2:]) in tf_keys:
650
+ if checkpoint[key].ndim > 2:
651
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
652
+
653
+
654
+ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
655
+ unet_conversion_map = [
656
+ # (stable-diffusion, HF Diffusers)
657
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
658
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
659
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
660
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
661
+ ("input_blocks.0.0.weight", "conv_in.weight"),
662
+ ("input_blocks.0.0.bias", "conv_in.bias"),
663
+ ("out.0.weight", "conv_norm_out.weight"),
664
+ ("out.0.bias", "conv_norm_out.bias"),
665
+ ("out.2.weight", "conv_out.weight"),
666
+ ("out.2.bias", "conv_out.bias"),
667
+ ]
668
+
669
+ unet_conversion_map_resnet = [
670
+ # (stable-diffusion, HF Diffusers)
671
+ ("in_layers.0", "norm1"),
672
+ ("in_layers.2", "conv1"),
673
+ ("out_layers.0", "norm2"),
674
+ ("out_layers.3", "conv2"),
675
+ ("emb_layers.1", "time_emb_proj"),
676
+ ("skip_connection", "conv_shortcut"),
677
+ ]
678
+
679
+ unet_conversion_map_layer = []
680
+ for i in range(4):
681
+ # loop over downblocks/upblocks
682
+
683
+ for j in range(2):
684
+ # loop over resnets/attentions for downblocks
685
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
686
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
687
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
688
+
689
+ if i < 3:
690
+ # no attention layers in down_blocks.3
691
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
692
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
693
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
694
+
695
+ for j in range(3):
696
+ # loop over resnets/attentions for upblocks
697
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
698
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
699
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
700
+
701
+ if i > 0:
702
+ # no attention layers in up_blocks.0
703
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
704
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
705
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
706
+
707
+ if i < 3:
708
+ # no downsample in down_blocks.3
709
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
710
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
711
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
712
+
713
+ # no upsample in up_blocks.3
714
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
715
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
716
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
717
+
718
+ hf_mid_atn_prefix = "mid_block.attentions.0."
719
+ sd_mid_atn_prefix = "middle_block.1."
720
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
721
+
722
+ for j in range(2):
723
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
724
+ sd_mid_res_prefix = f"middle_block.{2*j}."
725
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
726
+
727
+ # buyer beware: this is a *brittle* function,
728
+ # and correct output requires that all of these pieces interact in
729
+ # the exact order in which I have arranged them.
730
+ mapping = {k: k for k in unet_state_dict.keys()}
731
+ for sd_name, hf_name in unet_conversion_map:
732
+ mapping[hf_name] = sd_name
733
+ for k, v in mapping.items():
734
+ if "resnets" in k:
735
+ for sd_part, hf_part in unet_conversion_map_resnet:
736
+ v = v.replace(hf_part, sd_part)
737
+ mapping[k] = v
738
+ for k, v in mapping.items():
739
+ for sd_part, hf_part in unet_conversion_map_layer:
740
+ v = v.replace(hf_part, sd_part)
741
+ mapping[k] = v
742
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
743
+
744
+ if v2:
745
+ conv_transformer_to_linear(new_state_dict)
746
+
747
+ return new_state_dict
748
+
749
+
750
+ # ================#
751
+ # VAE Conversion #
752
+ # ================#
753
+
754
+ def reshape_weight_for_sd(w):
755
+ # convert HF linear weights to SD conv2d weights
756
+ return w.reshape(*w.shape, 1, 1)
757
+
758
+
759
+ def convert_vae_state_dict(vae_state_dict):
760
+ vae_conversion_map = [
761
+ # (stable-diffusion, HF Diffusers)
762
+ ("nin_shortcut", "conv_shortcut"),
763
+ ("norm_out", "conv_norm_out"),
764
+ ("mid.attn_1.", "mid_block.attentions.0."),
765
+ ]
766
+
767
+ for i in range(4):
768
+ # down_blocks have two resnets
769
+ for j in range(2):
770
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
771
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
772
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
773
+
774
+ if i < 3:
775
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
776
+ sd_downsample_prefix = f"down.{i}.downsample."
777
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
778
+
779
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
780
+ sd_upsample_prefix = f"up.{3-i}.upsample."
781
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
782
+
783
+ # up_blocks have three resnets
784
+ # also, up blocks in hf are numbered in reverse from sd
785
+ for j in range(3):
786
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
787
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
788
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
789
+
790
+ # this part accounts for mid blocks in both the encoder and the decoder
791
+ for i in range(2):
792
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
793
+ sd_mid_res_prefix = f"mid.block_{i+1}."
794
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
795
+
796
+ vae_conversion_map_attn = [
797
+ # (stable-diffusion, HF Diffusers)
798
+ ("norm.", "group_norm."),
799
+ ("q.", "query."),
800
+ ("k.", "key."),
801
+ ("v.", "value."),
802
+ ("proj_out.", "proj_attn."),
803
+ ]
804
+
805
+ mapping = {k: k for k in vae_state_dict.keys()}
806
+ for k, v in mapping.items():
807
+ for sd_part, hf_part in vae_conversion_map:
808
+ v = v.replace(hf_part, sd_part)
809
+ mapping[k] = v
810
+ for k, v in mapping.items():
811
+ if "attentions" in k:
812
+ for sd_part, hf_part in vae_conversion_map_attn:
813
+ v = v.replace(hf_part, sd_part)
814
+ mapping[k] = v
815
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
816
+ weights_to_convert = ["q", "k", "v", "proj_out"]
817
+ for k, v in new_state_dict.items():
818
+ for weight_name in weights_to_convert:
819
+ if f"mid.attn_1.{weight_name}.weight" in k:
820
+ # print(f"Reshaping {k} for SD format")
821
+ new_state_dict[k] = reshape_weight_for_sd(v)
822
+
823
+ return new_state_dict
824
+
825
+
826
+ # endregion
827
+
828
+ # region 自作のモデル読み書きなど
829
+
830
+ def is_safetensors(path):
831
+ return os.path.splitext(path)[1].lower() == '.safetensors'
832
+
833
+
834
+ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
835
+ # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
836
+ TEXT_ENCODER_KEY_REPLACEMENTS = [
837
+ ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
838
+ ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
839
+ ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
840
+ ]
841
+
842
+ if is_safetensors(ckpt_path):
843
+ checkpoint = None
844
+ state_dict = load_file(ckpt_path, "cpu")
845
+ else:
846
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
847
+ if "state_dict" in checkpoint:
848
+ state_dict = checkpoint["state_dict"]
849
+ else:
850
+ state_dict = checkpoint
851
+ checkpoint = None
852
+
853
+ key_reps = []
854
+ for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
855
+ for key in state_dict.keys():
856
+ if key.startswith(rep_from):
857
+ new_key = rep_to + key[len(rep_from):]
858
+ key_reps.append((key, new_key))
859
+
860
+ for key, new_key in key_reps:
861
+ state_dict[new_key] = state_dict[key]
862
+ del state_dict[key]
863
+
864
+ return checkpoint, state_dict
865
+
866
+
867
+ # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
868
+ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
869
+ _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
870
+ if dtype is not None:
871
+ for k, v in state_dict.items():
872
+ if type(v) is torch.Tensor:
873
+ state_dict[k] = v.to(dtype)
874
+
875
+ # Convert the UNet2DConditionModel model.
876
+ unet_config = create_unet_diffusers_config(v2)
877
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
878
+
879
+ unet = UNet2DConditionModel(**unet_config)
880
+ info = unet.load_state_dict(converted_unet_checkpoint)
881
+ print("loading u-net:", info)
882
+
883
+ # Convert the VAE model.
884
+ vae_config = create_vae_diffusers_config()
885
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
886
+
887
+ vae = AutoencoderKL(**vae_config)
888
+ info = vae.load_state_dict(converted_vae_checkpoint)
889
+ print("loading vae:", info)
890
+
891
+ # convert text_model
892
+ if v2:
893
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
894
+ cfg = CLIPTextConfig(
895
+ vocab_size=49408,
896
+ hidden_size=1024,
897
+ intermediate_size=4096,
898
+ num_hidden_layers=23,
899
+ num_attention_heads=16,
900
+ max_position_embeddings=77,
901
+ hidden_act="gelu",
902
+ layer_norm_eps=1e-05,
903
+ dropout=0.0,
904
+ attention_dropout=0.0,
905
+ initializer_range=0.02,
906
+ initializer_factor=1.0,
907
+ pad_token_id=1,
908
+ bos_token_id=0,
909
+ eos_token_id=2,
910
+ model_type="clip_text_model",
911
+ projection_dim=512,
912
+ torch_dtype="float32",
913
+ transformers_version="4.25.0.dev0",
914
+ )
915
+ text_model = CLIPTextModel._from_config(cfg)
916
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
917
+ else:
918
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
919
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
920
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
921
+ print("loading text encoder:", info)
922
+
923
+ return text_model, vae, unet
924
+
925
+
926
+ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
927
+ def convert_key(key):
928
+ # position_idsの除去
929
+ if ".position_ids" in key:
930
+ return None
931
+
932
+ # common
933
+ key = key.replace("text_model.encoder.", "transformer.")
934
+ key = key.replace("text_model.", "")
935
+ if "layers" in key:
936
+ # resblocks conversion
937
+ key = key.replace(".layers.", ".resblocks.")
938
+ if ".layer_norm" in key:
939
+ key = key.replace(".layer_norm", ".ln_")
940
+ elif ".mlp." in key:
941
+ key = key.replace(".fc1.", ".c_fc.")
942
+ key = key.replace(".fc2.", ".c_proj.")
943
+ elif '.self_attn.out_proj' in key:
944
+ key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
945
+ elif '.self_attn.' in key:
946
+ key = None # 特殊なので後で処理する
947
+ else:
948
+ raise ValueError(f"unexpected key in DiffUsers model: {key}")
949
+ elif '.position_embedding' in key:
950
+ key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
951
+ elif '.token_embedding' in key:
952
+ key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
953
+ elif 'final_layer_norm' in key:
954
+ key = key.replace("final_layer_norm", "ln_final")
955
+ return key
956
+
957
+ keys = list(checkpoint.keys())
958
+ new_sd = {}
959
+ for key in keys:
960
+ new_key = convert_key(key)
961
+ if new_key is None:
962
+ continue
963
+ new_sd[new_key] = checkpoint[key]
964
+
965
+ # attnの変換
966
+ for key in keys:
967
+ if 'layers' in key and 'q_proj' in key:
968
+ # 三つを結合
969
+ key_q = key
970
+ key_k = key.replace("q_proj", "k_proj")
971
+ key_v = key.replace("q_proj", "v_proj")
972
+
973
+ value_q = checkpoint[key_q]
974
+ value_k = checkpoint[key_k]
975
+ value_v = checkpoint[key_v]
976
+ value = torch.cat([value_q, value_k, value_v])
977
+
978
+ new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
979
+ new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
980
+ new_sd[new_key] = value
981
+
982
+ # 最後の層などを捏造するか
983
+ if make_dummy_weights:
984
+ print("make dummy weights for resblock.23, text_projection and logit scale.")
985
+ keys = list(new_sd.keys())
986
+ for key in keys:
987
+ if key.startswith("transformer.resblocks.22."):
988
+ new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
989
+
990
+ # Diffusersに含まれない重みを作っておく
991
+ new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
992
+ new_sd['logit_scale'] = torch.tensor(1)
993
+
994
+ return new_sd
995
+
996
+
997
+ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
998
+ if ckpt_path is not None:
999
+ # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
1000
+ checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
1001
+ if checkpoint is None: # safetensors または state_dictのckpt
1002
+ checkpoint = {}
1003
+ strict = False
1004
+ else:
1005
+ strict = True
1006
+ if "state_dict" in state_dict:
1007
+ del state_dict["state_dict"]
1008
+ else:
1009
+ # 新しく作る
1010
+ assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
1011
+ checkpoint = {}
1012
+ state_dict = {}
1013
+ strict = False
1014
+
1015
+ def update_sd(prefix, sd):
1016
+ for k, v in sd.items():
1017
+ key = prefix + k
1018
+ assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
1019
+ if save_dtype is not None:
1020
+ v = v.detach().clone().to("cpu").to(save_dtype)
1021
+ state_dict[key] = v
1022
+
1023
+ # Convert the UNet model
1024
+ unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
1025
+ update_sd("model.diffusion_model.", unet_state_dict)
1026
+
1027
+ # Convert the text encoder model
1028
+ if v2:
1029
+ make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製��て作るなどダミーの重みを入れる
1030
+ text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
1031
+ update_sd("cond_stage_model.model.", text_enc_dict)
1032
+ else:
1033
+ text_enc_dict = text_encoder.state_dict()
1034
+ update_sd("cond_stage_model.transformer.", text_enc_dict)
1035
+
1036
+ # Convert the VAE
1037
+ if vae is not None:
1038
+ vae_dict = convert_vae_state_dict(vae.state_dict())
1039
+ update_sd("first_stage_model.", vae_dict)
1040
+
1041
+ # Put together new checkpoint
1042
+ key_count = len(state_dict.keys())
1043
+ new_ckpt = {'state_dict': state_dict}
1044
+
1045
+ if 'epoch' in checkpoint:
1046
+ epochs += checkpoint['epoch']
1047
+ if 'global_step' in checkpoint:
1048
+ steps += checkpoint['global_step']
1049
+
1050
+ new_ckpt['epoch'] = epochs
1051
+ new_ckpt['global_step'] = steps
1052
+
1053
+ if is_safetensors(output_file):
1054
+ # TODO Tensor以外のdictの値を削除したほうがいいか
1055
+ save_file(state_dict, output_file)
1056
+ else:
1057
+ torch.save(new_ckpt, output_file)
1058
+
1059
+ return key_count
1060
+
1061
+
1062
+ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
1063
+ if pretrained_model_name_or_path is None:
1064
+ # load default settings for v1/v2
1065
+ if v2:
1066
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
1067
+ else:
1068
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
1069
+
1070
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
1071
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
1072
+ if vae is None:
1073
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
1074
+
1075
+ pipeline = StableDiffusionPipeline(
1076
+ unet=unet,
1077
+ text_encoder=text_encoder,
1078
+ vae=vae,
1079
+ scheduler=scheduler,
1080
+ tokenizer=tokenizer,
1081
+ safety_checker=None,
1082
+ feature_extractor=None,
1083
+ requires_safety_checker=None,
1084
+ )
1085
+ pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
1086
+
1087
+
1088
+ VAE_PREFIX = "first_stage_model."
1089
+
1090
+
1091
+ def load_vae(vae_id, dtype):
1092
+ print(f"load VAE: {vae_id}")
1093
+ if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
1094
+ # Diffusers local/remote
1095
+ try:
1096
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
1097
+ except EnvironmentError as e:
1098
+ print(f"exception occurs in loading vae: {e}")
1099
+ print("retry with subfolder='vae'")
1100
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
1101
+ return vae
1102
+
1103
+ # local
1104
+ vae_config = create_vae_diffusers_config()
1105
+
1106
+ if vae_id.endswith(".bin"):
1107
+ # SD 1.5 VAE on Huggingface
1108
+ converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
1109
+ else:
1110
+ # StableDiffusion
1111
+ vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
1112
+ else torch.load(vae_id, map_location="cpu"))
1113
+ vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
1114
+
1115
+ # vae only or full model
1116
+ full_model = False
1117
+ for vae_key in vae_sd:
1118
+ if vae_key.startswith(VAE_PREFIX):
1119
+ full_model = True
1120
+ break
1121
+ if not full_model:
1122
+ sd = {}
1123
+ for key, value in vae_sd.items():
1124
+ sd[VAE_PREFIX + key] = value
1125
+ vae_sd = sd
1126
+ del sd
1127
+
1128
+ # Convert the VAE model.
1129
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
1130
+
1131
+ vae = AutoencoderKL(**vae_config)
1132
+ vae.load_state_dict(converted_vae_checkpoint)
1133
+ return vae
1134
+
1135
+ # endregion
1136
+
1137
+
1138
+ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
1139
+ max_width, max_height = max_reso
1140
+ max_area = (max_width // divisible) * (max_height // divisible)
1141
+
1142
+ resos = set()
1143
+
1144
+ size = int(math.sqrt(max_area)) * divisible
1145
+ resos.add((size, size))
1146
+
1147
+ size = min_size
1148
+ while size <= max_size:
1149
+ width = size
1150
+ height = min(max_size, (max_area // (width // divisible)) * divisible)
1151
+ resos.add((width, height))
1152
+ resos.add((height, width))
1153
+
1154
+ # # make additional resos
1155
+ # if width >= height and width - divisible >= min_size:
1156
+ # resos.add((width - divisible, height))
1157
+ # resos.add((height, width - divisible))
1158
+ # if height >= width and height - divisible >= min_size:
1159
+ # resos.add((width, height - divisible))
1160
+ # resos.add((height - divisible, width))
1161
+
1162
+ size += divisible
1163
+
1164
+ resos = list(resos)
1165
+ resos.sort()
1166
+ return resos
1167
+
1168
+
1169
+ if __name__ == '__main__':
1170
+ resos = make_bucket_resolutions((512, 768))
1171
+ print(len(resos))
1172
+ print(resos)
1173
+ aspect_ratios = [w / h for w, h in resos]
1174
+ print(aspect_ratios)
1175
+
1176
+ ars = set()
1177
+ for ar in aspect_ratios:
1178
+ if ar in ars:
1179
+ print("error! duplicate ar:", ar)
1180
+ ars.add(ar)
build/lib/library/train_util.py ADDED
@@ -0,0 +1,1796 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # common functions for training
2
+
3
+ import argparse
4
+ import json
5
+ import shutil
6
+ import time
7
+ from typing import Dict, List, NamedTuple, Tuple
8
+ from accelerate import Accelerator
9
+ from torch.autograd.function import Function
10
+ import glob
11
+ import math
12
+ import os
13
+ import random
14
+ import hashlib
15
+ import subprocess
16
+ from io import BytesIO
17
+
18
+ from tqdm import tqdm
19
+ import torch
20
+ from torchvision import transforms
21
+ from transformers import CLIPTokenizer
22
+ import diffusers
23
+ from diffusers import DDPMScheduler, StableDiffusionPipeline
24
+ import albumentations as albu
25
+ import numpy as np
26
+ from PIL import Image
27
+ import cv2
28
+ from einops import rearrange
29
+ from torch import einsum
30
+ import safetensors.torch
31
+
32
+ import library.model_util as model_util
33
+
34
+ # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
35
+ TOKENIZER_PATH = "openai/clip-vit-large-patch14"
36
+ V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
37
+
38
+ # checkpointファイル名
39
+ EPOCH_STATE_NAME = "{}-{:06d}-state"
40
+ EPOCH_FILE_NAME = "{}-{:06d}"
41
+ EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}"
42
+ LAST_STATE_NAME = "{}-state"
43
+ DEFAULT_EPOCH_NAME = "epoch"
44
+ DEFAULT_LAST_OUTPUT_NAME = "last"
45
+
46
+ # region dataset
47
+
48
+ IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
49
+ # , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux?
50
+
51
+
52
+ class ImageInfo():
53
+ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
54
+ self.image_key: str = image_key
55
+ self.num_repeats: int = num_repeats
56
+ self.caption: str = caption
57
+ self.is_reg: bool = is_reg
58
+ self.absolute_path: str = absolute_path
59
+ self.image_size: Tuple[int, int] = None
60
+ self.resized_size: Tuple[int, int] = None
61
+ self.bucket_reso: Tuple[int, int] = None
62
+ self.latents: torch.Tensor = None
63
+ self.latents_flipped: torch.Tensor = None
64
+ self.latents_npz: str = None
65
+ self.latents_npz_flipped: str = None
66
+
67
+
68
+ class BucketManager():
69
+ def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
70
+ self.no_upscale = no_upscale
71
+ if max_reso is None:
72
+ self.max_reso = None
73
+ self.max_area = None
74
+ else:
75
+ self.max_reso = max_reso
76
+ self.max_area = max_reso[0] * max_reso[1]
77
+ self.min_size = min_size
78
+ self.max_size = max_size
79
+ self.reso_steps = reso_steps
80
+
81
+ self.resos = []
82
+ self.reso_to_id = {}
83
+ self.buckets = [] # 前処理時は (image_key, image)、学習時は image_key
84
+
85
+ def add_image(self, reso, image):
86
+ bucket_id = self.reso_to_id[reso]
87
+ self.buckets[bucket_id].append(image)
88
+
89
+ def shuffle(self):
90
+ for bucket in self.buckets:
91
+ random.shuffle(bucket)
92
+
93
+ def sort(self):
94
+ # 解像度順にソートする(表示時、メタデータ格納時の見栄えをよくするためだけ)。bucketsも入れ替えてreso_to_idも振り直す
95
+ sorted_resos = self.resos.copy()
96
+ sorted_resos.sort()
97
+
98
+ sorted_buckets = []
99
+ sorted_reso_to_id = {}
100
+ for i, reso in enumerate(sorted_resos):
101
+ bucket_id = self.reso_to_id[reso]
102
+ sorted_buckets.append(self.buckets[bucket_id])
103
+ sorted_reso_to_id[reso] = i
104
+
105
+ self.resos = sorted_resos
106
+ self.buckets = sorted_buckets
107
+ self.reso_to_id = sorted_reso_to_id
108
+
109
+ def make_buckets(self):
110
+ resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps)
111
+ self.set_predefined_resos(resos)
112
+
113
+ def set_predefined_resos(self, resos):
114
+ # 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
115
+ self.predefined_resos = resos.copy()
116
+ self.predefined_resos_set = set(resos)
117
+ self.predefined_aspect_ratios = np.array([w / h for w, h in resos])
118
+
119
+ def add_if_new_reso(self, reso):
120
+ if reso not in self.reso_to_id:
121
+ bucket_id = len(self.resos)
122
+ self.reso_to_id[reso] = bucket_id
123
+ self.resos.append(reso)
124
+ self.buckets.append([])
125
+ # print(reso, bucket_id, len(self.buckets))
126
+
127
+ def round_to_steps(self, x):
128
+ x = int(x + .5)
129
+ return x - x % self.reso_steps
130
+
131
+ def select_bucket(self, image_width, image_height):
132
+ aspect_ratio = image_width / image_height
133
+ if not self.no_upscale:
134
+ # 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する
135
+ reso = (image_width, image_height)
136
+ if reso in self.predefined_resos_set:
137
+ pass
138
+ else:
139
+ ar_errors = self.predefined_aspect_ratios - aspect_ratio
140
+ predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
141
+ reso = self.predefined_resos[predefined_bucket_id]
142
+
143
+ ar_reso = reso[0] / reso[1]
144
+ if aspect_ratio > ar_reso: # 横が長い→縦を合わせる
145
+ scale = reso[1] / image_height
146
+ else:
147
+ scale = reso[0] / image_width
148
+
149
+ resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
150
+ # print("use predef", image_width, image_height, reso, resized_size)
151
+ else:
152
+ if image_width * image_height > self.max_area:
153
+ # 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める
154
+ resized_width = math.sqrt(self.max_area * aspect_ratio)
155
+ resized_height = self.max_area / resized_width
156
+ assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal"
157
+
158
+ # リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ
159
+ # 元のbucketingと同じロジック
160
+ b_width_rounded = self.round_to_steps(resized_width)
161
+ b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio)
162
+ ar_width_rounded = b_width_rounded / b_height_in_wr
163
+
164
+ b_height_rounded = self.round_to_steps(resized_height)
165
+ b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio)
166
+ ar_height_rounded = b_width_in_hr / b_height_rounded
167
+
168
+ # print(b_width_rounded, b_height_in_wr, ar_width_rounded)
169
+ # print(b_width_in_hr, b_height_rounded, ar_height_rounded)
170
+
171
+ if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio):
172
+ resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + .5))
173
+ else:
174
+ resized_size = (int(b_height_rounded * aspect_ratio + .5), b_height_rounded)
175
+ # print(resized_size)
176
+ else:
177
+ resized_size = (image_width, image_height) # リサイズは不要
178
+
179
+ # 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする)
180
+ bucket_width = resized_size[0] - resized_size[0] % self.reso_steps
181
+ bucket_height = resized_size[1] - resized_size[1] % self.reso_steps
182
+ # print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height)
183
+
184
+ reso = (bucket_width, bucket_height)
185
+
186
+ self.add_if_new_reso(reso)
187
+
188
+ ar_error = (reso[0] / reso[1]) - aspect_ratio
189
+ return reso, resized_size, ar_error
190
+
191
+
192
+ class BucketBatchIndex(NamedTuple):
193
+ bucket_index: int
194
+ bucket_batch_size: int
195
+ batch_index: int
196
+
197
+
198
+ class BaseDataset(torch.utils.data.Dataset):
199
+ def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
200
+ super().__init__()
201
+ self.tokenizer: CLIPTokenizer = tokenizer
202
+ self.max_token_length = max_token_length
203
+ self.shuffle_caption = shuffle_caption
204
+ self.shuffle_keep_tokens = shuffle_keep_tokens
205
+ # width/height is used when enable_bucket==False
206
+ self.width, self.height = (None, None) if resolution is None else resolution
207
+ self.face_crop_aug_range = face_crop_aug_range
208
+ self.flip_aug = flip_aug
209
+ self.color_aug = color_aug
210
+ self.debug_dataset = debug_dataset
211
+ self.random_crop = random_crop
212
+ self.token_padding_disabled = False
213
+ self.dataset_dirs_info = {}
214
+ self.reg_dataset_dirs_info = {}
215
+ self.tag_frequency = {}
216
+
217
+ self.enable_bucket = False
218
+ self.bucket_manager: BucketManager = None # not initialized
219
+ self.min_bucket_reso = None
220
+ self.max_bucket_reso = None
221
+ self.bucket_reso_steps = None
222
+ self.bucket_no_upscale = None
223
+ self.bucket_info = None # for metadata
224
+
225
+ self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
226
+
227
+ self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
228
+ self.dropout_rate: float = 0
229
+ self.dropout_every_n_epochs: int = None
230
+ self.tag_dropout_rate: float = 0
231
+
232
+ # augmentation
233
+ flip_p = 0.5 if flip_aug else 0.0
234
+ if color_aug:
235
+ # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
236
+ self.aug = albu.Compose([
237
+ albu.OneOf([
238
+ albu.HueSaturationValue(8, 0, 0, p=.5),
239
+ albu.RandomGamma((95, 105), p=.5),
240
+ ], p=.33),
241
+ albu.HorizontalFlip(p=flip_p)
242
+ ], p=1.)
243
+ elif flip_aug:
244
+ self.aug = albu.Compose([
245
+ albu.HorizontalFlip(p=flip_p)
246
+ ], p=1.)
247
+ else:
248
+ self.aug = None
249
+
250
+ self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
251
+
252
+ self.image_data: Dict[str, ImageInfo] = {}
253
+
254
+ self.replacements = {}
255
+
256
+ def set_current_epoch(self, epoch):
257
+ self.current_epoch = epoch
258
+
259
+ def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
260
+ # コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
261
+ self.dropout_rate = dropout_rate
262
+ self.dropout_every_n_epochs = dropout_every_n_epochs
263
+ self.tag_dropout_rate = tag_dropout_rate
264
+
265
+ def set_tag_frequency(self, dir_name, captions):
266
+ frequency_for_dir = self.tag_frequency.get(dir_name, {})
267
+ self.tag_frequency[dir_name] = frequency_for_dir
268
+ for caption in captions:
269
+ for tag in caption.split(","):
270
+ if tag and not tag.isspace():
271
+ tag = tag.lower()
272
+ frequency = frequency_for_dir.get(tag, 0)
273
+ frequency_for_dir[tag] = frequency + 1
274
+
275
+ def disable_token_padding(self):
276
+ self.token_padding_disabled = True
277
+
278
+ def add_replacement(self, str_from, str_to):
279
+ self.replacements[str_from] = str_to
280
+
281
+ def process_caption(self, caption):
282
+ # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
283
+ is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
284
+ is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
285
+
286
+ if is_drop_out:
287
+ caption = ""
288
+ else:
289
+ if self.shuffle_caption or self.tag_dropout_rate > 0:
290
+ def dropout_tags(tokens):
291
+ if self.tag_dropout_rate <= 0:
292
+ return tokens
293
+ l = []
294
+ for token in tokens:
295
+ if random.random() >= self.tag_dropout_rate:
296
+ l.append(token)
297
+ return l
298
+
299
+ tokens = [t.strip() for t in caption.strip().split(",")]
300
+ if self.shuffle_keep_tokens is None:
301
+ if self.shuffle_caption:
302
+ random.shuffle(tokens)
303
+
304
+ tokens = dropout_tags(tokens)
305
+ else:
306
+ if len(tokens) > self.shuffle_keep_tokens:
307
+ keep_tokens = tokens[:self.shuffle_keep_tokens]
308
+ tokens = tokens[self.shuffle_keep_tokens:]
309
+
310
+ if self.shuffle_caption:
311
+ random.shuffle(tokens)
312
+
313
+ tokens = dropout_tags(tokens)
314
+
315
+ tokens = keep_tokens + tokens
316
+ caption = ", ".join(tokens)
317
+
318
+ # textual inversion対応
319
+ for str_from, str_to in self.replacements.items():
320
+ if str_from == "":
321
+ # replace all
322
+ if type(str_to) == list:
323
+ caption = random.choice(str_to)
324
+ else:
325
+ caption = str_to
326
+ else:
327
+ caption = caption.replace(str_from, str_to)
328
+
329
+ return caption
330
+
331
+ def get_input_ids(self, caption):
332
+ input_ids = self.tokenizer(caption, padding="max_length", truncation=True,
333
+ max_length=self.tokenizer_max_length, return_tensors="pt").input_ids
334
+
335
+ if self.tokenizer_max_length > self.tokenizer.model_max_length:
336
+ input_ids = input_ids.squeeze(0)
337
+ iids_list = []
338
+ if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
339
+ # v1
340
+ # 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
341
+ # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
342
+ for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75)
343
+ ids_chunk = (input_ids[0].unsqueeze(0),
344
+ input_ids[i:i + self.tokenizer.model_max_length - 2],
345
+ input_ids[-1].unsqueeze(0))
346
+ ids_chunk = torch.cat(ids_chunk)
347
+ iids_list.append(ids_chunk)
348
+ else:
349
+ # v2
350
+ # 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
351
+ for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2):
352
+ ids_chunk = (input_ids[0].unsqueeze(0), # BOS
353
+ input_ids[i:i + self.tokenizer.model_max_length - 2],
354
+ input_ids[-1].unsqueeze(0)) # PAD or EOS
355
+ ids_chunk = torch.cat(ids_chunk)
356
+
357
+ # 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
358
+ # 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
359
+ if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id:
360
+ ids_chunk[-1] = self.tokenizer.eos_token_id
361
+ # 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
362
+ if ids_chunk[1] == self.tokenizer.pad_token_id:
363
+ ids_chunk[1] = self.tokenizer.eos_token_id
364
+
365
+ iids_list.append(ids_chunk)
366
+
367
+ input_ids = torch.stack(iids_list) # 3,77
368
+ return input_ids
369
+
370
+ def register_image(self, info: ImageInfo):
371
+ self.image_data[info.image_key] = info
372
+
373
+ def make_buckets(self):
374
+ '''
375
+ bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
376
+ min_size and max_size are ignored when enable_bucket is False
377
+ '''
378
+ print("loading image sizes.")
379
+ for info in tqdm(self.image_data.values()):
380
+ if info.image_size is None:
381
+ info.image_size = self.get_image_size(info.absolute_path)
382
+
383
+ if self.enable_bucket:
384
+ print("make buckets")
385
+ else:
386
+ print("prepare dataset")
387
+
388
+ # bucketを作成し、画像をbucketに振り分ける
389
+ if self.enable_bucket:
390
+ if self.bucket_manager is None: # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み
391
+ self.bucket_manager = BucketManager(self.bucket_no_upscale, (self.width, self.height),
392
+ self.min_bucket_reso, self.max_bucket_reso, self.bucket_reso_steps)
393
+ if not self.bucket_no_upscale:
394
+ self.bucket_manager.make_buckets()
395
+ else:
396
+ print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
397
+
398
+ img_ar_errors = []
399
+ for image_info in self.image_data.values():
400
+ image_width, image_height = image_info.image_size
401
+ image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(image_width, image_height)
402
+
403
+ # print(image_info.image_key, image_info.bucket_reso)
404
+ img_ar_errors.append(abs(ar_error))
405
+
406
+ self.bucket_manager.sort()
407
+ else:
408
+ self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None)
409
+ self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ
410
+ for image_info in self.image_data.values():
411
+ image_width, image_height = image_info.image_size
412
+ image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height)
413
+
414
+ for image_info in self.image_data.values():
415
+ for _ in range(image_info.num_repeats):
416
+ self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key)
417
+
418
+ # bucket情報を表示、格納する
419
+ if self.enable_bucket:
420
+ self.bucket_info = {"buckets": {}}
421
+ print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
422
+ for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)):
423
+ count = len(bucket)
424
+ if count > 0:
425
+ self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
426
+ print(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
427
+
428
+ img_ar_errors = np.array(img_ar_errors)
429
+ mean_img_ar_error = np.mean(np.abs(img_ar_errors))
430
+ self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
431
+ print(f"mean ar error (without repeats): {mean_img_ar_error}")
432
+
433
+ # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
434
+ self.buckets_indices: List(BucketBatchIndex) = []
435
+ for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
436
+ batch_count = int(math.ceil(len(bucket) / self.batch_size))
437
+ for batch_index in range(batch_count):
438
+ self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index))
439
+
440
+ # ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す
441
+ #  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる
442
+ #
443
+ # # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
444
+ # # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
445
+ # # そのためバッチサイズを画像種類までに制限する
446
+ # # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
447
+ # # TO DO 正則化画像をepochまたがりで利用する仕組み
448
+ # num_of_image_types = len(set(bucket))
449
+ # bucket_batch_size = min(self.batch_size, num_of_image_types)
450
+ # batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
451
+ # # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
452
+ # for batch_index in range(batch_count):
453
+ # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
454
+ # ↑ここまで
455
+
456
+ self.shuffle_buckets()
457
+ self._length = len(self.buckets_indices)
458
+
459
+ def shuffle_buckets(self):
460
+ random.shuffle(self.buckets_indices)
461
+ self.bucket_manager.shuffle()
462
+
463
+ def load_image(self, image_path):
464
+ image = Image.open(image_path)
465
+ if not image.mode == "RGB":
466
+ image = image.convert("RGB")
467
+ img = np.array(image, np.uint8)
468
+ return img
469
+
470
+ def trim_and_resize_if_required(self, image, reso, resized_size):
471
+ image_height, image_width = image.shape[0:2]
472
+
473
+ if image_width != resized_size[0] or image_height != resized_size[1]:
474
+ # リサイズする
475
+ image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
476
+
477
+ image_height, image_width = image.shape[0:2]
478
+ if image_width > reso[0]:
479
+ trim_size = image_width - reso[0]
480
+ p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
481
+ # print("w", trim_size, p)
482
+ image = image[:, p:p + reso[0]]
483
+ if image_height > reso[1]:
484
+ trim_size = image_height - reso[1]
485
+ p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
486
+ # print("h", trim_size, p)
487
+ image = image[p:p + reso[1]]
488
+
489
+ assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
490
+ return image
491
+
492
+ def cache_latents(self, vae):
493
+ # TODO ここを高速化したい
494
+ print("caching latents.")
495
+ for info in tqdm(self.image_data.values()):
496
+ if info.latents_npz is not None:
497
+ info.latents = self.load_latents_from_npz(info, False)
498
+ info.latents = torch.FloatTensor(info.latents)
499
+ info.latents_flipped = self.load_latents_from_npz(info, True) # might be None
500
+ if info.latents_flipped is not None:
501
+ info.latents_flipped = torch.FloatTensor(info.latents_flipped)
502
+ continue
503
+
504
+ image = self.load_image(info.absolute_path)
505
+ image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
506
+
507
+ img_tensor = self.image_transforms(image)
508
+ img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
509
+ info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
510
+
511
+ if self.flip_aug:
512
+ image = image[:, ::-1].copy() # cannot convert to Tensor without copy
513
+ img_tensor = self.image_transforms(image)
514
+ img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
515
+ info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
516
+
517
+ def get_image_size(self, image_path):
518
+ image = Image.open(image_path)
519
+ return image.size
520
+
521
+ def load_image_with_face_info(self, image_path: str):
522
+ img = self.load_image(image_path)
523
+
524
+ face_cx = face_cy = face_w = face_h = 0
525
+ if self.face_crop_aug_range is not None:
526
+ tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
527
+ if len(tokens) >= 5:
528
+ face_cx = int(tokens[-4])
529
+ face_cy = int(tokens[-3])
530
+ face_w = int(tokens[-2])
531
+ face_h = int(tokens[-1])
532
+
533
+ return img, face_cx, face_cy, face_w, face_h
534
+
535
+ # いい感じに切り出す
536
+ def crop_target(self, image, face_cx, face_cy, face_w, face_h):
537
+ height, width = image.shape[0:2]
538
+ if height == self.height and width == self.width:
539
+ return image
540
+
541
+ # 画像サイズはsizeより大きいのでリサイズする
542
+ face_size = max(face_w, face_h)
543
+ min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
544
+ min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
545
+ max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
546
+ if min_scale >= max_scale: # range指定がmin==max
547
+ scale = min_scale
548
+ else:
549
+ scale = random.uniform(min_scale, max_scale)
550
+
551
+ nh = int(height * scale + .5)
552
+ nw = int(width * scale + .5)
553
+ assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
554
+ image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
555
+ face_cx = int(face_cx * scale + .5)
556
+ face_cy = int(face_cy * scale + .5)
557
+ height, width = nh, nw
558
+
559
+ # 顔を中心として448*640とかへ切り出す
560
+ for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
561
+ p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
562
+
563
+ if self.random_crop:
564
+ # 背景も含めるために顔を中心に置く確率を高めつつずらす
565
+ range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
566
+ p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
567
+ else:
568
+ # range指定があるときのみ、すこしだけランダムに(わりと適当)
569
+ if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
570
+ if face_size > self.size // 10 and face_size >= 40:
571
+ p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
572
+
573
+ p1 = max(0, min(p1, length - target_size))
574
+
575
+ if axis == 0:
576
+ image = image[p1:p1 + target_size, :]
577
+ else:
578
+ image = image[:, p1:p1 + target_size]
579
+
580
+ return image
581
+
582
+ def load_latents_from_npz(self, image_info: ImageInfo, flipped):
583
+ npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
584
+ if npz_file is None:
585
+ return None
586
+ return np.load(npz_file)['arr_0']
587
+
588
+ def __len__(self):
589
+ return self._length
590
+
591
+ def __getitem__(self, index):
592
+ if index == 0:
593
+ self.shuffle_buckets()
594
+
595
+ bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
596
+ bucket_batch_size = self.buckets_indices[index].bucket_batch_size
597
+ image_index = self.buckets_indices[index].batch_index * bucket_batch_size
598
+
599
+ loss_weights = []
600
+ captions = []
601
+ input_ids_list = []
602
+ latents_list = []
603
+ images = []
604
+
605
+ for image_key in bucket[image_index:image_index + bucket_batch_size]:
606
+ image_info = self.image_data[image_key]
607
+ loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
608
+
609
+ # image/latentsを処理する
610
+ if image_info.latents is not None:
611
+ latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped
612
+ image = None
613
+ elif image_info.latents_npz is not None:
614
+ latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5)
615
+ latents = torch.FloatTensor(latents)
616
+ image = None
617
+ else:
618
+ # 画像を読み込み、必要ならcropする
619
+ img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
620
+ im_h, im_w = img.shape[0:2]
621
+
622
+ if self.enable_bucket:
623
+ img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
624
+ else:
625
+ if face_cx > 0: # 顔位置情報あり
626
+ img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
627
+ elif im_h > self.height or im_w > self.width:
628
+ assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
629
+ if im_h > self.height:
630
+ p = random.randint(0, im_h - self.height)
631
+ img = img[p:p + self.height]
632
+ if im_w > self.width:
633
+ p = random.randint(0, im_w - self.width)
634
+ img = img[:, p:p + self.width]
635
+
636
+ im_h, im_w = img.shape[0:2]
637
+ assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
638
+
639
+ # augmentation
640
+ if self.aug is not None:
641
+ img = self.aug(image=img)['image']
642
+
643
+ latents = None
644
+ image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
645
+
646
+ images.append(image)
647
+ latents_list.append(latents)
648
+
649
+ caption = self.process_caption(image_info.caption)
650
+ captions.append(caption)
651
+ if not self.token_padding_disabled: # this option might be omitted in future
652
+ input_ids_list.append(self.get_input_ids(caption))
653
+
654
+ example = {}
655
+ example['loss_weights'] = torch.FloatTensor(loss_weights)
656
+
657
+ if self.token_padding_disabled:
658
+ # padding=True means pad in the batch
659
+ example['input_ids'] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
660
+ else:
661
+ # batch processing seems to be good
662
+ example['input_ids'] = torch.stack(input_ids_list)
663
+
664
+ if images[0] is not None:
665
+ images = torch.stack(images)
666
+ images = images.to(memory_format=torch.contiguous_format).float()
667
+ else:
668
+ images = None
669
+ example['images'] = images
670
+
671
+ example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None
672
+
673
+ if self.debug_dataset:
674
+ example['image_keys'] = bucket[image_index:image_index + self.batch_size]
675
+ example['captions'] = captions
676
+ return example
677
+
678
+
679
+ class DreamBoothDataset(BaseDataset):
680
+ def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
681
+ super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
682
+ resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
683
+
684
+ assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
685
+
686
+ self.batch_size = batch_size
687
+ self.size = min(self.width, self.height) # 短いほう
688
+ self.prior_loss_weight = prior_loss_weight
689
+ self.latents_cache = None
690
+
691
+ self.enable_bucket = enable_bucket
692
+ if self.enable_bucket:
693
+ assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
694
+ assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
695
+ self.min_bucket_reso = min_bucket_reso
696
+ self.max_bucket_reso = max_bucket_reso
697
+ self.bucket_reso_steps = bucket_reso_steps
698
+ self.bucket_no_upscale = bucket_no_upscale
699
+ else:
700
+ self.min_bucket_reso = None
701
+ self.max_bucket_reso = None
702
+ self.bucket_reso_steps = None # この情報は使われない
703
+ self.bucket_no_upscale = False
704
+
705
+ def read_caption(img_path):
706
+ # captionの候補ファイル名を作る
707
+ base_name = os.path.splitext(img_path)[0]
708
+ base_name_face_det = base_name
709
+ tokens = base_name.split("_")
710
+ if len(tokens) >= 5:
711
+ base_name_face_det = "_".join(tokens[:-4])
712
+ cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
713
+
714
+ caption = None
715
+ for cap_path in cap_paths:
716
+ if os.path.isfile(cap_path):
717
+ with open(cap_path, "rt", encoding='utf-8') as f:
718
+ try:
719
+ lines = f.readlines()
720
+ except UnicodeDecodeError as e:
721
+ print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
722
+ raise e
723
+ assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
724
+ caption = lines[0].strip()
725
+ break
726
+ return caption
727
+
728
+ def load_dreambooth_dir(dir):
729
+ if not os.path.isdir(dir):
730
+ # print(f"ignore file: {dir}")
731
+ return 0, [], []
732
+
733
+ tokens = os.path.basename(dir).split('_')
734
+ try:
735
+ n_repeats = int(tokens[0])
736
+ except ValueError as e:
737
+ print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
738
+ return 0, [], []
739
+
740
+ caption_by_folder = '_'.join(tokens[1:])
741
+ img_paths = glob_images(dir, "*")
742
+ print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
743
+
744
+ # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
745
+ captions = []
746
+ for img_path in img_paths:
747
+ cap_for_img = read_caption(img_path)
748
+ captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
749
+
750
+ self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
751
+
752
+ return n_repeats, img_paths, captions
753
+
754
+ print("prepare train images.")
755
+ train_dirs = os.listdir(train_data_dir)
756
+ num_train_images = 0
757
+ for dir in train_dirs:
758
+ n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
759
+ num_train_images += n_repeats * len(img_paths)
760
+
761
+ for img_path, caption in zip(img_paths, captions):
762
+ info = ImageInfo(img_path, n_repeats, caption, False, img_path)
763
+ self.register_image(info)
764
+
765
+ self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
766
+
767
+ print(f"{num_train_images} train images with repeating.")
768
+ self.num_train_images = num_train_images
769
+
770
+ # reg imageは数を数えて学習画像と同じ枚数にする
771
+ num_reg_images = 0
772
+ if reg_data_dir:
773
+ print("prepare reg images.")
774
+ reg_infos: List[ImageInfo] = []
775
+
776
+ reg_dirs = os.listdir(reg_data_dir)
777
+ for dir in reg_dirs:
778
+ n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
779
+ num_reg_images += n_repeats * len(img_paths)
780
+
781
+ for img_path, caption in zip(img_paths, captions):
782
+ info = ImageInfo(img_path, n_repeats, caption, True, img_path)
783
+ reg_infos.append(info)
784
+
785
+ self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
786
+
787
+ print(f"{num_reg_images} reg images.")
788
+ if num_train_images < num_reg_images:
789
+ print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
790
+
791
+ if num_reg_images == 0:
792
+ print("no regularization images / 正則化画像が見つかりませんでした")
793
+ else:
794
+ # num_repeatsを計算する:どうせ大した数ではないのでループで処理する
795
+ n = 0
796
+ first_loop = True
797
+ while n < num_train_images:
798
+ for info in reg_infos:
799
+ if first_loop:
800
+ self.register_image(info)
801
+ n += info.num_repeats
802
+ else:
803
+ info.num_repeats += 1
804
+ n += 1
805
+ if n >= num_train_images:
806
+ break
807
+ first_loop = False
808
+
809
+ self.num_reg_images = num_reg_images
810
+
811
+
812
+ class FineTuningDataset(BaseDataset):
813
+ def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
814
+ super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
815
+ resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
816
+
817
+ # メタデータを読み込む
818
+ if os.path.exists(json_file_name):
819
+ print(f"loading existing metadata: {json_file_name}")
820
+ with open(json_file_name, "rt", encoding='utf-8') as f:
821
+ metadata = json.load(f)
822
+ else:
823
+ raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}")
824
+
825
+ self.metadata = metadata
826
+ self.train_data_dir = train_data_dir
827
+ self.batch_size = batch_size
828
+
829
+ tags_list = []
830
+ for image_key, img_md in metadata.items():
831
+ # path情報を作る
832
+ if os.path.exists(image_key):
833
+ abs_path = image_key
834
+ else:
835
+ # わりといい加減だがいい方法が思いつかん
836
+ abs_path = glob_images(train_data_dir, image_key)
837
+ assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
838
+ abs_path = abs_path[0]
839
+
840
+ caption = img_md.get('caption')
841
+ tags = img_md.get('tags')
842
+ if caption is None:
843
+ caption = tags
844
+ elif tags is not None and len(tags) > 0:
845
+ caption = caption + ', ' + tags
846
+ tags_list.append(tags)
847
+ assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
848
+
849
+ image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
850
+ image_info.image_size = img_md.get('train_resolution')
851
+
852
+ if not self.color_aug and not self.random_crop:
853
+ # if npz exists, use them
854
+ image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
855
+
856
+ self.register_image(image_info)
857
+ self.num_train_images = len(metadata) * dataset_repeats
858
+ self.num_reg_images = 0
859
+
860
+ # TODO do not record tag freq when no tag
861
+ self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
862
+ self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
863
+
864
+ # check existence of all npz files
865
+ use_npz_latents = not (self.color_aug or self.random_crop)
866
+ if use_npz_latents:
867
+ npz_any = False
868
+ npz_all = True
869
+ for image_info in self.image_data.values():
870
+ has_npz = image_info.latents_npz is not None
871
+ npz_any = npz_any or has_npz
872
+
873
+ if self.flip_aug:
874
+ has_npz = has_npz and image_info.latents_npz_flipped is not None
875
+ npz_all = npz_all and has_npz
876
+
877
+ if npz_any and not npz_all:
878
+ break
879
+
880
+ if not npz_any:
881
+ use_npz_latents = False
882
+ print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します")
883
+ elif not npz_all:
884
+ use_npz_latents = False
885
+ print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
886
+ if self.flip_aug:
887
+ print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
888
+ # else:
889
+ # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
890
+
891
+ # check min/max bucket size
892
+ sizes = set()
893
+ resos = set()
894
+ for image_info in self.image_data.values():
895
+ if image_info.image_size is None:
896
+ sizes = None # not calculated
897
+ break
898
+ sizes.add(image_info.image_size[0])
899
+ sizes.add(image_info.image_size[1])
900
+ resos.add(tuple(image_info.image_size))
901
+
902
+ if sizes is None:
903
+ if use_npz_latents:
904
+ use_npz_latents = False
905
+ print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します")
906
+
907
+ assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
908
+
909
+ self.enable_bucket = enable_bucket
910
+ if self.enable_bucket:
911
+ self.min_bucket_reso = min_bucket_reso
912
+ self.max_bucket_reso = max_bucket_reso
913
+ self.bucket_reso_steps = bucket_reso_steps
914
+ self.bucket_no_upscale = bucket_no_upscale
915
+ else:
916
+ if not enable_bucket:
917
+ print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします")
918
+ print("using bucket info in metadata / メタデータ内のbucket情報を使います")
919
+ self.enable_bucket = True
920
+
921
+ assert not bucket_no_upscale, "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません"
922
+
923
+ # bucket情報を初期化しておく、make_bucketsで再作成しない
924
+ self.bucket_manager = BucketManager(False, None, None, None, None)
925
+ self.bucket_manager.set_predefined_resos(resos)
926
+
927
+ # npz情報をきれいにしておく
928
+ if not use_npz_latents:
929
+ for image_info in self.image_data.values():
930
+ image_info.latents_npz = image_info.latents_npz_flipped = None
931
+
932
+ def image_key_to_npz_file(self, image_key):
933
+ base_name = os.path.splitext(image_key)[0]
934
+ npz_file_norm = base_name + '.npz'
935
+
936
+ if os.path.exists(npz_file_norm):
937
+ # image_key is full path
938
+ npz_file_flip = base_name + '_flip.npz'
939
+ if not os.path.exists(npz_file_flip):
940
+ npz_file_flip = None
941
+ return npz_file_norm, npz_file_flip
942
+
943
+ # image_key is relative path
944
+ npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz')
945
+ npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz')
946
+
947
+ if not os.path.exists(npz_file_norm):
948
+ npz_file_norm = None
949
+ npz_file_flip = None
950
+ elif not os.path.exists(npz_file_flip):
951
+ npz_file_flip = None
952
+
953
+ return npz_file_norm, npz_file_flip
954
+
955
+
956
+ def debug_dataset(train_dataset, show_input_ids=False):
957
+ print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
958
+ print("Escape for exit. / Escキーで中断、終了します")
959
+
960
+ train_dataset.set_current_epoch(1)
961
+ k = 0
962
+ for i, example in enumerate(train_dataset):
963
+ if example['latents'] is not None:
964
+ print(f"sample has latents from npz file: {example['latents'].size()}")
965
+ for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
966
+ print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
967
+ if show_input_ids:
968
+ print(f"input ids: {iid}")
969
+ if example['images'] is not None:
970
+ im = example['images'][j]
971
+ print(f"image size: {im.size()}")
972
+ im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
973
+ im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
974
+ im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
975
+ if os.name == 'nt': # only windows
976
+ cv2.imshow("img", im)
977
+ k = cv2.waitKey()
978
+ cv2.destroyAllWindows()
979
+ if k == 27:
980
+ break
981
+ if k == 27 or (example['images'] is None and i >= 8):
982
+ break
983
+
984
+
985
+ def glob_images(directory, base="*"):
986
+ img_paths = []
987
+ for ext in IMAGE_EXTENSIONS:
988
+ if base == '*':
989
+ img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
990
+ else:
991
+ img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
992
+ # img_paths = list(set(img_paths)) # 重複を排除
993
+ # img_paths.sort()
994
+ return img_paths
995
+
996
+
997
+ def glob_images_pathlib(dir_path, recursive):
998
+ image_paths = []
999
+ if recursive:
1000
+ for ext in IMAGE_EXTENSIONS:
1001
+ image_paths += list(dir_path.rglob('*' + ext))
1002
+ else:
1003
+ for ext in IMAGE_EXTENSIONS:
1004
+ image_paths += list(dir_path.glob('*' + ext))
1005
+ # image_paths = list(set(image_paths)) # 重複を排除
1006
+ # image_paths.sort()
1007
+ return image_paths
1008
+
1009
+ # endregion
1010
+
1011
+
1012
+ # region モジュール入れ替え部
1013
+ """
1014
+ 高速化のためのモジュール入れ替え
1015
+ """
1016
+
1017
+ # FlashAttentionを使うCrossAttention
1018
+ # based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
1019
+ # LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
1020
+
1021
+ # constants
1022
+
1023
+ EPSILON = 1e-6
1024
+
1025
+ # helper functions
1026
+
1027
+
1028
+ def exists(val):
1029
+ return val is not None
1030
+
1031
+
1032
+ def default(val, d):
1033
+ return val if exists(val) else d
1034
+
1035
+
1036
+ def model_hash(filename):
1037
+ """Old model hash used by stable-diffusion-webui"""
1038
+ try:
1039
+ with open(filename, "rb") as file:
1040
+ m = hashlib.sha256()
1041
+
1042
+ file.seek(0x100000)
1043
+ m.update(file.read(0x10000))
1044
+ return m.hexdigest()[0:8]
1045
+ except FileNotFoundError:
1046
+ return 'NOFILE'
1047
+
1048
+
1049
+ def calculate_sha256(filename):
1050
+ """New model hash used by stable-diffusion-webui"""
1051
+ hash_sha256 = hashlib.sha256()
1052
+ blksize = 1024 * 1024
1053
+
1054
+ with open(filename, "rb") as f:
1055
+ for chunk in iter(lambda: f.read(blksize), b""):
1056
+ hash_sha256.update(chunk)
1057
+
1058
+ return hash_sha256.hexdigest()
1059
+
1060
+
1061
+ def precalculate_safetensors_hashes(tensors, metadata):
1062
+ """Precalculate the model hashes needed by sd-webui-additional-networks to
1063
+ save time on indexing the model later."""
1064
+
1065
+ # Because writing user metadata to the file can change the result of
1066
+ # sd_models.model_hash(), only retain the training metadata for purposes of
1067
+ # calculating the hash, as they are meant to be immutable
1068
+ metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
1069
+
1070
+ bytes = safetensors.torch.save(tensors, metadata)
1071
+ b = BytesIO(bytes)
1072
+
1073
+ model_hash = addnet_hash_safetensors(b)
1074
+ legacy_hash = addnet_hash_legacy(b)
1075
+ return model_hash, legacy_hash
1076
+
1077
+
1078
+ def addnet_hash_legacy(b):
1079
+ """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
1080
+ m = hashlib.sha256()
1081
+
1082
+ b.seek(0x100000)
1083
+ m.update(b.read(0x10000))
1084
+ return m.hexdigest()[0:8]
1085
+
1086
+
1087
+ def addnet_hash_safetensors(b):
1088
+ """New model hash used by sd-webui-additional-networks for .safetensors format files"""
1089
+ hash_sha256 = hashlib.sha256()
1090
+ blksize = 1024 * 1024
1091
+
1092
+ b.seek(0)
1093
+ header = b.read(8)
1094
+ n = int.from_bytes(header, "little")
1095
+
1096
+ offset = n + 8
1097
+ b.seek(offset)
1098
+ for chunk in iter(lambda: b.read(blksize), b""):
1099
+ hash_sha256.update(chunk)
1100
+
1101
+ return hash_sha256.hexdigest()
1102
+
1103
+
1104
+ def get_git_revision_hash() -> str:
1105
+ try:
1106
+ return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=os.path.dirname(__file__)).decode('ascii').strip()
1107
+ except:
1108
+ return "(unknown)"
1109
+
1110
+
1111
+ # flash attention forwards and backwards
1112
+
1113
+ # https://arxiv.org/abs/2205.14135
1114
+
1115
+
1116
+ class FlashAttentionFunction(torch.autograd.function.Function):
1117
+ @ staticmethod
1118
+ @ torch.no_grad()
1119
+ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
1120
+ """ Algorithm 2 in the paper """
1121
+
1122
+ device = q.device
1123
+ dtype = q.dtype
1124
+ max_neg_value = -torch.finfo(q.dtype).max
1125
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
1126
+
1127
+ o = torch.zeros_like(q)
1128
+ all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
1129
+ all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
1130
+
1131
+ scale = (q.shape[-1] ** -0.5)
1132
+
1133
+ if not exists(mask):
1134
+ mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
1135
+ else:
1136
+ mask = rearrange(mask, 'b n -> b 1 1 n')
1137
+ mask = mask.split(q_bucket_size, dim=-1)
1138
+
1139
+ row_splits = zip(
1140
+ q.split(q_bucket_size, dim=-2),
1141
+ o.split(q_bucket_size, dim=-2),
1142
+ mask,
1143
+ all_row_sums.split(q_bucket_size, dim=-2),
1144
+ all_row_maxes.split(q_bucket_size, dim=-2),
1145
+ )
1146
+
1147
+ for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
1148
+ q_start_index = ind * q_bucket_size - qk_len_diff
1149
+
1150
+ col_splits = zip(
1151
+ k.split(k_bucket_size, dim=-2),
1152
+ v.split(k_bucket_size, dim=-2),
1153
+ )
1154
+
1155
+ for k_ind, (kc, vc) in enumerate(col_splits):
1156
+ k_start_index = k_ind * k_bucket_size
1157
+
1158
+ attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
1159
+
1160
+ if exists(row_mask):
1161
+ attn_weights.masked_fill_(~row_mask, max_neg_value)
1162
+
1163
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
1164
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
1165
+ device=device).triu(q_start_index - k_start_index + 1)
1166
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
1167
+
1168
+ block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
1169
+ attn_weights -= block_row_maxes
1170
+ exp_weights = torch.exp(attn_weights)
1171
+
1172
+ if exists(row_mask):
1173
+ exp_weights.masked_fill_(~row_mask, 0.)
1174
+
1175
+ block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
1176
+
1177
+ new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
1178
+
1179
+ exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
1180
+
1181
+ exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
1182
+ exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
1183
+
1184
+ new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
1185
+
1186
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
1187
+
1188
+ row_maxes.copy_(new_row_maxes)
1189
+ row_sums.copy_(new_row_sums)
1190
+
1191
+ ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
1192
+ ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
1193
+
1194
+ return o
1195
+
1196
+ @ staticmethod
1197
+ @ torch.no_grad()
1198
+ def backward(ctx, do):
1199
+ """ Algorithm 4 in the paper """
1200
+
1201
+ causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
1202
+ q, k, v, o, l, m = ctx.saved_tensors
1203
+
1204
+ device = q.device
1205
+
1206
+ max_neg_value = -torch.finfo(q.dtype).max
1207
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
1208
+
1209
+ dq = torch.zeros_like(q)
1210
+ dk = torch.zeros_like(k)
1211
+ dv = torch.zeros_like(v)
1212
+
1213
+ row_splits = zip(
1214
+ q.split(q_bucket_size, dim=-2),
1215
+ o.split(q_bucket_size, dim=-2),
1216
+ do.split(q_bucket_size, dim=-2),
1217
+ mask,
1218
+ l.split(q_bucket_size, dim=-2),
1219
+ m.split(q_bucket_size, dim=-2),
1220
+ dq.split(q_bucket_size, dim=-2)
1221
+ )
1222
+
1223
+ for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
1224
+ q_start_index = ind * q_bucket_size - qk_len_diff
1225
+
1226
+ col_splits = zip(
1227
+ k.split(k_bucket_size, dim=-2),
1228
+ v.split(k_bucket_size, dim=-2),
1229
+ dk.split(k_bucket_size, dim=-2),
1230
+ dv.split(k_bucket_size, dim=-2),
1231
+ )
1232
+
1233
+ for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
1234
+ k_start_index = k_ind * k_bucket_size
1235
+
1236
+ attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
1237
+
1238
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
1239
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
1240
+ device=device).triu(q_start_index - k_start_index + 1)
1241
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
1242
+
1243
+ exp_attn_weights = torch.exp(attn_weights - mc)
1244
+
1245
+ if exists(row_mask):
1246
+ exp_attn_weights.masked_fill_(~row_mask, 0.)
1247
+
1248
+ p = exp_attn_weights / lc
1249
+
1250
+ dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
1251
+ dp = einsum('... i d, ... j d -> ... i j', doc, vc)
1252
+
1253
+ D = (doc * oc).sum(dim=-1, keepdims=True)
1254
+ ds = p * scale * (dp - D)
1255
+
1256
+ dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
1257
+ dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
1258
+
1259
+ dqc.add_(dq_chunk)
1260
+ dkc.add_(dk_chunk)
1261
+ dvc.add_(dv_chunk)
1262
+
1263
+ return dq, dk, dv, None, None, None, None
1264
+
1265
+
1266
+ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
1267
+ if mem_eff_attn:
1268
+ replace_unet_cross_attn_to_memory_efficient()
1269
+ elif xformers:
1270
+ replace_unet_cross_attn_to_xformers()
1271
+
1272
+
1273
+ def replace_unet_cross_attn_to_memory_efficient():
1274
+ print("Replace CrossAttention.forward to use FlashAttention (not xformers)")
1275
+ flash_func = FlashAttentionFunction
1276
+
1277
+ def forward_flash_attn(self, x, context=None, mask=None):
1278
+ q_bucket_size = 512
1279
+ k_bucket_size = 1024
1280
+
1281
+ h = self.heads
1282
+ q = self.to_q(x)
1283
+
1284
+ context = context if context is not None else x
1285
+ context = context.to(x.dtype)
1286
+
1287
+ if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
1288
+ context_k, context_v = self.hypernetwork.forward(x, context)
1289
+ context_k = context_k.to(x.dtype)
1290
+ context_v = context_v.to(x.dtype)
1291
+ else:
1292
+ context_k = context
1293
+ context_v = context
1294
+
1295
+ k = self.to_k(context_k)
1296
+ v = self.to_v(context_v)
1297
+ del context, x
1298
+
1299
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
1300
+
1301
+ out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
1302
+
1303
+ out = rearrange(out, 'b h n d -> b n (h d)')
1304
+
1305
+ # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`)
1306
+ out = self.to_out[0](out)
1307
+ out = self.to_out[1](out)
1308
+ return out
1309
+
1310
+ diffusers.models.attention.CrossAttention.forward = forward_flash_attn
1311
+
1312
+
1313
+ def replace_unet_cross_attn_to_xformers():
1314
+ print("Replace CrossAttention.forward to use xformers")
1315
+ try:
1316
+ import xformers.ops
1317
+ except ImportError:
1318
+ raise ImportError("No xformers / xformersがインストールされていないようです")
1319
+
1320
+ def forward_xformers(self, x, context=None, mask=None):
1321
+ h = self.heads
1322
+ q_in = self.to_q(x)
1323
+
1324
+ context = default(context, x)
1325
+ context = context.to(x.dtype)
1326
+
1327
+ if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
1328
+ context_k, context_v = self.hypernetwork.forward(x, context)
1329
+ context_k = context_k.to(x.dtype)
1330
+ context_v = context_v.to(x.dtype)
1331
+ else:
1332
+ context_k = context
1333
+ context_v = context
1334
+
1335
+ k_in = self.to_k(context_k)
1336
+ v_in = self.to_v(context_v)
1337
+
1338
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
1339
+ del q_in, k_in, v_in
1340
+
1341
+ q = q.contiguous()
1342
+ k = k.contiguous()
1343
+ v = v.contiguous()
1344
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
1345
+
1346
+ out = rearrange(out, 'b n h d -> b n (h d)', h=h)
1347
+
1348
+ # diffusers 0.7.0~
1349
+ out = self.to_out[0](out)
1350
+ out = self.to_out[1](out)
1351
+ return out
1352
+
1353
+ diffusers.models.attention.CrossAttention.forward = forward_xformers
1354
+ # endregion
1355
+
1356
+
1357
+ # region arguments
1358
+
1359
+ def add_sd_models_arguments(parser: argparse.ArgumentParser):
1360
+ # for pretrained models
1361
+ parser.add_argument("--v2", action='store_true',
1362
+ help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
1363
+ parser.add_argument("--v_parameterization", action='store_true',
1364
+ help='enable v-parameterization training / v-parameterization学習を有効にする')
1365
+ parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
1366
+ help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
1367
+
1368
+
1369
+ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
1370
+ parser.add_argument("--output_dir", type=str, default=None,
1371
+ help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
1372
+ parser.add_argument("--output_name", type=str, default=None,
1373
+ help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
1374
+ parser.add_argument("--save_precision", type=str, default=None,
1375
+ choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
1376
+ parser.add_argument("--save_every_n_epochs", type=int, default=None,
1377
+ help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
1378
+ parser.add_argument("--save_n_epoch_ratio", type=int, default=None,
1379
+ help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)")
1380
+ parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
1381
+ parser.add_argument("--save_last_n_epochs_state", type=int, default=None,
1382
+ help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
1383
+ parser.add_argument("--save_state", action="store_true",
1384
+ help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
1385
+ parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
1386
+
1387
+ parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
1388
+ parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
1389
+ help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
1390
+ parser.add_argument("--use_8bit_adam", action="store_true",
1391
+ help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
1392
+ parser.add_argument("--use_lion_optimizer", action="store_true",
1393
+ help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
1394
+ parser.add_argument("--mem_eff_attn", action="store_true",
1395
+ help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
1396
+ parser.add_argument("--xformers", action="store_true",
1397
+ help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
1398
+ parser.add_argument("--vae", type=str, default=None,
1399
+ help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
1400
+
1401
+ parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
1402
+ parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
1403
+ parser.add_argument("--max_train_epochs", type=int, default=None,
1404
+ help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
1405
+ parser.add_argument("--max_data_loader_n_workers", type=int, default=8,
1406
+ help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
1407
+ parser.add_argument("--persistent_data_loader_workers", action="store_true",
1408
+ help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)")
1409
+ parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
1410
+ parser.add_argument("--gradient_checkpointing", action="store_true",
1411
+ help="enable gradient checkpointing / grandient checkpointingを有効にする")
1412
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
1413
+ help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数")
1414
+ parser.add_argument("--mixed_precision", type=str, default="no",
1415
+ choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
1416
+ parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
1417
+ parser.add_argument("--clip_skip", type=int, default=None,
1418
+ help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
1419
+ parser.add_argument("--logging_dir", type=str, default=None,
1420
+ help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
1421
+ parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
1422
+ parser.add_argument("--lr_scheduler", type=str, default="constant",
1423
+ help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
1424
+ parser.add_argument("--lr_warmup_steps", type=int, default=0,
1425
+ help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
1426
+ parser.add_argument("--noise_offset", type=float, default=None,
1427
+ help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
1428
+ parser.add_argument("--lowram", action="store_true",
1429
+ help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
1430
+
1431
+ if support_dreambooth:
1432
+ # DreamBooth training
1433
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0,
1434
+ help="loss weight for regularization images / 正則化画像のlossの重み")
1435
+
1436
+
1437
+ def verify_training_args(args: argparse.Namespace):
1438
+ if args.v_parameterization and not args.v2:
1439
+ print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
1440
+ if args.v2 and args.clip_skip is not None:
1441
+ print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
1442
+
1443
+
1444
+ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool):
1445
+ # dataset common
1446
+ parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
1447
+ parser.add_argument("--shuffle_caption", action="store_true",
1448
+ help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする")
1449
+ parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
1450
+ parser.add_argument("--caption_extention", type=str, default=None,
1451
+ help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
1452
+ parser.add_argument("--keep_tokens", type=int, default=None,
1453
+ help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
1454
+ parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
1455
+ parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
1456
+ parser.add_argument("--face_crop_aug_range", type=str, default=None,
1457
+ help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)")
1458
+ parser.add_argument("--random_crop", action="store_true",
1459
+ help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)")
1460
+ parser.add_argument("--debug_dataset", action="store_true",
1461
+ help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
1462
+ parser.add_argument("--resolution", type=str, default=None,
1463
+ help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)")
1464
+ parser.add_argument("--cache_latents", action="store_true",
1465
+ help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
1466
+ parser.add_argument("--enable_bucket", action="store_true",
1467
+ help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
1468
+ parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
1469
+ parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
1470
+ parser.add_argument("--bucket_reso_steps", type=int, default=64,
1471
+ help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
1472
+ parser.add_argument("--bucket_no_upscale", action="store_true",
1473
+ help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
1474
+
1475
+ if support_caption_dropout:
1476
+ # Textual Inversion はcaptionのdropoutをsupportしない
1477
+ # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
1478
+ parser.add_argument("--caption_dropout_rate", type=float, default=0,
1479
+ help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
1480
+ parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
1481
+ help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
1482
+ parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
1483
+ help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
1484
+
1485
+ if support_dreambooth:
1486
+ # DreamBooth dataset
1487
+ parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
1488
+
1489
+ if support_caption:
1490
+ # caption dataset
1491
+ parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル")
1492
+ parser.add_argument("--dataset_repeats", type=int, default=1,
1493
+ help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数")
1494
+
1495
+
1496
+ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
1497
+ parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"],
1498
+ help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)")
1499
+ parser.add_argument("--use_safetensors", action='store_true',
1500
+ help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)")
1501
+
1502
+ # endregion
1503
+
1504
+ # region utils
1505
+
1506
+
1507
+ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
1508
+ # backward compatibility
1509
+ if args.caption_extention is not None:
1510
+ args.caption_extension = args.caption_extention
1511
+ args.caption_extention = None
1512
+
1513
+ if args.cache_latents:
1514
+ assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
1515
+ assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
1516
+
1517
+ # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
1518
+ if args.resolution is not None:
1519
+ args.resolution = tuple([int(r) for r in args.resolution.split(',')])
1520
+ if len(args.resolution) == 1:
1521
+ args.resolution = (args.resolution[0], args.resolution[0])
1522
+ assert len(args.resolution) == 2, \
1523
+ f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
1524
+
1525
+ if args.face_crop_aug_range is not None:
1526
+ args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
1527
+ assert len(args.face_crop_aug_range) == 2 and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1], \
1528
+ f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}"
1529
+ else:
1530
+ args.face_crop_aug_range = None
1531
+
1532
+ if support_metadata:
1533
+ if args.in_json is not None and (args.color_aug or args.random_crop):
1534
+ print(f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます")
1535
+
1536
+
1537
+ def load_tokenizer(args: argparse.Namespace):
1538
+ print("prepare tokenizer")
1539
+ if args.v2:
1540
+ tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
1541
+ else:
1542
+ tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
1543
+ if args.max_token_length is not None:
1544
+ print(f"update token length: {args.max_token_length}")
1545
+ return tokenizer
1546
+
1547
+
1548
+ def prepare_accelerator(args: argparse.Namespace):
1549
+ if args.logging_dir is None:
1550
+ log_with = None
1551
+ logging_dir = None
1552
+ else:
1553
+ log_with = "tensorboard"
1554
+ log_prefix = "" if args.log_prefix is None else args.log_prefix
1555
+ logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime())
1556
+
1557
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision,
1558
+ log_with=log_with, logging_dir=logging_dir)
1559
+
1560
+ # accelerateの互換性問題を解決する
1561
+ accelerator_0_15 = True
1562
+ try:
1563
+ accelerator.unwrap_model("dummy", True)
1564
+ print("Using accelerator 0.15.0 or above.")
1565
+ except TypeError:
1566
+ accelerator_0_15 = False
1567
+
1568
+ def unwrap_model(model):
1569
+ if accelerator_0_15:
1570
+ return accelerator.unwrap_model(model, True)
1571
+ return accelerator.unwrap_model(model)
1572
+
1573
+ return accelerator, unwrap_model
1574
+
1575
+
1576
+ def prepare_dtype(args: argparse.Namespace):
1577
+ weight_dtype = torch.float32
1578
+ if args.mixed_precision == "fp16":
1579
+ weight_dtype = torch.float16
1580
+ elif args.mixed_precision == "bf16":
1581
+ weight_dtype = torch.bfloat16
1582
+
1583
+ save_dtype = None
1584
+ if args.save_precision == "fp16":
1585
+ save_dtype = torch.float16
1586
+ elif args.save_precision == "bf16":
1587
+ save_dtype = torch.bfloat16
1588
+ elif args.save_precision == "float":
1589
+ save_dtype = torch.float32
1590
+
1591
+ return weight_dtype, save_dtype
1592
+
1593
+
1594
+ def load_target_model(args: argparse.Namespace, weight_dtype):
1595
+ load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers
1596
+ if load_stable_diffusion_format:
1597
+ print("load StableDiffusion checkpoint")
1598
+ text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
1599
+ else:
1600
+ print("load Diffusers pretrained models")
1601
+ pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
1602
+ text_encoder = pipe.text_encoder
1603
+ vae = pipe.vae
1604
+ unet = pipe.unet
1605
+ del pipe
1606
+
1607
+ # VAEを読み込む
1608
+ if args.vae is not None:
1609
+ vae = model_util.load_vae(args.vae, weight_dtype)
1610
+ print("additional VAE loaded")
1611
+
1612
+ return text_encoder, vae, unet, load_stable_diffusion_format
1613
+
1614
+
1615
+ def patch_accelerator_for_fp16_training(accelerator):
1616
+ org_unscale_grads = accelerator.scaler._unscale_grads_
1617
+
1618
+ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
1619
+ return org_unscale_grads(optimizer, inv_scale, found_inf, True)
1620
+
1621
+ accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
1622
+
1623
+
1624
+ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None):
1625
+ # with no_token_padding, the length is not max length, return result immediately
1626
+ if input_ids.size()[-1] != tokenizer.model_max_length:
1627
+ return text_encoder(input_ids)[0]
1628
+
1629
+ b_size = input_ids.size()[0]
1630
+ input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
1631
+
1632
+ if args.clip_skip is None:
1633
+ encoder_hidden_states = text_encoder(input_ids)[0]
1634
+ else:
1635
+ enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
1636
+ encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
1637
+ encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
1638
+
1639
+ # bs*3, 77, 768 or 1024
1640
+ encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
1641
+
1642
+ if args.max_token_length is not None:
1643
+ if args.v2:
1644
+ # v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
1645
+ states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
1646
+ for i in range(1, args.max_token_length, tokenizer.model_max_length):
1647
+ chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # <BOS> の後から 最後の前まで
1648
+ if i > 0:
1649
+ for j in range(len(chunk)):
1650
+ if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
1651
+ chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
1652
+ states_list.append(chunk) # <BOS> の後から <EOS> の前まで
1653
+ states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
1654
+ encoder_hidden_states = torch.cat(states_list, dim=1)
1655
+ else:
1656
+ # v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
1657
+ states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
1658
+ for i in range(1, args.max_token_length, tokenizer.model_max_length):
1659
+ states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
1660
+ states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
1661
+ encoder_hidden_states = torch.cat(states_list, dim=1)
1662
+
1663
+ if weight_dtype is not None:
1664
+ # this is required for additional network training
1665
+ encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
1666
+
1667
+ return encoder_hidden_states
1668
+
1669
+
1670
+ def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch):
1671
+ model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
1672
+ ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt")
1673
+ return model_name, ckpt_name
1674
+
1675
+
1676
+ def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int):
1677
+ saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
1678
+ if saving:
1679
+ os.makedirs(args.output_dir, exist_ok=True)
1680
+ save_func()
1681
+
1682
+ if args.save_last_n_epochs is not None:
1683
+ remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
1684
+ remove_old_func(remove_epoch_no)
1685
+ return saving
1686
+
1687
+
1688
+ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae):
1689
+ epoch_no = epoch + 1
1690
+ model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no)
1691
+
1692
+ if save_stable_diffusion_format:
1693
+ def save_sd():
1694
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
1695
+ print(f"saving checkpoint: {ckpt_file}")
1696
+ model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
1697
+ src_path, epoch_no, global_step, save_dtype, vae)
1698
+
1699
+ def remove_sd(old_epoch_no):
1700
+ _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
1701
+ old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
1702
+ if os.path.exists(old_ckpt_file):
1703
+ print(f"removing old checkpoint: {old_ckpt_file}")
1704
+ os.remove(old_ckpt_file)
1705
+
1706
+ save_func = save_sd
1707
+ remove_old_func = remove_sd
1708
+ else:
1709
+ def save_du():
1710
+ out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no))
1711
+ print(f"saving model: {out_dir}")
1712
+ os.makedirs(out_dir, exist_ok=True)
1713
+ model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
1714
+ src_path, vae=vae, use_safetensors=use_safetensors)
1715
+
1716
+ def remove_du(old_epoch_no):
1717
+ out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
1718
+ if os.path.exists(out_dir_old):
1719
+ print(f"removing old model: {out_dir_old}")
1720
+ shutil.rmtree(out_dir_old)
1721
+
1722
+ save_func = save_du
1723
+ remove_old_func = remove_du
1724
+
1725
+ saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
1726
+ if saving and args.save_state:
1727
+ save_state_on_epoch_end(args, accelerator, model_name, epoch_no)
1728
+
1729
+
1730
+ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no):
1731
+ print("saving state.")
1732
+ accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
1733
+
1734
+ last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
1735
+ if last_n_epochs is not None:
1736
+ remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs
1737
+ state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
1738
+ if os.path.exists(state_dir_old):
1739
+ print(f"removing old state: {state_dir_old}")
1740
+ shutil.rmtree(state_dir_old)
1741
+
1742
+
1743
+ def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, global_step: int, text_encoder, unet, vae):
1744
+ model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
1745
+
1746
+ if save_stable_diffusion_format:
1747
+ os.makedirs(args.output_dir, exist_ok=True)
1748
+
1749
+ ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt")
1750
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
1751
+
1752
+ print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
1753
+ model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
1754
+ src_path, epoch, global_step, save_dtype, vae)
1755
+ else:
1756
+ out_dir = os.path.join(args.output_dir, model_name)
1757
+ os.makedirs(out_dir, exist_ok=True)
1758
+
1759
+ print(f"save trained model as Diffusers to {out_dir}")
1760
+ model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
1761
+ src_path, vae=vae, use_safetensors=use_safetensors)
1762
+
1763
+
1764
+ def save_state_on_train_end(args: argparse.Namespace, accelerator):
1765
+ print("saving last state.")
1766
+ os.makedirs(args.output_dir, exist_ok=True)
1767
+ model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
1768
+ accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
1769
+
1770
+ # endregion
1771
+
1772
+ # region 前処理用
1773
+
1774
+
1775
+ class ImageLoadingDataset(torch.utils.data.Dataset):
1776
+ def __init__(self, image_paths):
1777
+ self.images = image_paths
1778
+
1779
+ def __len__(self):
1780
+ return len(self.images)
1781
+
1782
+ def __getitem__(self, idx):
1783
+ img_path = self.images[idx]
1784
+
1785
+ try:
1786
+ image = Image.open(img_path).convert("RGB")
1787
+ # convert to tensor temporarily so dataloader will accept it
1788
+ tensor_pil = transforms.functional.pil_to_tensor(image)
1789
+ except Exception as e:
1790
+ print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
1791
+ return None
1792
+
1793
+ return (tensor_pil, img_path)
1794
+
1795
+
1796
+ # endregion
fine_tune.py CHANGED
@@ -13,11 +13,7 @@ import diffusers
13
  from diffusers import DDPMScheduler
14
 
15
  import library.train_util as train_util
16
- import library.config_util as config_util
17
- from library.config_util import (
18
- ConfigSanitizer,
19
- BlueprintGenerator,
20
- )
21
 
22
  def collate_fn(examples):
23
  return examples[0]
@@ -34,36 +30,25 @@ def train(args):
34
 
35
  tokenizer = train_util.load_tokenizer(args)
36
 
37
- blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
38
- if args.dataset_config is not None:
39
- print(f"Load dataset config from {args.dataset_config}")
40
- user_config = config_util.load_user_config(args.dataset_config)
41
- ignored = ["train_data_dir", "in_json"]
42
- if any(getattr(args, attr) is not None for attr in ignored):
43
- print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
44
- else:
45
- user_config = {
46
- "datasets": [{
47
- "subsets": [{
48
- "image_dir": args.train_data_dir,
49
- "metadata_file": args.in_json,
50
- }]
51
- }]
52
- }
53
-
54
- blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
55
- train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
56
 
57
  if args.debug_dataset:
58
- train_util.debug_dataset(train_dataset_group)
59
  return
60
- if len(train_dataset_group) == 0:
61
  print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
62
  return
63
 
64
- if cache_latents:
65
- assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
66
-
67
  # acceleratorを準備する
68
  print("prepare accelerator")
69
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
@@ -124,7 +109,7 @@ def train(args):
124
  vae.requires_grad_(False)
125
  vae.eval()
126
  with torch.no_grad():
127
- train_dataset_group.cache_latents(vae)
128
  vae.to("cpu")
129
  if torch.cuda.is_available():
130
  torch.cuda.empty_cache()
@@ -164,13 +149,33 @@ def train(args):
164
 
165
  # 学習に必要なクラスを準備する
166
  print("prepare optimizer, data loader etc.")
167
- _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  # dataloaderを準備する
170
  # DataLoaderのプロセス数:0はメインプロセスになる
171
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
172
  train_dataloader = torch.utils.data.DataLoader(
173
- train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
174
 
175
  # 学習ステップ数を計算する
176
  if args.max_train_epochs is not None:
@@ -178,9 +183,8 @@ def train(args):
178
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
179
 
180
  # lr schedulerを用意する
181
- lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
182
- num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
183
- num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
184
 
185
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
186
  if args.full_fp16:
@@ -214,7 +218,7 @@ def train(args):
214
  # 学習する
215
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
216
  print("running training / 学習開始")
217
- print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
218
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
219
  print(f" num epochs / epoch数: {num_train_epochs}")
220
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
@@ -233,7 +237,7 @@ def train(args):
233
 
234
  for epoch in range(num_train_epochs):
235
  print(f"epoch {epoch+1}/{num_train_epochs}")
236
- train_dataset_group.set_current_epoch(epoch + 1)
237
 
238
  for m in training_models:
239
  m.train()
@@ -282,11 +286,11 @@ def train(args):
282
  loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
283
 
284
  accelerator.backward(loss)
285
- if accelerator.sync_gradients and args.max_grad_norm != 0.0:
286
  params_to_clip = []
287
  for m in training_models:
288
  params_to_clip.extend(m.parameters())
289
- accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
290
 
291
  optimizer.step()
292
  lr_scheduler.step()
@@ -297,16 +301,11 @@ def train(args):
297
  progress_bar.update(1)
298
  global_step += 1
299
 
300
- train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
301
-
302
  current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
303
  if args.logging_dir is not None:
304
- logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
305
- if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
306
- logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
307
  accelerator.log(logs, step=global_step)
308
 
309
- # TODO moving averageにする
310
  loss_total += current_loss
311
  avr_loss = loss_total / (step+1)
312
  logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
@@ -316,7 +315,7 @@ def train(args):
316
  break
317
 
318
  if args.logging_dir is not None:
319
- logs = {"loss/epoch": loss_total / len(train_dataloader)}
320
  accelerator.log(logs, step=epoch+1)
321
 
322
  accelerator.wait_for_everyone()
@@ -326,8 +325,6 @@ def train(args):
326
  train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
327
  save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
328
 
329
- train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
330
-
331
  is_main_process = accelerator.is_main_process
332
  if is_main_process:
333
  unet = unwrap_model(unet)
@@ -354,8 +351,6 @@ if __name__ == '__main__':
354
  train_util.add_dataset_arguments(parser, False, True, True)
355
  train_util.add_training_arguments(parser, False)
356
  train_util.add_sd_saving_arguments(parser)
357
- train_util.add_optimizer_arguments(parser)
358
- config_util.add_config_arguments(parser)
359
 
360
  parser.add_argument("--diffusers_xformers", action='store_true',
361
  help='use xformers by diffusers / Diffusersでxformersを使用する')
 
13
  from diffusers import DDPMScheduler
14
 
15
  import library.train_util as train_util
16
+
 
 
 
 
17
 
18
  def collate_fn(examples):
19
  return examples[0]
 
30
 
31
  tokenizer = train_util.load_tokenizer(args)
32
 
33
+ train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
34
+ tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
35
+ args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
36
+ args.bucket_reso_steps, args.bucket_no_upscale,
37
+ args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
38
+ args.dataset_repeats, args.debug_dataset)
39
+
40
+ # 学習データのdropout率を設定する
41
+ train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
42
+
43
+ train_dataset.make_buckets()
 
 
 
 
 
 
 
 
44
 
45
  if args.debug_dataset:
46
+ train_util.debug_dataset(train_dataset)
47
  return
48
+ if len(train_dataset) == 0:
49
  print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
50
  return
51
 
 
 
 
52
  # acceleratorを準備する
53
  print("prepare accelerator")
54
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
 
109
  vae.requires_grad_(False)
110
  vae.eval()
111
  with torch.no_grad():
112
+ train_dataset.cache_latents(vae)
113
  vae.to("cpu")
114
  if torch.cuda.is_available():
115
  torch.cuda.empty_cache()
 
149
 
150
  # 学習に必要なクラスを準備する
151
  print("prepare optimizer, data loader etc.")
152
+
153
+ # 8-bit Adamを使う
154
+ if args.use_8bit_adam:
155
+ try:
156
+ import bitsandbytes as bnb
157
+ except ImportError:
158
+ raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
159
+ print("use 8-bit Adam optimizer")
160
+ optimizer_class = bnb.optim.AdamW8bit
161
+ elif args.use_lion_optimizer:
162
+ try:
163
+ import lion_pytorch
164
+ except ImportError:
165
+ raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
166
+ print("use Lion optimizer")
167
+ optimizer_class = lion_pytorch.Lion
168
+ else:
169
+ optimizer_class = torch.optim.AdamW
170
+
171
+ # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
172
+ optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
173
 
174
  # dataloaderを準備する
175
  # DataLoaderのプロセス数:0はメインプロセスになる
176
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
177
  train_dataloader = torch.utils.data.DataLoader(
178
+ train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
179
 
180
  # 学習ステップ数を計算する
181
  if args.max_train_epochs is not None:
 
183
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
184
 
185
  # lr schedulerを用意する
186
+ lr_scheduler = diffusers.optimization.get_scheduler(
187
+ args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
 
188
 
189
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
190
  if args.full_fp16:
 
218
  # 学習する
219
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
220
  print("running training / 学習開始")
221
+ print(f" num examples / サンプル数: {train_dataset.num_train_images}")
222
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
223
  print(f" num epochs / epoch数: {num_train_epochs}")
224
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
 
237
 
238
  for epoch in range(num_train_epochs):
239
  print(f"epoch {epoch+1}/{num_train_epochs}")
240
+ train_dataset.set_current_epoch(epoch + 1)
241
 
242
  for m in training_models:
243
  m.train()
 
286
  loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
287
 
288
  accelerator.backward(loss)
289
+ if accelerator.sync_gradients:
290
  params_to_clip = []
291
  for m in training_models:
292
  params_to_clip.extend(m.parameters())
293
+ accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
294
 
295
  optimizer.step()
296
  lr_scheduler.step()
 
301
  progress_bar.update(1)
302
  global_step += 1
303
 
 
 
304
  current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
305
  if args.logging_dir is not None:
306
+ logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
 
 
307
  accelerator.log(logs, step=global_step)
308
 
 
309
  loss_total += current_loss
310
  avr_loss = loss_total / (step+1)
311
  logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
 
315
  break
316
 
317
  if args.logging_dir is not None:
318
+ logs = {"epoch_loss": loss_total / len(train_dataloader)}
319
  accelerator.log(logs, step=epoch+1)
320
 
321
  accelerator.wait_for_everyone()
 
325
  train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
326
  save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
327
 
 
 
328
  is_main_process = accelerator.is_main_process
329
  if is_main_process:
330
  unet = unwrap_model(unet)
 
351
  train_util.add_dataset_arguments(parser, False, True, True)
352
  train_util.add_training_arguments(parser, False)
353
  train_util.add_sd_saving_arguments(parser)
 
 
354
 
355
  parser.add_argument("--diffusers_xformers", action='store_true',
356
  help='use xformers by diffusers / Diffusersでxformersを使用する')
gen_img_diffusers.py CHANGED
@@ -47,7 +47,7 @@ VGG(
47
  """
48
 
49
  import json
50
- from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
51
  import glob
52
  import importlib
53
  import inspect
@@ -60,6 +60,7 @@ import math
60
  import os
61
  import random
62
  import re
 
63
 
64
  import diffusers
65
  import numpy as np
@@ -80,9 +81,6 @@ from PIL import Image
80
  from PIL.PngImagePlugin import PngInfo
81
 
82
  import library.model_util as model_util
83
- import library.train_util as train_util
84
- import tools.original_control_net as original_control_net
85
- from tools.original_control_net import ControlNetInfo
86
 
87
  # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
88
  TOKENIZER_PATH = "openai/clip-vit-large-patch14"
@@ -489,9 +487,6 @@ class PipelineLike():
489
  self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
490
  self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
491
 
492
- # ControlNet
493
- self.control_nets: List[ControlNetInfo] = []
494
-
495
  # Textual Inversion
496
  def add_token_replacement(self, target_token_id, rep_token_ids):
497
  self.token_replacements[target_token_id] = rep_token_ids
@@ -505,11 +500,7 @@ class PipelineLike():
505
  new_tokens.append(token)
506
  return new_tokens
507
 
508
- def set_control_nets(self, ctrl_nets):
509
- self.control_nets = ctrl_nets
510
-
511
  # region xformersとか使う部分:独自に書き換えるので関係なし
512
-
513
  def enable_xformers_memory_efficient_attention(self):
514
  r"""
515
  Enable memory efficient attention as implemented in xformers.
@@ -590,8 +581,6 @@ class PipelineLike():
590
  latents: Optional[torch.FloatTensor] = None,
591
  max_embeddings_multiples: Optional[int] = 3,
592
  output_type: Optional[str] = "pil",
593
- vae_batch_size: float = None,
594
- return_latents: bool = False,
595
  # return_dict: bool = True,
596
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
597
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
@@ -683,9 +672,6 @@ class PipelineLike():
683
  else:
684
  raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
685
 
686
- vae_batch_size = batch_size if vae_batch_size is None else (
687
- int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size)))
688
-
689
  if strength < 0 or strength > 1:
690
  raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
691
 
@@ -766,7 +752,7 @@ class PipelineLike():
766
  text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
767
  text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
768
 
769
- if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None or self.control_nets:
770
  if isinstance(clip_guide_images, PIL.Image.Image):
771
  clip_guide_images = [clip_guide_images]
772
 
@@ -779,7 +765,7 @@ class PipelineLike():
779
  image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
780
  if len(image_embeddings_clip) == 1:
781
  image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
782
- elif self.vgg16_guidance_scale > 0:
783
  size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
784
  clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
785
  clip_guide_images = torch.cat(clip_guide_images, dim=0)
@@ -788,10 +774,6 @@ class PipelineLike():
788
  image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
789
  if len(image_embeddings_vgg16) == 1:
790
  image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
791
- else:
792
- # ControlNetのhintにguide imageを流用する
793
- # 前処理はControlNet側で行う
794
- pass
795
 
796
  # set timesteps
797
  self.scheduler.set_timesteps(num_inference_steps, self.device)
@@ -799,6 +781,7 @@ class PipelineLike():
799
  latents_dtype = text_embeddings.dtype
800
  init_latents_orig = None
801
  mask = None
 
802
 
803
  if init_image is None:
804
  # get the initial random noise unless the user supplied it
@@ -830,8 +813,6 @@ class PipelineLike():
830
  if isinstance(init_image[0], PIL.Image.Image):
831
  init_image = [preprocess_image(im) for im in init_image]
832
  init_image = torch.cat(init_image)
833
- if isinstance(init_image, list):
834
- init_image = torch.stack(init_image)
835
 
836
  # mask image to tensor
837
  if mask_image is not None:
@@ -842,24 +823,9 @@ class PipelineLike():
842
 
843
  # encode the init image into latents and scale the latents
844
  init_image = init_image.to(device=self.device, dtype=latents_dtype)
845
- if init_image.size()[2:] == (height // 8, width // 8):
846
- init_latents = init_image
847
- else:
848
- if vae_batch_size >= batch_size:
849
- init_latent_dist = self.vae.encode(init_image).latent_dist
850
- init_latents = init_latent_dist.sample(generator=generator)
851
- else:
852
- if torch.cuda.is_available():
853
- torch.cuda.empty_cache()
854
- init_latents = []
855
- for i in tqdm(range(0, batch_size, vae_batch_size)):
856
- init_latent_dist = self.vae.encode(init_image[i:i + vae_batch_size]
857
- if vae_batch_size > 1 else init_image[i].unsqueeze(0)).latent_dist
858
- init_latents.append(init_latent_dist.sample(generator=generator))
859
- init_latents = torch.cat(init_latents)
860
-
861
- init_latents = 0.18215 * init_latents
862
-
863
  if len(init_latents) == 1:
864
  init_latents = init_latents.repeat((batch_size, 1, 1, 1))
865
  init_latents_orig = init_latents
@@ -898,21 +864,12 @@ class PipelineLike():
898
  extra_step_kwargs["eta"] = eta
899
 
900
  num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
901
-
902
- if self.control_nets:
903
- guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
904
-
905
  for i, t in enumerate(tqdm(timesteps)):
906
  # expand the latents if we are doing classifier free guidance
907
  latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
908
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
909
-
910
  # predict the noise residual
911
- if self.control_nets:
912
- noise_pred = original_control_net.call_unet_and_control_net(
913
- i, num_latent_input, self.unet, self.control_nets, guided_hints, i / len(timesteps), latent_model_input, t, text_embeddings).sample
914
- else:
915
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
916
 
917
  # perform guidance
918
  if do_classifier_free_guidance:
@@ -954,19 +911,8 @@ class PipelineLike():
954
  if is_cancelled_callback is not None and is_cancelled_callback():
955
  return None
956
 
957
- if return_latents:
958
- return (latents, False)
959
-
960
  latents = 1 / 0.18215 * latents
961
- if vae_batch_size >= batch_size:
962
- image = self.vae.decode(latents).sample
963
- else:
964
- if torch.cuda.is_available():
965
- torch.cuda.empty_cache()
966
- images = []
967
- for i in tqdm(range(0, batch_size, vae_batch_size)):
968
- images.append(self.vae.decode(latents[i:i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample)
969
- image = torch.cat(images)
970
 
971
  image = (image / 2 + 0.5).clamp(0, 1)
972
 
@@ -1649,11 +1595,10 @@ def get_unweighted_text_embeddings(
1649
  if pad == eos: # v1
1650
  text_input_chunk[:, -1] = text_input[0, -1]
1651
  else: # v2
1652
- for j in range(len(text_input_chunk)):
1653
- if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
1654
- text_input_chunk[j, -1] = eos
1655
- if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
1656
- text_input_chunk[j, 1] = eos
1657
 
1658
  if clip_skip is None or clip_skip == 1:
1659
  text_embedding = pipe.text_encoder(text_input_chunk)[0]
@@ -1854,7 +1799,7 @@ def preprocess_mask(mask):
1854
  mask = mask.convert("L")
1855
  w, h = mask.size
1856
  w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
1857
- mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS)
1858
  mask = np.array(mask).astype(np.float32) / 255.0
1859
  mask = np.tile(mask, (4, 1, 1))
1860
  mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
@@ -1872,35 +1817,6 @@ def preprocess_mask(mask):
1872
  # return text_encoder
1873
 
1874
 
1875
- class BatchDataBase(NamedTuple):
1876
- # バッチ分割が必要ないデータ
1877
- step: int
1878
- prompt: str
1879
- negative_prompt: str
1880
- seed: int
1881
- init_image: Any
1882
- mask_image: Any
1883
- clip_prompt: str
1884
- guide_image: Any
1885
-
1886
-
1887
- class BatchDataExt(NamedTuple):
1888
- # バッチ分割が必要なデータ
1889
- width: int
1890
- height: int
1891
- steps: int
1892
- scale: float
1893
- negative_scale: float
1894
- strength: float
1895
- network_muls: Tuple[float]
1896
-
1897
-
1898
- class BatchData(NamedTuple):
1899
- return_latents: bool
1900
- base: BatchDataBase
1901
- ext: BatchDataExt
1902
-
1903
-
1904
  def main(args):
1905
  if args.fp16:
1906
  dtype = torch.float16
@@ -1965,7 +1881,10 @@ def main(args):
1965
  # tokenizerを読み込む
1966
  print("loading tokenizer")
1967
  if use_stable_diffusion_format:
1968
- tokenizer = train_util.load_tokenizer(args)
 
 
 
1969
 
1970
  # schedulerを用意する
1971
  sched_init_args = {}
@@ -2076,13 +1995,11 @@ def main(args):
2076
  # networkを組み込む
2077
  if args.network_module:
2078
  networks = []
2079
- network_default_muls = []
2080
  for i, network_module in enumerate(args.network_module):
2081
  print("import network module:", network_module)
2082
  imported_module = importlib.import_module(network_module)
2083
 
2084
  network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
2085
- network_default_muls.append(network_mul)
2086
 
2087
  net_kwargs = {}
2088
  if args.network_args and i < len(args.network_args):
@@ -2097,7 +2014,7 @@ def main(args):
2097
  network_weight = args.network_weights[i]
2098
  print("load network weights from:", network_weight)
2099
 
2100
- if model_util.is_safetensors(network_weight) and args.network_show_meta:
2101
  from safetensors.torch import safe_open
2102
  with safe_open(network_weight, framework="pt") as f:
2103
  metadata = f.metadata()
@@ -2120,18 +2037,6 @@ def main(args):
2120
  else:
2121
  networks = []
2122
 
2123
- # ControlNetの処理
2124
- control_nets: List[ControlNetInfo] = []
2125
- if args.control_net_models:
2126
- for i, model in enumerate(args.control_net_models):
2127
- prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
2128
- weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
2129
- ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
2130
-
2131
- ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
2132
- prep = original_control_net.load_preprocess(prep_type)
2133
- control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
2134
-
2135
  if args.opt_channels_last:
2136
  print(f"set optimizing: channels last")
2137
  text_encoder.to(memory_format=torch.channels_last)
@@ -2145,14 +2050,9 @@ def main(args):
2145
  if vgg16_model is not None:
2146
  vgg16_model.to(memory_format=torch.channels_last)
2147
 
2148
- for cn in control_nets:
2149
- cn.unet.to(memory_format=torch.channels_last)
2150
- cn.net.to(memory_format=torch.channels_last)
2151
-
2152
  pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
2153
  clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
2154
  vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
2155
- pipe.set_control_nets(control_nets)
2156
  print("pipeline is ready.")
2157
 
2158
  if args.diffusers_xformers:
@@ -2277,34 +2177,18 @@ def main(args):
2277
  mask_images = l
2278
 
2279
  # 画像サイズにオプション指定があるときはリサイズする
2280
- if args.W is not None and args.H is not None:
2281
- if init_images is not None:
2282
- print(f"resize img2img source images to {args.W}*{args.H}")
2283
- init_images = resize_images(init_images, (args.W, args.H))
2284
  if mask_images is not None:
2285
  print(f"resize img2img mask images to {args.W}*{args.H}")
2286
  mask_images = resize_images(mask_images, (args.W, args.H))
2287
 
2288
- if networks and mask_images:
2289
- # mask を領域情報として流用する、現在は1枚だけ対応
2290
- # TODO 複数のnetwork classの混在時の考慮
2291
- print("use mask as region")
2292
- # import cv2
2293
- # for i in range(3):
2294
- # cv2.imshow("msk", np.array(mask_images[0])[:,:,i])
2295
- # cv2.waitKey()
2296
- # cv2.destroyAllWindows()
2297
- networks[0].__class__.set_regions(networks, np.array(mask_images[0]))
2298
- mask_images = None
2299
-
2300
  prev_image = None # for VGG16 guided
2301
  if args.guide_image_path is not None:
2302
- print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
2303
- guide_images = []
2304
- for p in args.guide_image_path:
2305
- guide_images.extend(load_images(p))
2306
-
2307
- print(f"loaded {len(guide_images)} guide images for guidance")
2308
  if len(guide_images) == 0:
2309
  print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
2310
  guide_images = None
@@ -2335,46 +2219,33 @@ def main(args):
2335
  iter_seed = random.randint(0, 0x7fffffff)
2336
 
2337
  # バッチ処理の関数
2338
- def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
2339
  batch_size = len(batch)
2340
 
2341
  # highres_fixの処理
2342
  if highres_fix and not highres_1st:
2343
- # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
2344
- print("process 1st stage")
2345
  batch_1st = []
2346
- for _, base, ext in batch:
2347
- width_1st = int(ext.width * args.highres_fix_scale + .5)
2348
- height_1st = int(ext.height * args.highres_fix_scale + .5)
2349
  width_1st = width_1st - width_1st % 32
2350
  height_1st = height_1st - height_1st % 32
2351
-
2352
- ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale,
2353
- ext.negative_scale, ext.strength, ext.network_muls)
2354
- batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
2355
  images_1st = process_batch(batch_1st, True, True)
2356
 
2357
  # 2nd stageのバッチを作成して以下処理する
2358
- print("process 2nd stage")
2359
- if args.highres_fix_latents_upscaling:
2360
- org_dtype = images_1st.dtype
2361
- if images_1st.dtype == torch.bfloat16:
2362
- images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
2363
- images_1st = torch.nn.functional.interpolate(
2364
- images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode='bilinear') # , antialias=True)
2365
- images_1st = images_1st.to(org_dtype)
2366
-
2367
  batch_2nd = []
2368
- for i, (bd, image) in enumerate(zip(batch, images_1st)):
2369
- if not args.highres_fix_latents_upscaling:
2370
- image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
2371
- bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext)
2372
- batch_2nd.append(bd_2nd)
2373
  batch = batch_2nd
2374
 
2375
- # このバッチの情報を取り出す
2376
- return_latents, (step_first, _, _, _, init_image, mask_image, _, guide_image), \
2377
- (width, height, steps, scale, negative_scale, strength, network_muls) = batch[0]
2378
  noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
2379
 
2380
  prompts = []
@@ -2407,7 +2278,7 @@ def main(args):
2407
  all_images_are_same = True
2408
  all_masks_are_same = True
2409
  all_guide_images_are_same = True
2410
- for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
2411
  prompts.append(prompt)
2412
  negative_prompts.append(negative_prompt)
2413
  seeds.append(seed)
@@ -2424,13 +2295,9 @@ def main(args):
2424
  all_masks_are_same = mask_images[-2] is mask_image
2425
 
2426
  if guide_image is not None:
2427
- if type(guide_image) is list:
2428
- guide_images.extend(guide_image)
2429
- all_guide_images_are_same = False
2430
- else:
2431
- guide_images.append(guide_image)
2432
- if i > 0 and all_guide_images_are_same:
2433
- all_guide_images_are_same = guide_images[-2] is guide_image
2434
 
2435
  # make start code
2436
  torch.manual_seed(seed)
@@ -2453,24 +2320,10 @@ def main(args):
2453
  if guide_images is not None and all_guide_images_are_same:
2454
  guide_images = guide_images[0]
2455
 
2456
- # ControlNet使用時はguide imageをリサイズする
2457
- if control_nets:
2458
- # TODO resample��メソッド
2459
- guide_images = guide_images if type(guide_images) == list else [guide_images]
2460
- guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images]
2461
- if len(guide_images) == 1:
2462
- guide_images = guide_images[0]
2463
-
2464
  # generate
2465
- if networks:
2466
- for n, m in zip(networks, network_muls if network_muls else network_default_muls):
2467
- n.set_multiplier(m)
2468
-
2469
  images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
2470
- output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises,
2471
- vae_batch_size=args.vae_batch_size, return_latents=return_latents,
2472
- clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
2473
- if highres_1st and not args.highres_fix_save_1st: # return images or latents
2474
  return images
2475
 
2476
  # save image
@@ -2545,7 +2398,6 @@ def main(args):
2545
  strength = 0.8 if args.strength is None else args.strength
2546
  negative_prompt = ""
2547
  clip_prompt = None
2548
- network_muls = None
2549
 
2550
  prompt_args = prompt.strip().split(' --')
2551
  prompt = prompt_args[0]
@@ -2609,15 +2461,6 @@ def main(args):
2609
  clip_prompt = m.group(1)
2610
  print(f"clip prompt: {clip_prompt}")
2611
  continue
2612
-
2613
- m = re.match(r'am ([\d\.\-,]+)', parg, re.IGNORECASE)
2614
- if m: # network multiplies
2615
- network_muls = [float(v) for v in m.group(1).split(",")]
2616
- while len(network_muls) < len(networks):
2617
- network_muls.append(network_muls[-1])
2618
- print(f"network mul: {network_muls}")
2619
- continue
2620
-
2621
  except ValueError as ex:
2622
  print(f"Exception in parsing / 解析エラー: {parg}")
2623
  print(ex)
@@ -2655,12 +2498,7 @@ def main(args):
2655
  mask_image = mask_images[global_step % len(mask_images)]
2656
 
2657
  if guide_images is not None:
2658
- if control_nets: # 複数件の場合あり
2659
- c = len(control_nets)
2660
- p = global_step % (len(guide_images) // c)
2661
- guide_image = guide_images[p * c:p * c + c]
2662
- else:
2663
- guide_image = guide_images[global_step % len(guide_images)]
2664
  elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
2665
  if prev_image is None:
2666
  print("Generate 1st image without guide image.")
@@ -2668,9 +2506,10 @@ def main(args):
2668
  print("Use previous image as guide image.")
2669
  guide_image = prev_image
2670
 
2671
- b1 = BatchData(False, BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
2672
- BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None))
2673
- if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
 
2674
  process_batch(batch_data, highres_fix)
2675
  batch_data.clear()
2676
 
@@ -2714,8 +2553,6 @@ if __name__ == '__main__':
2714
  parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
2715
  parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
2716
  parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
2717
- parser.add_argument("--vae_batch_size", type=float, default=None,
2718
- help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率")
2719
  parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
2720
  parser.add_argument('--sampler', type=str, default='ddim',
2721
  choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
@@ -2727,8 +2564,6 @@ if __name__ == '__main__':
2727
  parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
2728
  parser.add_argument("--vae", type=str, default=None,
2729
  help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
2730
- parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
2731
- help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
2732
  # parser.add_argument("--replace_clip_l14_336", action='store_true',
2733
  # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
2734
  parser.add_argument("--seed", type=int, default=None,
@@ -2743,15 +2578,12 @@ if __name__ == '__main__':
2743
  parser.add_argument("--opt_channels_last", action='store_true',
2744
  help='set channels last option to model / モデルにchannels lastを指定し最適化する')
2745
  parser.add_argument("--network_module", type=str, default=None, nargs='*',
2746
- help='additional network module to use / 追加ネットワークを使う時そのモジュール名')
2747
  parser.add_argument("--network_weights", type=str, default=None, nargs='*',
2748
- help='additional network weights to load / 追加ネットワークの重み')
2749
- parser.add_argument("--network_mul", type=float, default=None, nargs='*',
2750
- help='additional network multiplier / 追加ネットワークの効果の倍率')
2751
  parser.add_argument("--network_args", type=str, default=None, nargs='*',
2752
  help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
2753
- parser.add_argument("--network_show_meta", action='store_true',
2754
- help='show metadata of network model / ネットワークモデルのメタデータを表示する')
2755
  parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
2756
  help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
2757
  parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
@@ -2765,26 +2597,15 @@ if __name__ == '__main__':
2765
  help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
2766
  parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
2767
  help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
2768
- parser.add_argument("--guide_image_path", type=str, default=None, nargs="*",
2769
- help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
2770
  parser.add_argument("--highres_fix_scale", type=float, default=None,
2771
  help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
2772
  parser.add_argument("--highres_fix_steps", type=int, default=28,
2773
  help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
2774
  parser.add_argument("--highres_fix_save_1st", action='store_true',
2775
  help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
2776
- parser.add_argument("--highres_fix_latents_upscaling", action='store_true',
2777
- help="use latents upscaling for highres fix / highres fixでlatentで拡大する")
2778
  parser.add_argument("--negative_scale", type=float, default=None,
2779
  help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
2780
 
2781
- parser.add_argument("--control_net_models", type=str, default=None, nargs='*',
2782
- help='ControlNet models to use / 使用するControlNetのモデル名')
2783
- parser.add_argument("--control_net_preps", type=str, default=None, nargs='*',
2784
- help='ControlNet preprocess to use / 使用するControlNetのプリプロセス名')
2785
- parser.add_argument("--control_net_weights", type=float, default=None, nargs='*', help='ControlNet weights / ControlNetの重み')
2786
- parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*',
2787
- help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率')
2788
-
2789
  args = parser.parse_args()
2790
  main(args)
 
47
  """
48
 
49
  import json
50
+ from typing import List, Optional, Union
51
  import glob
52
  import importlib
53
  import inspect
 
60
  import os
61
  import random
62
  import re
63
+ from typing import Any, Callable, List, Optional, Union
64
 
65
  import diffusers
66
  import numpy as np
 
81
  from PIL.PngImagePlugin import PngInfo
82
 
83
  import library.model_util as model_util
 
 
 
84
 
85
  # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
86
  TOKENIZER_PATH = "openai/clip-vit-large-patch14"
 
487
  self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
488
  self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
489
 
 
 
 
490
  # Textual Inversion
491
  def add_token_replacement(self, target_token_id, rep_token_ids):
492
  self.token_replacements[target_token_id] = rep_token_ids
 
500
  new_tokens.append(token)
501
  return new_tokens
502
 
 
 
 
503
  # region xformersとか使う部分:独自に書き換えるので関係なし
 
504
  def enable_xformers_memory_efficient_attention(self):
505
  r"""
506
  Enable memory efficient attention as implemented in xformers.
 
581
  latents: Optional[torch.FloatTensor] = None,
582
  max_embeddings_multiples: Optional[int] = 3,
583
  output_type: Optional[str] = "pil",
 
 
584
  # return_dict: bool = True,
585
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
586
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
 
672
  else:
673
  raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
674
 
 
 
 
675
  if strength < 0 or strength > 1:
676
  raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
677
 
 
752
  text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
753
  text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
754
 
755
+ if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None:
756
  if isinstance(clip_guide_images, PIL.Image.Image):
757
  clip_guide_images = [clip_guide_images]
758
 
 
765
  image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
766
  if len(image_embeddings_clip) == 1:
767
  image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
768
+ else:
769
  size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
770
  clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
771
  clip_guide_images = torch.cat(clip_guide_images, dim=0)
 
774
  image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
775
  if len(image_embeddings_vgg16) == 1:
776
  image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
 
 
 
 
777
 
778
  # set timesteps
779
  self.scheduler.set_timesteps(num_inference_steps, self.device)
 
781
  latents_dtype = text_embeddings.dtype
782
  init_latents_orig = None
783
  mask = None
784
+ noise = None
785
 
786
  if init_image is None:
787
  # get the initial random noise unless the user supplied it
 
813
  if isinstance(init_image[0], PIL.Image.Image):
814
  init_image = [preprocess_image(im) for im in init_image]
815
  init_image = torch.cat(init_image)
 
 
816
 
817
  # mask image to tensor
818
  if mask_image is not None:
 
823
 
824
  # encode the init image into latents and scale the latents
825
  init_image = init_image.to(device=self.device, dtype=latents_dtype)
826
+ init_latent_dist = self.vae.encode(init_image).latent_dist
827
+ init_latents = init_latent_dist.sample(generator=generator)
828
+ init_latents = 0.18215 * init_latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
829
  if len(init_latents) == 1:
830
  init_latents = init_latents.repeat((batch_size, 1, 1, 1))
831
  init_latents_orig = init_latents
 
864
  extra_step_kwargs["eta"] = eta
865
 
866
  num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
 
 
 
 
867
  for i, t in enumerate(tqdm(timesteps)):
868
  # expand the latents if we are doing classifier free guidance
869
  latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
870
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
871
  # predict the noise residual
872
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
 
 
 
 
873
 
874
  # perform guidance
875
  if do_classifier_free_guidance:
 
911
  if is_cancelled_callback is not None and is_cancelled_callback():
912
  return None
913
 
 
 
 
914
  latents = 1 / 0.18215 * latents
915
+ image = self.vae.decode(latents).sample
 
 
 
 
 
 
 
 
916
 
917
  image = (image / 2 + 0.5).clamp(0, 1)
918
 
 
1595
  if pad == eos: # v1
1596
  text_input_chunk[:, -1] = text_input[0, -1]
1597
  else: # v2
1598
+ if text_input_chunk[:, -1] != eos and text_input_chunk[:, -1] != pad: # 最後に普通の文字がある
1599
+ text_input_chunk[:, -1] = eos
1600
+ if text_input_chunk[:, 1] == pad: # BOSだけであとはPAD
1601
+ text_input_chunk[:, 1] = eos
 
1602
 
1603
  if clip_skip is None or clip_skip == 1:
1604
  text_embedding = pipe.text_encoder(text_input_chunk)[0]
 
1799
  mask = mask.convert("L")
1800
  w, h = mask.size
1801
  w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
1802
+ mask = mask.resize((w // 8, h // 8), resample=PIL.Image.LANCZOS)
1803
  mask = np.array(mask).astype(np.float32) / 255.0
1804
  mask = np.tile(mask, (4, 1, 1))
1805
  mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
 
1817
  # return text_encoder
1818
 
1819
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1820
  def main(args):
1821
  if args.fp16:
1822
  dtype = torch.float16
 
1881
  # tokenizerを読み込む
1882
  print("loading tokenizer")
1883
  if use_stable_diffusion_format:
1884
+ if args.v2:
1885
+ tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
1886
+ else:
1887
+ tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
1888
 
1889
  # schedulerを用意する
1890
  sched_init_args = {}
 
1995
  # networkを組み込む
1996
  if args.network_module:
1997
  networks = []
 
1998
  for i, network_module in enumerate(args.network_module):
1999
  print("import network module:", network_module)
2000
  imported_module = importlib.import_module(network_module)
2001
 
2002
  network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
 
2003
 
2004
  net_kwargs = {}
2005
  if args.network_args and i < len(args.network_args):
 
2014
  network_weight = args.network_weights[i]
2015
  print("load network weights from:", network_weight)
2016
 
2017
+ if model_util.is_safetensors(network_weight):
2018
  from safetensors.torch import safe_open
2019
  with safe_open(network_weight, framework="pt") as f:
2020
  metadata = f.metadata()
 
2037
  else:
2038
  networks = []
2039
 
 
 
 
 
 
 
 
 
 
 
 
 
2040
  if args.opt_channels_last:
2041
  print(f"set optimizing: channels last")
2042
  text_encoder.to(memory_format=torch.channels_last)
 
2050
  if vgg16_model is not None:
2051
  vgg16_model.to(memory_format=torch.channels_last)
2052
 
 
 
 
 
2053
  pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
2054
  clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
2055
  vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
 
2056
  print("pipeline is ready.")
2057
 
2058
  if args.diffusers_xformers:
 
2177
  mask_images = l
2178
 
2179
  # 画像サイズにオプション指定があるときはリサイズする
2180
+ if init_images is not None and args.W is not None and args.H is not None:
2181
+ print(f"resize img2img source images to {args.W}*{args.H}")
2182
+ init_images = resize_images(init_images, (args.W, args.H))
 
2183
  if mask_images is not None:
2184
  print(f"resize img2img mask images to {args.W}*{args.H}")
2185
  mask_images = resize_images(mask_images, (args.W, args.H))
2186
 
 
 
 
 
 
 
 
 
 
 
 
 
2187
  prev_image = None # for VGG16 guided
2188
  if args.guide_image_path is not None:
2189
+ print(f"load image for CLIP/VGG16 guidance: {args.guide_image_path}")
2190
+ guide_images = load_images(args.guide_image_path)
2191
+ print(f"loaded {len(guide_images)} guide images for CLIP/VGG16 guidance")
 
 
 
2192
  if len(guide_images) == 0:
2193
  print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
2194
  guide_images = None
 
2219
  iter_seed = random.randint(0, 0x7fffffff)
2220
 
2221
  # バッチ処理の関数
2222
+ def process_batch(batch, highres_fix, highres_1st=False):
2223
  batch_size = len(batch)
2224
 
2225
  # highres_fixの処理
2226
  if highres_fix and not highres_1st:
2227
+ # 1st stageのバッチを作成して呼び出す
2228
+ print("process 1st stage1")
2229
  batch_1st = []
2230
+ for params1, (width, height, steps, scale, negative_scale, strength) in batch:
2231
+ width_1st = int(width * args.highres_fix_scale + .5)
2232
+ height_1st = int(height * args.highres_fix_scale + .5)
2233
  width_1st = width_1st - width_1st % 32
2234
  height_1st = height_1st - height_1st % 32
2235
+ batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, negative_scale, strength)))
 
 
 
2236
  images_1st = process_batch(batch_1st, True, True)
2237
 
2238
  # 2nd stageのバッチを作成して以下処理する
2239
+ print("process 2nd stage1")
 
 
 
 
 
 
 
 
2240
  batch_2nd = []
2241
+ for i, (b1, image) in enumerate(zip(batch, images_1st)):
2242
+ image = image.resize((width, height), resample=PIL.Image.LANCZOS)
2243
+ (step, prompt, negative_prompt, seed, _, _, clip_prompt, guide_image), params2 = b1
2244
+ batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2))
 
2245
  batch = batch_2nd
2246
 
2247
+ (step_first, _, _, _, init_image, mask_image, _, guide_image), (width,
2248
+ height, steps, scale, negative_scale, strength) = batch[0]
 
2249
  noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
2250
 
2251
  prompts = []
 
2278
  all_images_are_same = True
2279
  all_masks_are_same = True
2280
  all_guide_images_are_same = True
2281
+ for i, ((_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
2282
  prompts.append(prompt)
2283
  negative_prompts.append(negative_prompt)
2284
  seeds.append(seed)
 
2295
  all_masks_are_same = mask_images[-2] is mask_image
2296
 
2297
  if guide_image is not None:
2298
+ guide_images.append(guide_image)
2299
+ if i > 0 and all_guide_images_are_same:
2300
+ all_guide_images_are_same = guide_images[-2] is guide_image
 
 
 
 
2301
 
2302
  # make start code
2303
  torch.manual_seed(seed)
 
2320
  if guide_images is not None and all_guide_images_are_same:
2321
  guide_images = guide_images[0]
2322
 
 
 
 
 
 
 
 
 
2323
  # generate
 
 
 
 
2324
  images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
2325
+ output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
2326
+ if highres_1st and not args.highres_fix_save_1st:
 
 
2327
  return images
2328
 
2329
  # save image
 
2398
  strength = 0.8 if args.strength is None else args.strength
2399
  negative_prompt = ""
2400
  clip_prompt = None
 
2401
 
2402
  prompt_args = prompt.strip().split(' --')
2403
  prompt = prompt_args[0]
 
2461
  clip_prompt = m.group(1)
2462
  print(f"clip prompt: {clip_prompt}")
2463
  continue
 
 
 
 
 
 
 
 
 
2464
  except ValueError as ex:
2465
  print(f"Exception in parsing / 解析エラー: {parg}")
2466
  print(ex)
 
2498
  mask_image = mask_images[global_step % len(mask_images)]
2499
 
2500
  if guide_images is not None:
2501
+ guide_image = guide_images[global_step % len(guide_images)]
 
 
 
 
 
2502
  elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
2503
  if prev_image is None:
2504
  print("Generate 1st image without guide image.")
 
2506
  print("Use previous image as guide image.")
2507
  guide_image = prev_image
2508
 
2509
+ # TODO named tupleか何かにする
2510
+ b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
2511
+ (width, height, steps, scale, negative_scale, strength))
2512
+ if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
2513
  process_batch(batch_data, highres_fix)
2514
  batch_data.clear()
2515
 
 
2553
  parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
2554
  parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
2555
  parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
 
 
2556
  parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
2557
  parser.add_argument('--sampler', type=str, default='ddim',
2558
  choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
 
2564
  parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
2565
  parser.add_argument("--vae", type=str, default=None,
2566
  help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
 
 
2567
  # parser.add_argument("--replace_clip_l14_336", action='store_true',
2568
  # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
2569
  parser.add_argument("--seed", type=int, default=None,
 
2578
  parser.add_argument("--opt_channels_last", action='store_true',
2579
  help='set channels last option to model / モデルにchannels lastを指定し最適化する')
2580
  parser.add_argument("--network_module", type=str, default=None, nargs='*',
2581
+ help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
2582
  parser.add_argument("--network_weights", type=str, default=None, nargs='*',
2583
+ help='Hypernetwork weights to load / Hypernetworkの重み')
2584
+ parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
 
2585
  parser.add_argument("--network_args", type=str, default=None, nargs='*',
2586
  help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
 
 
2587
  parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
2588
  help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
2589
  parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
 
2597
  help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
2598
  parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
2599
  help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
2600
+ parser.add_argument("--guide_image_path", type=str, default=None, help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
 
2601
  parser.add_argument("--highres_fix_scale", type=float, default=None,
2602
  help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
2603
  parser.add_argument("--highres_fix_steps", type=int, default=28,
2604
  help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
2605
  parser.add_argument("--highres_fix_save_1st", action='store_true',
2606
  help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
 
 
2607
  parser.add_argument("--negative_scale", type=float, default=None,
2608
  help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
2609
 
 
 
 
 
 
 
 
 
2610
  args = parser.parse_args()
2611
  main(args)
library.egg-info/PKG-INFO ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: library
3
+ Version: 0.0.0
4
+ License-File: LICENSE.md
library.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE.md
2
+ README.md
3
+ setup.py
4
+ library/__init__.py
5
+ library/model_util.py
6
+ library/train_util.py
7
+ library.egg-info/PKG-INFO
8
+ library.egg-info/SOURCES.txt
9
+ library.egg-info/dependency_links.txt
10
+ library.egg-info/top_level.txt
library.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
library.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ library
library/model_util.py CHANGED
@@ -4,7 +4,7 @@
4
  import math
5
  import os
6
  import torch
7
- from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
8
  from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
9
  from safetensors.torch import load_file, save_file
10
 
@@ -916,11 +916,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
916
  info = text_model.load_state_dict(converted_text_encoder_checkpoint)
917
  else:
918
  converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
919
-
920
- logging.set_verbosity_error() # don't show annoying warning
921
  text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
922
- logging.set_verbosity_warning()
923
-
924
  info = text_model.load_state_dict(converted_text_encoder_checkpoint)
925
  print("loading text encoder:", info)
926
 
 
4
  import math
5
  import os
6
  import torch
7
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
8
  from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
9
  from safetensors.torch import load_file, save_file
10
 
 
916
  info = text_model.load_state_dict(converted_text_encoder_checkpoint)
917
  else:
918
  converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
 
 
919
  text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
 
 
920
  info = text_model.load_state_dict(converted_text_encoder_checkpoint)
921
  print("loading text encoder:", info)
922
 
library/train_util.py CHANGED
@@ -1,21 +1,12 @@
1
  # common functions for training
2
 
3
  import argparse
4
- import importlib
5
  import json
6
- import re
7
  import shutil
8
  import time
9
- from typing import (
10
- Dict,
11
- List,
12
- NamedTuple,
13
- Optional,
14
- Sequence,
15
- Tuple,
16
- Union,
17
- )
18
  from accelerate import Accelerator
 
19
  import glob
20
  import math
21
  import os
@@ -26,16 +17,10 @@ from io import BytesIO
26
 
27
  from tqdm import tqdm
28
  import torch
29
- from torch.optim import Optimizer
30
  from torchvision import transforms
31
  from transformers import CLIPTokenizer
32
- import transformers
33
  import diffusers
34
- from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
35
- from diffusers import (StableDiffusionPipeline, DDPMScheduler,
36
- EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler,
37
- LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler,
38
- KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler)
39
  import albumentations as albu
40
  import numpy as np
41
  from PIL import Image
@@ -210,95 +195,23 @@ class BucketBatchIndex(NamedTuple):
210
  batch_index: int
211
 
212
 
213
- class AugHelper:
214
- def __init__(self):
215
- # prepare all possible augmentators
216
- color_aug_method = albu.OneOf([
217
- albu.HueSaturationValue(8, 0, 0, p=.5),
218
- albu.RandomGamma((95, 105), p=.5),
219
- ], p=.33)
220
- flip_aug_method = albu.HorizontalFlip(p=0.5)
221
-
222
- # key: (use_color_aug, use_flip_aug)
223
- self.augmentors = {
224
- (True, True): albu.Compose([
225
- color_aug_method,
226
- flip_aug_method,
227
- ], p=1.),
228
- (True, False): albu.Compose([
229
- color_aug_method,
230
- ], p=1.),
231
- (False, True): albu.Compose([
232
- flip_aug_method,
233
- ], p=1.),
234
- (False, False): None
235
- }
236
-
237
- def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]:
238
- return self.augmentors[(use_color_aug, use_flip_aug)]
239
-
240
-
241
- class BaseSubset:
242
- def __init__(self, image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, keep_tokens: int, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], random_crop: bool, caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float) -> None:
243
- self.image_dir = image_dir
244
- self.num_repeats = num_repeats
245
- self.shuffle_caption = shuffle_caption
246
- self.keep_tokens = keep_tokens
247
- self.color_aug = color_aug
248
- self.flip_aug = flip_aug
249
- self.face_crop_aug_range = face_crop_aug_range
250
- self.random_crop = random_crop
251
- self.caption_dropout_rate = caption_dropout_rate
252
- self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
253
- self.caption_tag_dropout_rate = caption_tag_dropout_rate
254
-
255
- self.img_count = 0
256
-
257
-
258
- class DreamBoothSubset(BaseSubset):
259
- def __init__(self, image_dir: str, is_reg: bool, class_tokens: Optional[str], caption_extension: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
260
- assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
261
-
262
- super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
263
- face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
264
-
265
- self.is_reg = is_reg
266
- self.class_tokens = class_tokens
267
- self.caption_extension = caption_extension
268
-
269
- def __eq__(self, other) -> bool:
270
- if not isinstance(other, DreamBoothSubset):
271
- return NotImplemented
272
- return self.image_dir == other.image_dir
273
-
274
-
275
- class FineTuningSubset(BaseSubset):
276
- def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
277
- assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
278
-
279
- super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
280
- face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
281
-
282
- self.metadata_file = metadata_file
283
-
284
- def __eq__(self, other) -> bool:
285
- if not isinstance(other, FineTuningSubset):
286
- return NotImplemented
287
- return self.metadata_file == other.metadata_file
288
-
289
-
290
  class BaseDataset(torch.utils.data.Dataset):
291
- def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None:
292
  super().__init__()
293
- self.tokenizer = tokenizer
294
  self.max_token_length = max_token_length
 
 
295
  # width/height is used when enable_bucket==False
296
  self.width, self.height = (None, None) if resolution is None else resolution
 
 
 
297
  self.debug_dataset = debug_dataset
298
-
299
- self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
300
-
301
  self.token_padding_disabled = False
 
 
302
  self.tag_frequency = {}
303
 
304
  self.enable_bucket = False
@@ -312,28 +225,49 @@ class BaseDataset(torch.utils.data.Dataset):
312
  self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
313
 
314
  self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
 
 
 
315
 
316
  # augmentation
317
- self.aug_helper = AugHelper()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
  self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
320
 
321
  self.image_data: Dict[str, ImageInfo] = {}
322
- self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
323
 
324
  self.replacements = {}
325
 
326
  def set_current_epoch(self, epoch):
327
  self.current_epoch = epoch
328
- self.shuffle_buckets()
 
 
 
 
 
329
 
330
  def set_tag_frequency(self, dir_name, captions):
331
  frequency_for_dir = self.tag_frequency.get(dir_name, {})
332
  self.tag_frequency[dir_name] = frequency_for_dir
333
  for caption in captions:
334
  for tag in caption.split(","):
335
- tag = tag.strip()
336
- if tag:
337
  tag = tag.lower()
338
  frequency = frequency_for_dir.get(tag, 0)
339
  frequency_for_dir[tag] = frequency + 1
@@ -344,36 +278,42 @@ class BaseDataset(torch.utils.data.Dataset):
344
  def add_replacement(self, str_from, str_to):
345
  self.replacements[str_from] = str_to
346
 
347
- def process_caption(self, subset: BaseSubset, caption):
348
  # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
349
- is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate
350
- is_drop_out = is_drop_out or subset.caption_dropout_every_n_epochs > 0 and self.current_epoch % subset.caption_dropout_every_n_epochs == 0
351
 
352
  if is_drop_out:
353
  caption = ""
354
  else:
355
- if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0:
356
  def dropout_tags(tokens):
357
- if subset.caption_tag_dropout_rate <= 0:
358
  return tokens
359
  l = []
360
  for token in tokens:
361
- if random.random() >= subset.caption_tag_dropout_rate:
362
  l.append(token)
363
  return l
364
 
365
- fixed_tokens = []
366
- flex_tokens = [t.strip() for t in caption.strip().split(",")]
367
- if subset.keep_tokens > 0:
368
- fixed_tokens = flex_tokens[:subset.keep_tokens]
369
- flex_tokens = flex_tokens[subset.keep_tokens:]
 
 
 
 
 
370
 
371
- if subset.shuffle_caption:
372
- random.shuffle(flex_tokens)
373
 
374
- flex_tokens = dropout_tags(flex_tokens)
375
 
376
- caption = ", ".join(fixed_tokens + flex_tokens)
 
377
 
378
  # textual inversion対応
379
  for str_from, str_to in self.replacements.items():
@@ -427,9 +367,8 @@ class BaseDataset(torch.utils.data.Dataset):
427
  input_ids = torch.stack(iids_list) # 3,77
428
  return input_ids
429
 
430
- def register_image(self, info: ImageInfo, subset: BaseSubset):
431
  self.image_data[info.image_key] = info
432
- self.image_to_subset[info.image_key] = subset
433
 
434
  def make_buckets(self):
435
  '''
@@ -528,7 +467,7 @@ class BaseDataset(torch.utils.data.Dataset):
528
  img = np.array(image, np.uint8)
529
  return img
530
 
531
- def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size):
532
  image_height, image_width = image.shape[0:2]
533
 
534
  if image_width != resized_size[0] or image_height != resized_size[1]:
@@ -538,27 +477,22 @@ class BaseDataset(torch.utils.data.Dataset):
538
  image_height, image_width = image.shape[0:2]
539
  if image_width > reso[0]:
540
  trim_size = image_width - reso[0]
541
- p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
542
  # print("w", trim_size, p)
543
  image = image[:, p:p + reso[0]]
544
  if image_height > reso[1]:
545
  trim_size = image_height - reso[1]
546
- p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
547
  # print("h", trim_size, p)
548
  image = image[p:p + reso[1]]
549
 
550
  assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
551
  return image
552
 
553
- def is_latent_cacheable(self):
554
- return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
555
-
556
  def cache_latents(self, vae):
557
  # TODO ここを高速化したい
558
  print("caching latents.")
559
  for info in tqdm(self.image_data.values()):
560
- subset = self.image_to_subset[info.image_key]
561
-
562
  if info.latents_npz is not None:
563
  info.latents = self.load_latents_from_npz(info, False)
564
  info.latents = torch.FloatTensor(info.latents)
@@ -568,13 +502,13 @@ class BaseDataset(torch.utils.data.Dataset):
568
  continue
569
 
570
  image = self.load_image(info.absolute_path)
571
- image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
572
 
573
  img_tensor = self.image_transforms(image)
574
  img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
575
  info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
576
 
577
- if subset.flip_aug:
578
  image = image[:, ::-1].copy() # cannot convert to Tensor without copy
579
  img_tensor = self.image_transforms(image)
580
  img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
@@ -584,11 +518,11 @@ class BaseDataset(torch.utils.data.Dataset):
584
  image = Image.open(image_path)
585
  return image.size
586
 
587
- def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
588
  img = self.load_image(image_path)
589
 
590
  face_cx = face_cy = face_w = face_h = 0
591
- if subset.face_crop_aug_range is not None:
592
  tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
593
  if len(tokens) >= 5:
594
  face_cx = int(tokens[-4])
@@ -599,7 +533,7 @@ class BaseDataset(torch.utils.data.Dataset):
599
  return img, face_cx, face_cy, face_w, face_h
600
 
601
  # いい感じに切り出す
602
- def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h):
603
  height, width = image.shape[0:2]
604
  if height == self.height and width == self.width:
605
  return image
@@ -607,8 +541,8 @@ class BaseDataset(torch.utils.data.Dataset):
607
  # 画像サイズはsizeより大きいのでリサイズする
608
  face_size = max(face_w, face_h)
609
  min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
610
- min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
611
- max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
612
  if min_scale >= max_scale: # range指定がmin==max
613
  scale = min_scale
614
  else:
@@ -626,13 +560,13 @@ class BaseDataset(torch.utils.data.Dataset):
626
  for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
627
  p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
628
 
629
- if subset.random_crop:
630
  # 背景も含めるために顔を中心に置く確率を高めつつずらす
631
  range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
632
  p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
633
  else:
634
  # range指定があるときのみ、すこしだけランダムに(わりと適当)
635
- if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]:
636
  if face_size > self.size // 10 and face_size >= 40:
637
  p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
638
 
@@ -655,6 +589,9 @@ class BaseDataset(torch.utils.data.Dataset):
655
  return self._length
656
 
657
  def __getitem__(self, index):
 
 
 
658
  bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
659
  bucket_batch_size = self.buckets_indices[index].bucket_batch_size
660
  image_index = self.buckets_indices[index].batch_index * bucket_batch_size
@@ -667,29 +604,28 @@ class BaseDataset(torch.utils.data.Dataset):
667
 
668
  for image_key in bucket[image_index:image_index + bucket_batch_size]:
669
  image_info = self.image_data[image_key]
670
- subset = self.image_to_subset[image_key]
671
  loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
672
 
673
  # image/latentsを処理する
674
  if image_info.latents is not None:
675
- latents = image_info.latents if not subset.flip_aug or random.random() < .5 else image_info.latents_flipped
676
  image = None
677
  elif image_info.latents_npz is not None:
678
- latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= .5)
679
  latents = torch.FloatTensor(latents)
680
  image = None
681
  else:
682
  # 画像を読み込み、必要ならcropする
683
- img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path)
684
  im_h, im_w = img.shape[0:2]
685
 
686
  if self.enable_bucket:
687
- img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size)
688
  else:
689
  if face_cx > 0: # 顔位置情報あり
690
- img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
691
  elif im_h > self.height or im_w > self.width:
692
- assert subset.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
693
  if im_h > self.height:
694
  p = random.randint(0, im_h - self.height)
695
  img = img[p:p + self.height]
@@ -701,9 +637,8 @@ class BaseDataset(torch.utils.data.Dataset):
701
  assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
702
 
703
  # augmentation
704
- aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
705
- if aug is not None:
706
- img = aug(image=img)['image']
707
 
708
  latents = None
709
  image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
@@ -711,7 +646,7 @@ class BaseDataset(torch.utils.data.Dataset):
711
  images.append(image)
712
  latents_list.append(latents)
713
 
714
- caption = self.process_caption(subset, image_info.caption)
715
  captions.append(caption)
716
  if not self.token_padding_disabled: # this option might be omitted in future
717
  input_ids_list.append(self.get_input_ids(caption))
@@ -742,8 +677,9 @@ class BaseDataset(torch.utils.data.Dataset):
742
 
743
 
744
  class DreamBoothDataset(BaseDataset):
745
- def __init__(self, subsets: Sequence[DreamBoothSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset) -> None:
746
- super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
 
747
 
748
  assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
749
 
@@ -766,7 +702,7 @@ class DreamBoothDataset(BaseDataset):
766
  self.bucket_reso_steps = None # この情報は使われない
767
  self.bucket_no_upscale = False
768
 
769
- def read_caption(img_path, caption_extension):
770
  # captionの候補ファイル名を作る
771
  base_name = os.path.splitext(img_path)[0]
772
  base_name_face_det = base_name
@@ -789,181 +725,153 @@ class DreamBoothDataset(BaseDataset):
789
  break
790
  return caption
791
 
792
- def load_dreambooth_dir(subset: DreamBoothSubset):
793
- if not os.path.isdir(subset.image_dir):
794
- print(f"not directory: {subset.image_dir}")
795
- return [], []
796
 
797
- img_paths = glob_images(subset.image_dir, "*")
798
- print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
 
 
 
 
 
 
 
 
799
 
800
  # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
801
  captions = []
802
  for img_path in img_paths:
803
- cap_for_img = read_caption(img_path, subset.caption_extension)
804
- if cap_for_img is None and subset.class_tokens is None:
805
- print(f"neither caption file nor class tokens are found. use empty caption for {img_path}")
806
- captions.append("")
807
- else:
808
- captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
809
 
810
- self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
811
 
812
- return img_paths, captions
813
 
814
- print("prepare images.")
 
815
  num_train_images = 0
816
- num_reg_images = 0
817
- reg_infos: List[ImageInfo] = []
818
- for subset in subsets:
819
- if subset.num_repeats < 1:
820
- print(
821
- f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
822
- continue
823
-
824
- if subset in self.subsets:
825
- print(
826
- f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
827
- continue
828
-
829
- img_paths, captions = load_dreambooth_dir(subset)
830
- if len(img_paths) < 1:
831
- print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
832
- continue
833
-
834
- if subset.is_reg:
835
- num_reg_images += subset.num_repeats * len(img_paths)
836
- else:
837
- num_train_images += subset.num_repeats * len(img_paths)
838
 
839
  for img_path, caption in zip(img_paths, captions):
840
- info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
841
- if subset.is_reg:
842
- reg_infos.append(info)
843
- else:
844
- self.register_image(info, subset)
845
 
846
- subset.img_count = len(img_paths)
847
- self.subsets.append(subset)
848
 
849
  print(f"{num_train_images} train images with repeating.")
850
  self.num_train_images = num_train_images
851
 
852
- print(f"{num_reg_images} reg images.")
853
- if num_train_images < num_reg_images:
854
- print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
855
-
856
- if num_reg_images == 0:
857
- print("no regularization images / 正則化画像が見つかりませんでした")
858
- else:
859
- # num_repeatsを計算する:どうせ大した数ではないのでループで処理する
860
- n = 0
861
- first_loop = True
862
- while n < num_train_images:
863
- for info in reg_infos:
864
- if first_loop:
865
- self.register_image(info, subset)
866
- n += info.num_repeats
867
- else:
868
- info.num_repeats += 1
869
- n += 1
870
- if n >= num_train_images:
871
- break
872
- first_loop = False
873
-
874
- self.num_reg_images = num_reg_images
875
-
876
-
877
- class FineTuningDataset(BaseDataset):
878
- def __init__(self, subsets: Sequence[FineTuningSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset) -> None:
879
- super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
880
 
881
- self.batch_size = batch_size
 
 
 
882
 
883
- self.num_train_images = 0
884
- self.num_reg_images = 0
 
885
 
886
- for subset in subsets:
887
- if subset.num_repeats < 1:
888
- print(
889
- f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
890
- continue
891
 
892
- if subset in self.subsets:
893
- print(
894
- f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
895
- continue
896
 
897
- # メタデータを読み込む
898
- if os.path.exists(subset.metadata_file):
899
- print(f"loading existing metadata: {subset.metadata_file}")
900
- with open(subset.metadata_file, "rt", encoding='utf-8') as f:
901
- metadata = json.load(f)
902
  else:
903
- raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}")
 
 
 
 
 
 
 
 
 
 
 
 
 
904
 
905
- if len(metadata) < 1:
906
- print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します")
907
- continue
908
-
909
- tags_list = []
910
- for image_key, img_md in metadata.items():
911
- # path情報を作る
912
- if os.path.exists(image_key):
913
- abs_path = image_key
914
- else:
915
- npz_path = os.path.join(subset.image_dir, image_key + ".npz")
916
- if os.path.exists(npz_path):
917
- abs_path = npz_path
918
- else:
919
- # わりといい加減だがいい方法が思いつかん
920
- abs_path = glob_images(subset.image_dir, image_key)
921
- assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
922
- abs_path = abs_path[0]
923
-
924
- caption = img_md.get('caption')
925
- tags = img_md.get('tags')
926
- if caption is None:
927
- caption = tags
928
- elif tags is not None and len(tags) > 0:
929
- caption = caption + ', ' + tags
930
- tags_list.append(tags)
931
-
932
- if caption is None:
933
- caption = ""
934
 
935
- image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
936
- image_info.image_size = img_md.get('train_resolution')
937
 
938
- if not subset.color_aug and not subset.random_crop:
939
- # if npz exists, use them
940
- image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
 
 
 
 
 
 
 
 
 
941
 
942
- self.register_image(image_info, subset)
 
 
943
 
944
- self.num_train_images += len(metadata) * subset.num_repeats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
945
 
946
- # TODO do not record tag freq when no tag
947
- self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list)
948
- subset.img_count = len(metadata)
949
- self.subsets.append(subset)
950
 
951
  # check existence of all npz files
952
- use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
953
  if use_npz_latents:
954
- flip_aug_in_subset = False
955
  npz_any = False
956
  npz_all = True
957
-
958
  for image_info in self.image_data.values():
959
- subset = self.image_to_subset[image_info.image_key]
960
-
961
  has_npz = image_info.latents_npz is not None
962
  npz_any = npz_any or has_npz
963
 
964
- if subset.flip_aug:
965
  has_npz = has_npz and image_info.latents_npz_flipped is not None
966
- flip_aug_in_subset = True
967
  npz_all = npz_all and has_npz
968
 
969
  if npz_any and not npz_all:
@@ -975,7 +883,7 @@ class FineTuningDataset(BaseDataset):
975
  elif not npz_all:
976
  use_npz_latents = False
977
  print(f"some of npz file does not exist. ignore npz files / いくつ���のnpzファイルが見つからないためnpzファイルを無視します")
978
- if flip_aug_in_subset:
979
  print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
980
  # else:
981
  # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
@@ -1021,7 +929,7 @@ class FineTuningDataset(BaseDataset):
1021
  for image_info in self.image_data.values():
1022
  image_info.latents_npz = image_info.latents_npz_flipped = None
1023
 
1024
- def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
1025
  base_name = os.path.splitext(image_key)[0]
1026
  npz_file_norm = base_name + '.npz'
1027
 
@@ -1033,8 +941,8 @@ class FineTuningDataset(BaseDataset):
1033
  return npz_file_norm, npz_file_flip
1034
 
1035
  # image_key is relative path
1036
- npz_file_norm = os.path.join(subset.image_dir, image_key + '.npz')
1037
- npz_file_flip = os.path.join(subset.image_dir, image_key + '_flip.npz')
1038
 
1039
  if not os.path.exists(npz_file_norm):
1040
  npz_file_norm = None
@@ -1045,60 +953,13 @@ class FineTuningDataset(BaseDataset):
1045
  return npz_file_norm, npz_file_flip
1046
 
1047
 
1048
- # behave as Dataset mock
1049
- class DatasetGroup(torch.utils.data.ConcatDataset):
1050
- def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
1051
- self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]]
1052
-
1053
- super().__init__(datasets)
1054
-
1055
- self.image_data = {}
1056
- self.num_train_images = 0
1057
- self.num_reg_images = 0
1058
-
1059
- # simply concat together
1060
- # TODO: handling image_data key duplication among dataset
1061
- # In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset.
1062
- for dataset in datasets:
1063
- self.image_data.update(dataset.image_data)
1064
- self.num_train_images += dataset.num_train_images
1065
- self.num_reg_images += dataset.num_reg_images
1066
-
1067
- def add_replacement(self, str_from, str_to):
1068
- for dataset in self.datasets:
1069
- dataset.add_replacement(str_from, str_to)
1070
-
1071
- # def make_buckets(self):
1072
- # for dataset in self.datasets:
1073
- # dataset.make_buckets()
1074
-
1075
- def cache_latents(self, vae):
1076
- for i, dataset in enumerate(self.datasets):
1077
- print(f"[Dataset {i}]")
1078
- dataset.cache_latents(vae)
1079
-
1080
- def is_latent_cacheable(self) -> bool:
1081
- return all([dataset.is_latent_cacheable() for dataset in self.datasets])
1082
-
1083
- def set_current_epoch(self, epoch):
1084
- for dataset in self.datasets:
1085
- dataset.set_current_epoch(epoch)
1086
-
1087
- def disable_token_padding(self):
1088
- for dataset in self.datasets:
1089
- dataset.disable_token_padding()
1090
-
1091
-
1092
  def debug_dataset(train_dataset, show_input_ids=False):
1093
  print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
1094
  print("Escape for exit. / Escキーで中断、終了します")
1095
 
1096
  train_dataset.set_current_epoch(1)
1097
  k = 0
1098
- indices = list(range(len(train_dataset)))
1099
- random.shuffle(indices)
1100
- for i, idx in enumerate(indices):
1101
- example = train_dataset[idx]
1102
  if example['latents'] is not None:
1103
  print(f"sample has latents from npz file: {example['latents'].size()}")
1104
  for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
@@ -1503,35 +1364,6 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
1503
  help='enable v-parameterization training / v-parameterization学習を有効にする')
1504
  parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
1505
  help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
1506
- parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
1507
- help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
1508
-
1509
-
1510
- def add_optimizer_arguments(parser: argparse.ArgumentParser):
1511
- parser.add_argument("--optimizer_type", type=str, default="",
1512
- help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor")
1513
-
1514
- # backward compatibility
1515
- parser.add_argument("--use_8bit_adam", action="store_true",
1516
- help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
1517
- parser.add_argument("--use_lion_optimizer", action="store_true",
1518
- help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
1519
-
1520
- parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
1521
- parser.add_argument("--max_grad_norm", default=1.0, type=float,
1522
- help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない")
1523
-
1524
- parser.add_argument("--optimizer_args", type=str, default=None, nargs='*',
1525
- help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")")
1526
-
1527
- parser.add_argument("--lr_scheduler", type=str, default="constant",
1528
- help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor")
1529
- parser.add_argument("--lr_warmup_steps", type=int, default=0,
1530
- help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
1531
- parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
1532
- help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
1533
- parser.add_argument("--lr_scheduler_power", type=float, default=1,
1534
- help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
1535
 
1536
 
1537
  def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
@@ -1555,6 +1387,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
1555
  parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
1556
  parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
1557
  help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
 
 
 
 
1558
  parser.add_argument("--mem_eff_attn", action="store_true",
1559
  help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
1560
  parser.add_argument("--xformers", action="store_true",
@@ -1562,6 +1398,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
1562
  parser.add_argument("--vae", type=str, default=None,
1563
  help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
1564
 
 
1565
  parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
1566
  parser.add_argument("--max_train_epochs", type=int, default=None,
1567
  help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
@@ -1582,23 +1419,15 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
1582
  parser.add_argument("--logging_dir", type=str, default=None,
1583
  help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
1584
  parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
 
 
 
 
1585
  parser.add_argument("--noise_offset", type=float, default=None,
1586
  help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
1587
  parser.add_argument("--lowram", action="store_true",
1588
  help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
1589
 
1590
- parser.add_argument("--sample_every_n_steps", type=int, default=None,
1591
- help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する")
1592
- parser.add_argument("--sample_every_n_epochs", type=int, default=None,
1593
- help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)")
1594
- parser.add_argument("--sample_prompts", type=str, default=None,
1595
- help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル")
1596
- parser.add_argument('--sample_sampler', type=str, default='ddim',
1597
- choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
1598
- 'dpmsolver++', 'dpmsingle',
1599
- 'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'],
1600
- help=f'sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類')
1601
-
1602
  if support_dreambooth:
1603
  # DreamBooth training
1604
  parser.add_argument("--prior_loss_weight", type=float, default=1.0,
@@ -1620,8 +1449,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
1620
  parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
1621
  parser.add_argument("--caption_extention", type=str, default=None,
1622
  help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
1623
- parser.add_argument("--keep_tokens", type=int, default=0,
1624
- help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)")
1625
  parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
1626
  parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
1627
  parser.add_argument("--face_crop_aug_range", type=str, default=None,
@@ -1646,11 +1475,11 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
1646
  if support_caption_dropout:
1647
  # Textual Inversion はcaptionのdropoutをsupportしない
1648
  # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
1649
- parser.add_argument("--caption_dropout_rate", type=float, default=0.0,
1650
  help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
1651
- parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=0,
1652
  help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
1653
- parser.add_argument("--caption_tag_dropout_rate", type=float, default=0.0,
1654
  help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
1655
 
1656
  if support_dreambooth:
@@ -1675,256 +1504,16 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
1675
  # region utils
1676
 
1677
 
1678
- def get_optimizer(args, trainable_params):
1679
- # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
1680
-
1681
- optimizer_type = args.optimizer_type
1682
- if args.use_8bit_adam:
1683
- assert not args.use_lion_optimizer, "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています"
1684
- assert optimizer_type is None or optimizer_type == "", "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています"
1685
- optimizer_type = "AdamW8bit"
1686
-
1687
- elif args.use_lion_optimizer:
1688
- assert optimizer_type is None or optimizer_type == "", "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています"
1689
- optimizer_type = "Lion"
1690
-
1691
- if optimizer_type is None or optimizer_type == "":
1692
- optimizer_type = "AdamW"
1693
- optimizer_type = optimizer_type.lower()
1694
-
1695
- # 引数を分解する:boolとfloat、tupleのみ対応
1696
- optimizer_kwargs = {}
1697
- if args.optimizer_args is not None and len(args.optimizer_args) > 0:
1698
- for arg in args.optimizer_args:
1699
- key, value = arg.split('=')
1700
-
1701
- value = value.split(",")
1702
- for i in range(len(value)):
1703
- if value[i].lower() == "true" or value[i].lower() == "false":
1704
- value[i] = (value[i].lower() == "true")
1705
- else:
1706
- value[i] = float(value[i])
1707
- if len(value) == 1:
1708
- value = value[0]
1709
- else:
1710
- value = tuple(value)
1711
-
1712
- optimizer_kwargs[key] = value
1713
- # print("optkwargs:", optimizer_kwargs)
1714
-
1715
- lr = args.learning_rate
1716
-
1717
- if optimizer_type == "AdamW8bit".lower():
1718
- try:
1719
- import bitsandbytes as bnb
1720
- except ImportError:
1721
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
1722
- print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
1723
- optimizer_class = bnb.optim.AdamW8bit
1724
- optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1725
-
1726
- elif optimizer_type == "SGDNesterov8bit".lower():
1727
- try:
1728
- import bitsandbytes as bnb
1729
- except ImportError:
1730
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
1731
- print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
1732
- if "momentum" not in optimizer_kwargs:
1733
- print(f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
1734
- optimizer_kwargs["momentum"] = 0.9
1735
-
1736
- optimizer_class = bnb.optim.SGD8bit
1737
- optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
1738
-
1739
- elif optimizer_type == "Lion".lower():
1740
- try:
1741
- import lion_pytorch
1742
- except ImportError:
1743
- raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
1744
- print(f"use Lion optimizer | {optimizer_kwargs}")
1745
- optimizer_class = lion_pytorch.Lion
1746
- optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1747
-
1748
- elif optimizer_type == "SGDNesterov".lower():
1749
- print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
1750
- if "momentum" not in optimizer_kwargs:
1751
- print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
1752
- optimizer_kwargs["momentum"] = 0.9
1753
-
1754
- optimizer_class = torch.optim.SGD
1755
- optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
1756
-
1757
- elif optimizer_type == "DAdaptation".lower():
1758
- try:
1759
- import dadaptation
1760
- except ImportError:
1761
- raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
1762
- print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
1763
-
1764
- actual_lr = lr
1765
- lr_count = 1
1766
- if type(trainable_params) == list and type(trainable_params[0]) == dict:
1767
- lrs = set()
1768
- actual_lr = trainable_params[0].get("lr", actual_lr)
1769
- for group in trainable_params:
1770
- lrs.add(group.get("lr", actual_lr))
1771
- lr_count = len(lrs)
1772
-
1773
- if actual_lr <= 0.1:
1774
- print(
1775
- f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}')
1776
- print('recommend option: lr=1.0 / 推奨は1.0です')
1777
- if lr_count > 1:
1778
- print(
1779
- f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}")
1780
-
1781
- optimizer_class = dadaptation.DAdaptAdam
1782
- optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1783
-
1784
- elif optimizer_type == "Adafactor".lower():
1785
- # 引数を確認して適宜補正する
1786
- if "relative_step" not in optimizer_kwargs:
1787
- optimizer_kwargs["relative_step"] = True # default
1788
- if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
1789
- print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします")
1790
- optimizer_kwargs["relative_step"] = True
1791
- print(f"use Adafactor optimizer | {optimizer_kwargs}")
1792
-
1793
- if optimizer_kwargs["relative_step"]:
1794
- print(f"relative_step is true / relative_stepがtrueです")
1795
- if lr != 0.0:
1796
- print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
1797
- args.learning_rate = None
1798
-
1799
- # trainable_paramsがgroupだった時の処理:lrを削除する
1800
- if type(trainable_params) == list and type(trainable_params[0]) == dict:
1801
- has_group_lr = False
1802
- for group in trainable_params:
1803
- p = group.pop("lr", None)
1804
- has_group_lr = has_group_lr or (p is not None)
1805
-
1806
- if has_group_lr:
1807
- # 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない
1808
- print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます")
1809
- args.unet_lr = None
1810
- args.text_encoder_lr = None
1811
-
1812
- if args.lr_scheduler != "adafactor":
1813
- print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
1814
- args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
1815
-
1816
- lr = None
1817
- else:
1818
- if args.max_grad_norm != 0.0:
1819
- print(f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません")
1820
- if args.lr_scheduler != "constant_with_warmup":
1821
- print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
1822
- if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
1823
- print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
1824
-
1825
- optimizer_class = transformers.optimization.Adafactor
1826
- optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1827
-
1828
- elif optimizer_type == "AdamW".lower():
1829
- print(f"use AdamW optimizer | {optimizer_kwargs}")
1830
- optimizer_class = torch.optim.AdamW
1831
- optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1832
-
1833
- else:
1834
- # 任意のoptimizerを使う
1835
- optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
1836
- print(f"use {optimizer_type} | {optimizer_kwargs}")
1837
- if "." not in optimizer_type:
1838
- optimizer_module = torch.optim
1839
- else:
1840
- values = optimizer_type.split(".")
1841
- optimizer_module = importlib.import_module(".".join(values[:-1]))
1842
- optimizer_type = values[-1]
1843
-
1844
- optimizer_class = getattr(optimizer_module, optimizer_type)
1845
- optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1846
-
1847
- optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
1848
- optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
1849
-
1850
- return optimizer_name, optimizer_args, optimizer
1851
-
1852
-
1853
- # Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
1854
- # code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
1855
- # Which is a newer release of diffusers than currently packaged with sd-scripts
1856
- # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
1857
-
1858
-
1859
- def get_scheduler_fix(
1860
- name: Union[str, SchedulerType],
1861
- optimizer: Optimizer,
1862
- num_warmup_steps: Optional[int] = None,
1863
- num_training_steps: Optional[int] = None,
1864
- num_cycles: int = 1,
1865
- power: float = 1.0,
1866
- ):
1867
- """
1868
- Unified API to get any scheduler from its name.
1869
- Args:
1870
- name (`str` or `SchedulerType`):
1871
- The name of the scheduler to use.
1872
- optimizer (`torch.optim.Optimizer`):
1873
- The optimizer that will be used during training.
1874
- num_warmup_steps (`int`, *optional*):
1875
- The number of warmup steps to do. This is not required by all schedulers (hence the argument being
1876
- optional), the function will raise an error if it's unset and the scheduler type requires it.
1877
- num_training_steps (`int``, *optional*):
1878
- The number of training steps to do. This is not required by all schedulers (hence the argument being
1879
- optional), the function will raise an error if it's unset and the scheduler type requires it.
1880
- num_cycles (`int`, *optional*):
1881
- The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
1882
- power (`float`, *optional*, defaults to 1.0):
1883
- Power factor. See `POLYNOMIAL` scheduler
1884
- last_epoch (`int`, *optional*, defaults to -1):
1885
- The index of the last epoch when resuming training.
1886
- """
1887
- if name.startswith("adafactor"):
1888
- assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
1889
- initial_lr = float(name.split(':')[1])
1890
- # print("adafactor scheduler init lr", initial_lr)
1891
- return transformers.optimization.AdafactorSchedule(optimizer, initial_lr)
1892
-
1893
- name = SchedulerType(name)
1894
- schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
1895
- if name == SchedulerType.CONSTANT:
1896
- return schedule_func(optimizer)
1897
-
1898
- # All other schedulers require `num_warmup_steps`
1899
- if num_warmup_steps is None:
1900
- raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
1901
-
1902
- if name == SchedulerType.CONSTANT_WITH_WARMUP:
1903
- return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
1904
-
1905
- # All other schedulers require `num_training_steps`
1906
- if num_training_steps is None:
1907
- raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
1908
-
1909
- if name == SchedulerType.COSINE_WITH_RESTARTS:
1910
- return schedule_func(
1911
- optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
1912
- )
1913
-
1914
- if name == SchedulerType.POLYNOMIAL:
1915
- return schedule_func(
1916
- optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
1917
- )
1918
-
1919
- return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
1920
-
1921
-
1922
  def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
1923
  # backward compatibility
1924
  if args.caption_extention is not None:
1925
  args.caption_extension = args.caption_extention
1926
  args.caption_extention = None
1927
 
 
 
 
 
1928
  # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
1929
  if args.resolution is not None:
1930
  args.resolution = tuple([int(r) for r in args.resolution.split(',')])
@@ -1947,28 +1536,12 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
1947
 
1948
  def load_tokenizer(args: argparse.Namespace):
1949
  print("prepare tokenizer")
1950
- original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH
1951
-
1952
- tokenizer: CLIPTokenizer = None
1953
- if args.tokenizer_cache_dir:
1954
- local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace('/', '_'))
1955
- if os.path.exists(local_tokenizer_path):
1956
- print(f"load tokenizer from cache: {local_tokenizer_path}")
1957
- tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2
1958
-
1959
- if tokenizer is None:
1960
- if args.v2:
1961
- tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer")
1962
- else:
1963
- tokenizer = CLIPTokenizer.from_pretrained(original_path)
1964
-
1965
- if hasattr(args, "max_token_length") and args.max_token_length is not None:
1966
  print(f"update token length: {args.max_token_length}")
1967
-
1968
- if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
1969
- print(f"save Tokenizer to cache: {local_tokenizer_path}")
1970
- tokenizer.save_pretrained(local_tokenizer_path)
1971
-
1972
  return tokenizer
1973
 
1974
 
@@ -2019,19 +1592,13 @@ def prepare_dtype(args: argparse.Namespace):
2019
 
2020
 
2021
  def load_target_model(args: argparse.Namespace, weight_dtype):
2022
- name_or_path = args.pretrained_model_name_or_path
2023
- name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
2024
- load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
2025
  if load_stable_diffusion_format:
2026
  print("load StableDiffusion checkpoint")
2027
- text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path)
2028
  else:
2029
  print("load Diffusers pretrained models")
2030
- try:
2031
- pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
2032
- except EnvironmentError as ex:
2033
- print(
2034
- f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}")
2035
  text_encoder = pipe.text_encoder
2036
  vae = pipe.vae
2037
  unet = pipe.unet
@@ -2200,197 +1767,6 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
2200
  model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
2201
  accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
2202
 
2203
-
2204
- # scheduler:
2205
- SCHEDULER_LINEAR_START = 0.00085
2206
- SCHEDULER_LINEAR_END = 0.0120
2207
- SCHEDULER_TIMESTEPS = 1000
2208
- SCHEDLER_SCHEDULE = 'scaled_linear'
2209
-
2210
-
2211
- def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None):
2212
- """
2213
- 生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない
2214
- clip skipは対応した
2215
- """
2216
- if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
2217
- return
2218
- if args.sample_every_n_epochs is not None:
2219
- # sample_every_n_steps は無視する
2220
- if epoch is None or epoch % args.sample_every_n_epochs != 0:
2221
- return
2222
- else:
2223
- if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
2224
- return
2225
-
2226
- print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
2227
- if not os.path.isfile(args.sample_prompts):
2228
- print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
2229
- return
2230
-
2231
- org_vae_device = vae.device # CPUにいるはず
2232
- vae.to(device)
2233
-
2234
- # clip skip 対応のための wrapper を作る
2235
- if args.clip_skip is None:
2236
- text_encoder_or_wrapper = text_encoder
2237
- else:
2238
- class Wrapper():
2239
- def __init__(self, tenc) -> None:
2240
- self.tenc = tenc
2241
- self.config = {}
2242
- super().__init__()
2243
-
2244
- def __call__(self, input_ids, attention_mask):
2245
- enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True)
2246
- encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
2247
- encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states)
2248
- pooled_output = enc_out['pooler_output']
2249
- return encoder_hidden_states, pooled_output # 1st output is only used
2250
-
2251
- text_encoder_or_wrapper = Wrapper(text_encoder)
2252
-
2253
- # read prompts
2254
- with open(args.sample_prompts, 'rt', encoding='utf-8') as f:
2255
- prompts = f.readlines()
2256
-
2257
- # schedulerを用意する
2258
- sched_init_args = {}
2259
- if args.sample_sampler == "ddim":
2260
- scheduler_cls = DDIMScheduler
2261
- elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
2262
- scheduler_cls = DDPMScheduler
2263
- elif args.sample_sampler == "pndm":
2264
- scheduler_cls = PNDMScheduler
2265
- elif args.sample_sampler == 'lms' or args.sample_sampler == 'k_lms':
2266
- scheduler_cls = LMSDiscreteScheduler
2267
- elif args.sample_sampler == 'euler' or args.sample_sampler == 'k_euler':
2268
- scheduler_cls = EulerDiscreteScheduler
2269
- elif args.sample_sampler == 'euler_a' or args.sample_sampler == 'k_euler_a':
2270
- scheduler_cls = EulerAncestralDiscreteScheduler
2271
- elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
2272
- scheduler_cls = DPMSolverMultistepScheduler
2273
- sched_init_args['algorithm_type'] = args.sample_sampler
2274
- elif args.sample_sampler == "dpmsingle":
2275
- scheduler_cls = DPMSolverSinglestepScheduler
2276
- elif args.sample_sampler == "heun":
2277
- scheduler_cls = HeunDiscreteScheduler
2278
- elif args.sample_sampler == 'dpm_2' or args.sample_sampler == 'k_dpm_2':
2279
- scheduler_cls = KDPM2DiscreteScheduler
2280
- elif args.sample_sampler == 'dpm_2_a' or args.sample_sampler == 'k_dpm_2_a':
2281
- scheduler_cls = KDPM2AncestralDiscreteScheduler
2282
- else:
2283
- scheduler_cls = DDIMScheduler
2284
-
2285
- if args.v_parameterization:
2286
- sched_init_args['prediction_type'] = 'v_prediction'
2287
-
2288
- scheduler = scheduler_cls(num_train_timesteps=SCHEDULER_TIMESTEPS,
2289
- beta_start=SCHEDULER_LINEAR_START, beta_end=SCHEDULER_LINEAR_END,
2290
- beta_schedule=SCHEDLER_SCHEDULE, **sched_init_args)
2291
-
2292
- # clip_sample=Trueにする
2293
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
2294
- # print("set clip_sample to True")
2295
- scheduler.config.clip_sample = True
2296
-
2297
- pipeline = StableDiffusionPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer,
2298
- scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False)
2299
- pipeline.to(device)
2300
-
2301
- save_dir = args.output_dir + "/sample"
2302
- os.makedirs(save_dir, exist_ok=True)
2303
-
2304
- rng_state = torch.get_rng_state()
2305
- cuda_rng_state = torch.cuda.get_rng_state()
2306
-
2307
- with torch.no_grad():
2308
- with accelerator.autocast():
2309
- for i, prompt in enumerate(prompts):
2310
- if not accelerator.is_main_process:
2311
- continue
2312
- prompt = prompt.strip()
2313
- if len(prompt) == 0 or prompt[0] == '#':
2314
- continue
2315
-
2316
- # subset of gen_img_diffusers
2317
- prompt_args = prompt.split(' --')
2318
- prompt = prompt_args[0]
2319
- negative_prompt = None
2320
- sample_steps = 30
2321
- width = height = 512
2322
- scale = 7.5
2323
- seed = None
2324
- for parg in prompt_args:
2325
- try:
2326
- m = re.match(r'w (\d+)', parg, re.IGNORECASE)
2327
- if m:
2328
- width = int(m.group(1))
2329
- continue
2330
-
2331
- m = re.match(r'h (\d+)', parg, re.IGNORECASE)
2332
- if m:
2333
- height = int(m.group(1))
2334
- continue
2335
-
2336
- m = re.match(r'd (\d+)', parg, re.IGNORECASE)
2337
- if m:
2338
- seed = int(m.group(1))
2339
- continue
2340
-
2341
- m = re.match(r's (\d+)', parg, re.IGNORECASE)
2342
- if m: # steps
2343
- sample_steps = max(1, min(1000, int(m.group(1))))
2344
- continue
2345
-
2346
- m = re.match(r'l ([\d\.]+)', parg, re.IGNORECASE)
2347
- if m: # scale
2348
- scale = float(m.group(1))
2349
- continue
2350
-
2351
- m = re.match(r'n (.+)', parg, re.IGNORECASE)
2352
- if m: # negative prompt
2353
- negative_prompt = m.group(1)
2354
- continue
2355
-
2356
- except ValueError as ex:
2357
- print(f"Exception in parsing / 解析エラー: {parg}")
2358
- print(ex)
2359
-
2360
- if seed is not None:
2361
- torch.manual_seed(seed)
2362
- torch.cuda.manual_seed(seed)
2363
-
2364
- if prompt_replacement is not None:
2365
- prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
2366
- if negative_prompt is not None:
2367
- negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
2368
-
2369
- height = max(64, height - height % 8) # round to divisible by 8
2370
- width = max(64, width - width % 8) # round to divisible by 8
2371
- print(f"prompt: {prompt}")
2372
- print(f"negative_prompt: {negative_prompt}")
2373
- print(f"height: {height}")
2374
- print(f"width: {width}")
2375
- print(f"sample_steps: {sample_steps}")
2376
- print(f"scale: {scale}")
2377
- image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
2378
-
2379
- ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
2380
- num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
2381
- seed_suffix = "" if seed is None else f"_{seed}"
2382
- img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
2383
-
2384
- image.save(os.path.join(save_dir, img_filename))
2385
-
2386
- # clear pipeline and cache to reduce vram usage
2387
- del pipeline
2388
- torch.cuda.empty_cache()
2389
-
2390
- torch.set_rng_state(rng_state)
2391
- torch.cuda.set_rng_state(cuda_rng_state)
2392
- vae.to(org_vae_device)
2393
-
2394
  # endregion
2395
 
2396
  # region 前処理用
 
1
  # common functions for training
2
 
3
  import argparse
 
4
  import json
 
5
  import shutil
6
  import time
7
+ from typing import Dict, List, NamedTuple, Tuple
 
 
 
 
 
 
 
 
8
  from accelerate import Accelerator
9
+ from torch.autograd.function import Function
10
  import glob
11
  import math
12
  import os
 
17
 
18
  from tqdm import tqdm
19
  import torch
 
20
  from torchvision import transforms
21
  from transformers import CLIPTokenizer
 
22
  import diffusers
23
+ from diffusers import DDPMScheduler, StableDiffusionPipeline
 
 
 
 
24
  import albumentations as albu
25
  import numpy as np
26
  from PIL import Image
 
195
  batch_index: int
196
 
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  class BaseDataset(torch.utils.data.Dataset):
199
+ def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
200
  super().__init__()
201
+ self.tokenizer: CLIPTokenizer = tokenizer
202
  self.max_token_length = max_token_length
203
+ self.shuffle_caption = shuffle_caption
204
+ self.shuffle_keep_tokens = shuffle_keep_tokens
205
  # width/height is used when enable_bucket==False
206
  self.width, self.height = (None, None) if resolution is None else resolution
207
+ self.face_crop_aug_range = face_crop_aug_range
208
+ self.flip_aug = flip_aug
209
+ self.color_aug = color_aug
210
  self.debug_dataset = debug_dataset
211
+ self.random_crop = random_crop
 
 
212
  self.token_padding_disabled = False
213
+ self.dataset_dirs_info = {}
214
+ self.reg_dataset_dirs_info = {}
215
  self.tag_frequency = {}
216
 
217
  self.enable_bucket = False
 
225
  self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
226
 
227
  self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
228
+ self.dropout_rate: float = 0
229
+ self.dropout_every_n_epochs: int = None
230
+ self.tag_dropout_rate: float = 0
231
 
232
  # augmentation
233
+ flip_p = 0.5 if flip_aug else 0.0
234
+ if color_aug:
235
+ # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
236
+ self.aug = albu.Compose([
237
+ albu.OneOf([
238
+ albu.HueSaturationValue(8, 0, 0, p=.5),
239
+ albu.RandomGamma((95, 105), p=.5),
240
+ ], p=.33),
241
+ albu.HorizontalFlip(p=flip_p)
242
+ ], p=1.)
243
+ elif flip_aug:
244
+ self.aug = albu.Compose([
245
+ albu.HorizontalFlip(p=flip_p)
246
+ ], p=1.)
247
+ else:
248
+ self.aug = None
249
 
250
  self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
251
 
252
  self.image_data: Dict[str, ImageInfo] = {}
 
253
 
254
  self.replacements = {}
255
 
256
  def set_current_epoch(self, epoch):
257
  self.current_epoch = epoch
258
+
259
+ def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
260
+ # コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
261
+ self.dropout_rate = dropout_rate
262
+ self.dropout_every_n_epochs = dropout_every_n_epochs
263
+ self.tag_dropout_rate = tag_dropout_rate
264
 
265
  def set_tag_frequency(self, dir_name, captions):
266
  frequency_for_dir = self.tag_frequency.get(dir_name, {})
267
  self.tag_frequency[dir_name] = frequency_for_dir
268
  for caption in captions:
269
  for tag in caption.split(","):
270
+ if tag and not tag.isspace():
 
271
  tag = tag.lower()
272
  frequency = frequency_for_dir.get(tag, 0)
273
  frequency_for_dir[tag] = frequency + 1
 
278
  def add_replacement(self, str_from, str_to):
279
  self.replacements[str_from] = str_to
280
 
281
+ def process_caption(self, caption):
282
  # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
283
+ is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
284
+ is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
285
 
286
  if is_drop_out:
287
  caption = ""
288
  else:
289
+ if self.shuffle_caption or self.tag_dropout_rate > 0:
290
  def dropout_tags(tokens):
291
+ if self.tag_dropout_rate <= 0:
292
  return tokens
293
  l = []
294
  for token in tokens:
295
+ if random.random() >= self.tag_dropout_rate:
296
  l.append(token)
297
  return l
298
 
299
+ tokens = [t.strip() for t in caption.strip().split(",")]
300
+ if self.shuffle_keep_tokens is None:
301
+ if self.shuffle_caption:
302
+ random.shuffle(tokens)
303
+
304
+ tokens = dropout_tags(tokens)
305
+ else:
306
+ if len(tokens) > self.shuffle_keep_tokens:
307
+ keep_tokens = tokens[:self.shuffle_keep_tokens]
308
+ tokens = tokens[self.shuffle_keep_tokens:]
309
 
310
+ if self.shuffle_caption:
311
+ random.shuffle(tokens)
312
 
313
+ tokens = dropout_tags(tokens)
314
 
315
+ tokens = keep_tokens + tokens
316
+ caption = ", ".join(tokens)
317
 
318
  # textual inversion対応
319
  for str_from, str_to in self.replacements.items():
 
367
  input_ids = torch.stack(iids_list) # 3,77
368
  return input_ids
369
 
370
+ def register_image(self, info: ImageInfo):
371
  self.image_data[info.image_key] = info
 
372
 
373
  def make_buckets(self):
374
  '''
 
467
  img = np.array(image, np.uint8)
468
  return img
469
 
470
+ def trim_and_resize_if_required(self, image, reso, resized_size):
471
  image_height, image_width = image.shape[0:2]
472
 
473
  if image_width != resized_size[0] or image_height != resized_size[1]:
 
477
  image_height, image_width = image.shape[0:2]
478
  if image_width > reso[0]:
479
  trim_size = image_width - reso[0]
480
+ p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
481
  # print("w", trim_size, p)
482
  image = image[:, p:p + reso[0]]
483
  if image_height > reso[1]:
484
  trim_size = image_height - reso[1]
485
+ p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
486
  # print("h", trim_size, p)
487
  image = image[p:p + reso[1]]
488
 
489
  assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
490
  return image
491
 
 
 
 
492
  def cache_latents(self, vae):
493
  # TODO ここを高速化したい
494
  print("caching latents.")
495
  for info in tqdm(self.image_data.values()):
 
 
496
  if info.latents_npz is not None:
497
  info.latents = self.load_latents_from_npz(info, False)
498
  info.latents = torch.FloatTensor(info.latents)
 
502
  continue
503
 
504
  image = self.load_image(info.absolute_path)
505
+ image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
506
 
507
  img_tensor = self.image_transforms(image)
508
  img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
509
  info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
510
 
511
+ if self.flip_aug:
512
  image = image[:, ::-1].copy() # cannot convert to Tensor without copy
513
  img_tensor = self.image_transforms(image)
514
  img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
 
518
  image = Image.open(image_path)
519
  return image.size
520
 
521
+ def load_image_with_face_info(self, image_path: str):
522
  img = self.load_image(image_path)
523
 
524
  face_cx = face_cy = face_w = face_h = 0
525
+ if self.face_crop_aug_range is not None:
526
  tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
527
  if len(tokens) >= 5:
528
  face_cx = int(tokens[-4])
 
533
  return img, face_cx, face_cy, face_w, face_h
534
 
535
  # いい感じに切り出す
536
+ def crop_target(self, image, face_cx, face_cy, face_w, face_h):
537
  height, width = image.shape[0:2]
538
  if height == self.height and width == self.width:
539
  return image
 
541
  # 画像サイズはsizeより大きいのでリサイズする
542
  face_size = max(face_w, face_h)
543
  min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
544
+ min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
545
+ max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
546
  if min_scale >= max_scale: # range指定がmin==max
547
  scale = min_scale
548
  else:
 
560
  for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
561
  p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
562
 
563
+ if self.random_crop:
564
  # 背景も含めるために顔を中心に置く確率を高めつつずらす
565
  range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
566
  p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
567
  else:
568
  # range指定があるときのみ、すこしだけランダムに(わりと適当)
569
+ if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
570
  if face_size > self.size // 10 and face_size >= 40:
571
  p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
572
 
 
589
  return self._length
590
 
591
  def __getitem__(self, index):
592
+ if index == 0:
593
+ self.shuffle_buckets()
594
+
595
  bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
596
  bucket_batch_size = self.buckets_indices[index].bucket_batch_size
597
  image_index = self.buckets_indices[index].batch_index * bucket_batch_size
 
604
 
605
  for image_key in bucket[image_index:image_index + bucket_batch_size]:
606
  image_info = self.image_data[image_key]
 
607
  loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
608
 
609
  # image/latentsを処理する
610
  if image_info.latents is not None:
611
+ latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped
612
  image = None
613
  elif image_info.latents_npz is not None:
614
+ latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5)
615
  latents = torch.FloatTensor(latents)
616
  image = None
617
  else:
618
  # 画像を読み込み、必要ならcropする
619
+ img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
620
  im_h, im_w = img.shape[0:2]
621
 
622
  if self.enable_bucket:
623
+ img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
624
  else:
625
  if face_cx > 0: # 顔位置情報あり
626
+ img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
627
  elif im_h > self.height or im_w > self.width:
628
+ assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
629
  if im_h > self.height:
630
  p = random.randint(0, im_h - self.height)
631
  img = img[p:p + self.height]
 
637
  assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
638
 
639
  # augmentation
640
+ if self.aug is not None:
641
+ img = self.aug(image=img)['image']
 
642
 
643
  latents = None
644
  image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
 
646
  images.append(image)
647
  latents_list.append(latents)
648
 
649
+ caption = self.process_caption(image_info.caption)
650
  captions.append(caption)
651
  if not self.token_padding_disabled: # this option might be omitted in future
652
  input_ids_list.append(self.get_input_ids(caption))
 
677
 
678
 
679
  class DreamBoothDataset(BaseDataset):
680
+ def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
681
+ super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
682
+ resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
683
 
684
  assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
685
 
 
702
  self.bucket_reso_steps = None # この情報は使われない
703
  self.bucket_no_upscale = False
704
 
705
+ def read_caption(img_path):
706
  # captionの候補ファイル名を作る
707
  base_name = os.path.splitext(img_path)[0]
708
  base_name_face_det = base_name
 
725
  break
726
  return caption
727
 
728
+ def load_dreambooth_dir(dir):
729
+ if not os.path.isdir(dir):
730
+ # print(f"ignore file: {dir}")
731
+ return 0, [], []
732
 
733
+ tokens = os.path.basename(dir).split('_')
734
+ try:
735
+ n_repeats = int(tokens[0])
736
+ except ValueError as e:
737
+ print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
738
+ return 0, [], []
739
+
740
+ caption_by_folder = '_'.join(tokens[1:])
741
+ img_paths = glob_images(dir, "*")
742
+ print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
743
 
744
  # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
745
  captions = []
746
  for img_path in img_paths:
747
+ cap_for_img = read_caption(img_path)
748
+ captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
 
 
 
 
749
 
750
+ self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
751
 
752
+ return n_repeats, img_paths, captions
753
 
754
+ print("prepare train images.")
755
+ train_dirs = os.listdir(train_data_dir)
756
  num_train_images = 0
757
+ for dir in train_dirs:
758
+ n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
759
+ num_train_images += n_repeats * len(img_paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
760
 
761
  for img_path, caption in zip(img_paths, captions):
762
+ info = ImageInfo(img_path, n_repeats, caption, False, img_path)
763
+ self.register_image(info)
 
 
 
764
 
765
+ self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
 
766
 
767
  print(f"{num_train_images} train images with repeating.")
768
  self.num_train_images = num_train_images
769
 
770
+ # reg imageは数を数えて学習画像と同じ枚数にする
771
+ num_reg_images = 0
772
+ if reg_data_dir:
773
+ print("prepare reg images.")
774
+ reg_infos: List[ImageInfo] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
775
 
776
+ reg_dirs = os.listdir(reg_data_dir)
777
+ for dir in reg_dirs:
778
+ n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
779
+ num_reg_images += n_repeats * len(img_paths)
780
 
781
+ for img_path, caption in zip(img_paths, captions):
782
+ info = ImageInfo(img_path, n_repeats, caption, True, img_path)
783
+ reg_infos.append(info)
784
 
785
+ self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
 
 
 
 
786
 
787
+ print(f"{num_reg_images} reg images.")
788
+ if num_train_images < num_reg_images:
789
+ print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
 
790
 
791
+ if num_reg_images == 0:
792
+ print("no regularization images / 正則化画像が見つかりませんでした")
 
 
 
793
  else:
794
+ # num_repeatsを計算する:どうせ大した数ではないのでループで処理する
795
+ n = 0
796
+ first_loop = True
797
+ while n < num_train_images:
798
+ for info in reg_infos:
799
+ if first_loop:
800
+ self.register_image(info)
801
+ n += info.num_repeats
802
+ else:
803
+ info.num_repeats += 1
804
+ n += 1
805
+ if n >= num_train_images:
806
+ break
807
+ first_loop = False
808
 
809
+ self.num_reg_images = num_reg_images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
810
 
 
 
811
 
812
+ class FineTuningDataset(BaseDataset):
813
+ def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
814
+ super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
815
+ resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
816
+
817
+ # メタデータを読み込む
818
+ if os.path.exists(json_file_name):
819
+ print(f"loading existing metadata: {json_file_name}")
820
+ with open(json_file_name, "rt", encoding='utf-8') as f:
821
+ metadata = json.load(f)
822
+ else:
823
+ raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}")
824
 
825
+ self.metadata = metadata
826
+ self.train_data_dir = train_data_dir
827
+ self.batch_size = batch_size
828
 
829
+ tags_list = []
830
+ for image_key, img_md in metadata.items():
831
+ # path情報を作る
832
+ if os.path.exists(image_key):
833
+ abs_path = image_key
834
+ else:
835
+ # わりといい加減だがいい方法が思いつかん
836
+ abs_path = glob_images(train_data_dir, image_key)
837
+ assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
838
+ abs_path = abs_path[0]
839
+
840
+ caption = img_md.get('caption')
841
+ tags = img_md.get('tags')
842
+ if caption is None:
843
+ caption = tags
844
+ elif tags is not None and len(tags) > 0:
845
+ caption = caption + ', ' + tags
846
+ tags_list.append(tags)
847
+ assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
848
+
849
+ image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
850
+ image_info.image_size = img_md.get('train_resolution')
851
+
852
+ if not self.color_aug and not self.random_crop:
853
+ # if npz exists, use them
854
+ image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
855
+
856
+ self.register_image(image_info)
857
+ self.num_train_images = len(metadata) * dataset_repeats
858
+ self.num_reg_images = 0
859
 
860
+ # TODO do not record tag freq when no tag
861
+ self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
862
+ self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
 
863
 
864
  # check existence of all npz files
865
+ use_npz_latents = not (self.color_aug or self.random_crop)
866
  if use_npz_latents:
 
867
  npz_any = False
868
  npz_all = True
 
869
  for image_info in self.image_data.values():
 
 
870
  has_npz = image_info.latents_npz is not None
871
  npz_any = npz_any or has_npz
872
 
873
+ if self.flip_aug:
874
  has_npz = has_npz and image_info.latents_npz_flipped is not None
 
875
  npz_all = npz_all and has_npz
876
 
877
  if npz_any and not npz_all:
 
883
  elif not npz_all:
884
  use_npz_latents = False
885
  print(f"some of npz file does not exist. ignore npz files / いくつ���のnpzファイルが見つからないためnpzファイルを無視します")
886
+ if self.flip_aug:
887
  print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
888
  # else:
889
  # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
 
929
  for image_info in self.image_data.values():
930
  image_info.latents_npz = image_info.latents_npz_flipped = None
931
 
932
+ def image_key_to_npz_file(self, image_key):
933
  base_name = os.path.splitext(image_key)[0]
934
  npz_file_norm = base_name + '.npz'
935
 
 
941
  return npz_file_norm, npz_file_flip
942
 
943
  # image_key is relative path
944
+ npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz')
945
+ npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz')
946
 
947
  if not os.path.exists(npz_file_norm):
948
  npz_file_norm = None
 
953
  return npz_file_norm, npz_file_flip
954
 
955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
956
  def debug_dataset(train_dataset, show_input_ids=False):
957
  print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
958
  print("Escape for exit. / Escキーで中断、終了します")
959
 
960
  train_dataset.set_current_epoch(1)
961
  k = 0
962
+ for i, example in enumerate(train_dataset):
 
 
 
963
  if example['latents'] is not None:
964
  print(f"sample has latents from npz file: {example['latents'].size()}")
965
  for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
 
1364
  help='enable v-parameterization training / v-parameterization学習を有効にする')
1365
  parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
1366
  help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1367
 
1368
 
1369
  def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
 
1387
  parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
1388
  parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
1389
  help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
1390
+ parser.add_argument("--use_8bit_adam", action="store_true",
1391
+ help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
1392
+ parser.add_argument("--use_lion_optimizer", action="store_true",
1393
+ help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
1394
  parser.add_argument("--mem_eff_attn", action="store_true",
1395
  help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
1396
  parser.add_argument("--xformers", action="store_true",
 
1398
  parser.add_argument("--vae", type=str, default=None,
1399
  help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
1400
 
1401
+ parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
1402
  parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
1403
  parser.add_argument("--max_train_epochs", type=int, default=None,
1404
  help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
 
1419
  parser.add_argument("--logging_dir", type=str, default=None,
1420
  help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
1421
  parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
1422
+ parser.add_argument("--lr_scheduler", type=str, default="constant",
1423
+ help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
1424
+ parser.add_argument("--lr_warmup_steps", type=int, default=0,
1425
+ help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
1426
  parser.add_argument("--noise_offset", type=float, default=None,
1427
  help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
1428
  parser.add_argument("--lowram", action="store_true",
1429
  help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
1430
 
 
 
 
 
 
 
 
 
 
 
 
 
1431
  if support_dreambooth:
1432
  # DreamBooth training
1433
  parser.add_argument("--prior_loss_weight", type=float, default=1.0,
 
1449
  parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
1450
  parser.add_argument("--caption_extention", type=str, default=None,
1451
  help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
1452
+ parser.add_argument("--keep_tokens", type=int, default=None,
1453
+ help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
1454
  parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
1455
  parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
1456
  parser.add_argument("--face_crop_aug_range", type=str, default=None,
 
1475
  if support_caption_dropout:
1476
  # Textual Inversion はcaptionのdropoutをsupportしない
1477
  # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
1478
+ parser.add_argument("--caption_dropout_rate", type=float, default=0,
1479
  help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
1480
+ parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
1481
  help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
1482
+ parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
1483
  help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
1484
 
1485
  if support_dreambooth:
 
1504
  # region utils
1505
 
1506
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1507
  def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
1508
  # backward compatibility
1509
  if args.caption_extention is not None:
1510
  args.caption_extension = args.caption_extention
1511
  args.caption_extention = None
1512
 
1513
+ if args.cache_latents:
1514
+ assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
1515
+ assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
1516
+
1517
  # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
1518
  if args.resolution is not None:
1519
  args.resolution = tuple([int(r) for r in args.resolution.split(',')])
 
1536
 
1537
  def load_tokenizer(args: argparse.Namespace):
1538
  print("prepare tokenizer")
1539
+ if args.v2:
1540
+ tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
1541
+ else:
1542
+ tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
1543
+ if args.max_token_length is not None:
 
 
 
 
 
 
 
 
 
 
 
1544
  print(f"update token length: {args.max_token_length}")
 
 
 
 
 
1545
  return tokenizer
1546
 
1547
 
 
1592
 
1593
 
1594
  def load_target_model(args: argparse.Namespace, weight_dtype):
1595
+ load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers
 
 
1596
  if load_stable_diffusion_format:
1597
  print("load StableDiffusion checkpoint")
1598
+ text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
1599
  else:
1600
  print("load Diffusers pretrained models")
1601
+ pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
 
 
 
 
1602
  text_encoder = pipe.text_encoder
1603
  vae = pipe.vae
1604
  unet = pipe.unet
 
1767
  model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
1768
  accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
1769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1770
  # endregion
1771
 
1772
  # region 前処理用
lora_train_popup.py ADDED
@@ -0,0 +1,862 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import json
3
+ import time
4
+ from functools import partial
5
+ from typing import Union
6
+ import os
7
+ import tkinter as tk
8
+ from tkinter import filedialog as fd, ttk
9
+ from tkinter import simpledialog as sd
10
+ from tkinter import messagebox as mb
11
+
12
+ import torch.cuda
13
+ import train_network
14
+ import library.train_util as util
15
+ import argparse
16
+
17
+
18
+ class ArgStore:
19
+ # Represents the entirety of all possible inputs for sd-scripts. they are ordered from most important to least
20
+ def __init__(self):
21
+ # Important, these are the most likely things you will modify
22
+ self.base_model: str = r"" # example path, r"E:\sd\stable-diffusion-webui\models\Stable-diffusion\nai.ckpt"
23
+ self.img_folder: str = r"" # is the folder path to your img folder, make sure to follow the guide here for folder setup: https://rentry.org/2chAI_LoRA_Dreambooth_guide_english#for-kohyas-script
24
+ self.output_folder: str = r"" # just the folder all epochs/safetensors are output
25
+ self.change_output_name: Union[str, None] = None # changes the output name of the epochs
26
+ self.save_json_folder: Union[str, None] = None # OPTIONAL, saves a json folder of your config to whatever location you set here.
27
+ self.load_json_path: Union[str, None] = None # OPTIONAL, loads a json file partially changes the config to match. things like folder paths do not get modified.
28
+ self.json_load_skip_list: Union[list[str], None] = ["save_json_folder", "reg_img_folder",
29
+ "lora_model_for_resume", "change_output_name",
30
+ "training_comment",
31
+ "json_load_skip_list"] # OPTIONAL, allows the user to define what they skip when loading a json, by default it loads everything, including all paths, set it up like this ["base_model", "img_folder", "output_folder"]
32
+ self.caption_dropout_rate: Union[float, None] = None # The rate at which captions for files get dropped.
33
+ self.caption_dropout_every_n_epochs: Union[int, None] = None # Defines how often an epoch will completely ignore
34
+ # captions, EX. 3 means it will ignore captions at epochs 3, 6, and 9
35
+ self.caption_tag_dropout_rate: Union[float, None] = None # Defines the rate at which a tag would be dropped, rather than the entire caption file
36
+ self.noise_offset: Union[float, None] = None # OPTIONAL, seems to help allow SD to gen better blacks and whites
37
+ # Kohya recommends, if you have it set, to use 0.1, not sure how
38
+ # high the value can be, I'm going to assume maximum of 1
39
+
40
+ self.net_dim: int = 128 # network dimension, 128 is the most common, however you might be able to get lesser to work
41
+ self.alpha: float = 128 # represents the scalar for training. the lower the alpha, the less gets learned per step. if you want the older way of training, set this to dim
42
+ # list of schedulers: linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup
43
+ self.scheduler: str = "cosine_with_restarts" # the scheduler for learning rate. Each does something specific
44
+ self.cosine_restarts: Union[int, None] = 1 # OPTIONAL, represents the number of times it restarts. Only matters if you are using cosine_with_restarts
45
+ self.scheduler_power: Union[float, None] = 1 # OPTIONAL, represents the power of the polynomial. Only matters if you are using polynomial
46
+ self.warmup_lr_ratio: Union[float, None] = None # OPTIONAL, Calculates the number of warmup steps based on the ratio given. Make sure to set this if you are using constant_with_warmup, None to ignore
47
+ self.learning_rate: Union[float, None] = 1e-4 # OPTIONAL, when not set, lr gets set to 1e-3 as per adamW. Personally, I suggest actually setting this as lower lr seems to be a small bit better.
48
+ self.text_encoder_lr: Union[float, None] = None # OPTIONAL, Sets a specific lr for the text encoder, this overwrites the base lr I believe, None to ignore
49
+ self.unet_lr: Union[float, None] = None # OPTIONAL, Sets a specific lr for the unet, this overwrites the base lr I believe, None to ignore
50
+ self.num_workers: int = 1 # The number of threads that are being used to load images, lower speeds up the start of epochs, but slows down the loading of data. The assumption here is that it increases the training time as you reduce this value
51
+ self.persistent_workers: bool = True # makes workers persistent, further reduces/eliminates the lag in between epochs. however it may increase memory usage
52
+
53
+ self.batch_size: int = 1 # The number of images that get processed at one time, this is directly proportional to your vram and resolution. with 12gb of vram, at 512 reso, you can get a maximum of 6 batch size
54
+ self.num_epochs: int = 1 # The number of epochs, if you set max steps this value is ignored as it doesn't calculate steps.
55
+ self.save_every_n_epochs: Union[int, None] = None # OPTIONAL, how often to save epochs, None to ignore
56
+ self.shuffle_captions: bool = False # OPTIONAL, False to ignore
57
+ self.keep_tokens: Union[int, None] = None # OPTIONAL, None to ignore
58
+ self.max_steps: Union[int, None] = None # OPTIONAL, if you have specific steps you want to hit, this allows you to set it directly. None to ignore
59
+ self.tag_occurrence_txt_file: bool = False # OPTIONAL, creates a txt file that has the entire occurrence of all tags in your dataset
60
+ # the metadata will also have this so long as you have metadata on, so no reason to have this on by default
61
+ # will automatically output to the same folder as your output checkpoints
62
+ self.sort_tag_occurrence_alphabetically: bool = False # OPTIONAL, only applies if tag_occurrence_txt_file is also true
63
+ # Will change the output to be alphabetically vs being occurrence based
64
+
65
+ # These are the second most likely things you will modify
66
+ self.train_resolution: int = 512
67
+ self.min_bucket_resolution: int = 320
68
+ self.max_bucket_resolution: int = 960
69
+ self.lora_model_for_resume: Union[str, None] = None # OPTIONAL, takes an input lora to continue training from, not exactly the way it *should* be, but it works, None to ignore
70
+ self.save_state: bool = False # OPTIONAL, is the intended way to save a training state to use for continuing training, False to ignore
71
+ self.load_previous_save_state: Union[str, None] = None # OPTIONAL, is the intended way to load a training state to use for continuing training, None to ignore
72
+ self.training_comment: Union[str, None] = None # OPTIONAL, great way to put in things like activation tokens right into the metadata. seems to not work at this point and time
73
+ self.unet_only: bool = False # OPTIONAL, set it to only train the unet
74
+ self.text_only: bool = False # OPTIONAL, set it to only train the text encoder
75
+
76
+ # These are the least likely things you will modify
77
+ self.reg_img_folder: Union[str, None] = None # OPTIONAL, None to ignore
78
+ self.clip_skip: int = 2 # If you are training on a model that is anime based, keep this at 2 as most models are designed for that
79
+ self.test_seed: int = 23 # this is the "reproducable seed", basically if you set the seed to this, you should be able to input a prompt from one of your training images and get a close representation of it
80
+ self.prior_loss_weight: float = 1 # is the loss weight much like Dreambooth, is required for LoRA training
81
+ self.gradient_checkpointing: bool = False # OPTIONAL, enables gradient checkpointing
82
+ self.gradient_acc_steps: Union[int, None] = None # OPTIONAL, not sure exactly what this means
83
+ self.mixed_precision: str = "fp16" # If you have the ability to use bf16, do it, it's better
84
+ self.save_precision: str = "fp16" # You can also save in bf16, but because it's not universally supported, I suggest you keep saving at fp16
85
+ self.save_as: str = "safetensors" # list is pt, ckpt, safetensors
86
+ self.caption_extension: str = ".txt" # the other option is .captions, but since wd1.4 tagger outputs as txt files, this is the default
87
+ self.max_clip_token_length = 150 # can be 75, 150, or 225 I believe, there is no reason to go higher than 150 though
88
+ self.buckets: bool = True
89
+ self.xformers: bool = True
90
+ self.use_8bit_adam: bool = True
91
+ self.cache_latents: bool = True
92
+ self.color_aug: bool = False # IMPORTANT: Clashes with cache_latents, only have one of the two on!
93
+ self.flip_aug: bool = False
94
+ self.vae: Union[str, None] = None # Seems to only make results worse when not using that specific vae, should probably not use
95
+ self.no_meta: bool = False # This removes the metadata that now gets saved into safetensors, (you should keep this on)
96
+ self.log_dir: Union[str, None] = None # output of logs, not useful to most people.
97
+ self.v2: bool = False # Sets up training for SD2.1
98
+ self.v_parameterization: bool = False # Only is used when v2 is also set and you are using the 768x version of v2
99
+
100
+ # Creates the dict that is used for the rest of the code, to facilitate easier json saving and loading
101
+ @staticmethod
102
+ def convert_args_to_dict():
103
+ return ArgStore().__dict__
104
+
105
+
106
+ def main():
107
+ parser = argparse.ArgumentParser()
108
+ setup_args(parser)
109
+ pre_args = parser.parse_args()
110
+ queues = 0
111
+ args_queue = []
112
+ cont = True
113
+ while cont:
114
+ arg_dict = ArgStore.convert_args_to_dict()
115
+ ret = mb.askyesno(message="Do you want to load a json config file?")
116
+ if ret:
117
+ load_json(ask_file("select json to load from", {"json"}), arg_dict)
118
+ arg_dict = ask_elements_trunc(arg_dict)
119
+ else:
120
+ arg_dict = ask_elements(arg_dict)
121
+ if pre_args.save_json_path or arg_dict["save_json_folder"]:
122
+ save_json(pre_args.save_json_path if pre_args.save_json_path else arg_dict['save_json_folder'], arg_dict)
123
+ args = create_arg_space(arg_dict)
124
+ args = parser.parse_args(args)
125
+ queues += 1
126
+ args_queue.append(args)
127
+ if arg_dict['tag_occurrence_txt_file']:
128
+ get_occurrence_of_tags(arg_dict)
129
+ ret = mb.askyesno(message="Do you want to queue another training?")
130
+ if not ret:
131
+ cont = False
132
+ for args in args_queue:
133
+ try:
134
+ train_network.train(args)
135
+ except Exception as e:
136
+ print(f"Failed to train this set of args.\nSkipping this training session.\nError is: {e}")
137
+ gc.collect()
138
+ torch.cuda.empty_cache()
139
+
140
+
141
+ def create_arg_space(args: dict) -> [str]:
142
+ # This is the list of args that are to be used regardless of setup
143
+ output = ["--network_module=networks.lora", f"--pretrained_model_name_or_path={args['base_model']}",
144
+ f"--train_data_dir={args['img_folder']}", f"--output_dir={args['output_folder']}",
145
+ f"--prior_loss_weight={args['prior_loss_weight']}", f"--caption_extension=" + args['caption_extension'],
146
+ f"--resolution={args['train_resolution']}", f"--train_batch_size={args['batch_size']}",
147
+ f"--mixed_precision={args['mixed_precision']}", f"--save_precision={args['save_precision']}",
148
+ f"--network_dim={args['net_dim']}", f"--save_model_as={args['save_as']}",
149
+ f"--clip_skip={args['clip_skip']}", f"--seed={args['test_seed']}",
150
+ f"--max_token_length={args['max_clip_token_length']}", f"--lr_scheduler={args['scheduler']}",
151
+ f"--network_alpha={args['alpha']}", f"--max_data_loader_n_workers={args['num_workers']}"]
152
+ if not args['max_steps']:
153
+ output.append(f"--max_train_epochs={args['num_epochs']}")
154
+ output += create_optional_args(args, find_max_steps(args))
155
+ else:
156
+ output.append(f"--max_train_steps={args['max_steps']}")
157
+ output += create_optional_args(args, args['max_steps'])
158
+ return output
159
+
160
+
161
+ def create_optional_args(args: dict, steps):
162
+ output = []
163
+ if args["reg_img_folder"]:
164
+ output.append(f"--reg_data_dir={args['reg_img_folder']}")
165
+
166
+ if args['lora_model_for_resume']:
167
+ output.append(f"--network_weights={args['lora_model_for_resume']}")
168
+
169
+ if args['save_every_n_epochs']:
170
+ output.append(f"--save_every_n_epochs={args['save_every_n_epochs']}")
171
+ else:
172
+ output.append("--save_every_n_epochs=999999")
173
+
174
+ if args['shuffle_captions']:
175
+ output.append("--shuffle_caption")
176
+
177
+ if args['keep_tokens'] and args['keep_tokens'] > 0:
178
+ output.append(f"--keep_tokens={args['keep_tokens']}")
179
+
180
+ if args['buckets']:
181
+ output.append("--enable_bucket")
182
+ output.append(f"--min_bucket_reso={args['min_bucket_resolution']}")
183
+ output.append(f"--max_bucket_reso={args['max_bucket_resolution']}")
184
+
185
+ if args['use_8bit_adam']:
186
+ output.append("--use_8bit_adam")
187
+
188
+ if args['xformers']:
189
+ output.append("--xformers")
190
+
191
+ if args['color_aug']:
192
+ if args['cache_latents']:
193
+ print("color_aug and cache_latents conflict with one another. Please select only one")
194
+ quit(1)
195
+ output.append("--color_aug")
196
+
197
+ if args['flip_aug']:
198
+ output.append("--flip_aug")
199
+
200
+ if args['cache_latents']:
201
+ output.append("--cache_latents")
202
+
203
+ if args['warmup_lr_ratio'] and args['warmup_lr_ratio'] > 0:
204
+ warmup_steps = int(steps * args['warmup_lr_ratio'])
205
+ output.append(f"--lr_warmup_steps={warmup_steps}")
206
+
207
+ if args['gradient_checkpointing']:
208
+ output.append("--gradient_checkpointing")
209
+
210
+ if args['gradient_acc_steps'] and args['gradient_acc_steps'] > 0 and args['gradient_checkpointing']:
211
+ output.append(f"--gradient_accumulation_steps={args['gradient_acc_steps']}")
212
+
213
+ if args['learning_rate'] and args['learning_rate'] > 0:
214
+ output.append(f"--learning_rate={args['learning_rate']}")
215
+
216
+ if args['text_encoder_lr'] and args['text_encoder_lr'] > 0:
217
+ output.append(f"--text_encoder_lr={args['text_encoder_lr']}")
218
+
219
+ if args['unet_lr'] and args['unet_lr'] > 0:
220
+ output.append(f"--unet_lr={args['unet_lr']}")
221
+
222
+ if args['vae']:
223
+ output.append(f"--vae={args['vae']}")
224
+
225
+ if args['no_meta']:
226
+ output.append("--no_metadata")
227
+
228
+ if args['save_state']:
229
+ output.append("--save_state")
230
+
231
+ if args['load_previous_save_state']:
232
+ output.append(f"--resume={args['load_previous_save_state']}")
233
+
234
+ if args['change_output_name']:
235
+ output.append(f"--output_name={args['change_output_name']}")
236
+
237
+ if args['training_comment']:
238
+ output.append(f"--training_comment={args['training_comment']}")
239
+
240
+ if args['cosine_restarts'] and args['scheduler'] == "cosine_with_restarts":
241
+ output.append(f"--lr_scheduler_num_cycles={args['cosine_restarts']}")
242
+
243
+ if args['scheduler_power'] and args['scheduler'] == "polynomial":
244
+ output.append(f"--lr_scheduler_power={args['scheduler_power']}")
245
+
246
+ if args['persistent_workers']:
247
+ output.append(f"--persistent_data_loader_workers")
248
+
249
+ if args['unet_only']:
250
+ output.append("--network_train_unet_only")
251
+
252
+ if args['text_only'] and not args['unet_only']:
253
+ output.append("--network_train_text_encoder_only")
254
+
255
+ if args["log_dir"]:
256
+ output.append(f"--logging_dir={args['log_dir']}")
257
+
258
+ if args['caption_dropout_rate']:
259
+ output.append(f"--caption_dropout_rate={args['caption_dropout_rate']}")
260
+
261
+ if args['caption_dropout_every_n_epochs']:
262
+ output.append(f"--caption_dropout_every_n_epochs={args['caption_dropout_every_n_epochs']}")
263
+
264
+ if args['caption_tag_dropout_rate']:
265
+ output.append(f"--caption_tag_dropout_rate={args['caption_tag_dropout_rate']}")
266
+
267
+ if args['v2']:
268
+ output.append("--v2")
269
+
270
+ if args['v2'] and args['v_parameterization']:
271
+ output.append("--v_parameterization")
272
+
273
+ if args['noise_offset']:
274
+ output.append(f"--noise_offset={args['noise_offset']}")
275
+ return output
276
+
277
+
278
+ def find_max_steps(args: dict) -> int:
279
+ total_steps = 0
280
+ folders = os.listdir(args["img_folder"])
281
+ for folder in folders:
282
+ if not os.path.isdir(os.path.join(args["img_folder"], folder)):
283
+ continue
284
+ num_repeats = folder.split("_")
285
+ if len(num_repeats) < 2:
286
+ print(f"folder {folder} is not in the correct format. Format is x_name. skipping")
287
+ continue
288
+ try:
289
+ num_repeats = int(num_repeats[0])
290
+ except ValueError:
291
+ print(f"folder {folder} is not in the correct format. Format is x_name. skipping")
292
+ continue
293
+ imgs = 0
294
+ for file in os.listdir(os.path.join(args["img_folder"], folder)):
295
+ if os.path.isdir(file):
296
+ continue
297
+ ext = file.split(".")
298
+ if ext[-1].lower() in {"png", "bmp", "gif", "jpeg", "jpg", "webp"}:
299
+ imgs += 1
300
+ total_steps += (num_repeats * imgs)
301
+ total_steps = int((total_steps / args["batch_size"]) * args["num_epochs"])
302
+ return total_steps
303
+
304
+
305
+ def add_misc_args(parser):
306
+ parser.add_argument("--save_json_path", type=str, default=None,
307
+ help="Path to save a configuration json file to")
308
+ parser.add_argument("--load_json_path", type=str, default=None,
309
+ help="Path to a json file to configure things from")
310
+ parser.add_argument("--no_metadata", action='store_true',
311
+ help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
312
+ parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
313
+ help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)")
314
+
315
+ parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
316
+ parser.add_argument("--text_encoder_lr", type=float, default=None,
317
+ help="learning rate for Text Encoder / Text Encoderの学習率")
318
+ parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
319
+ help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
320
+ parser.add_argument("--lr_scheduler_power", type=float, default=1,
321
+ help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
322
+
323
+ parser.add_argument("--network_weights", type=str, default=None,
324
+ help="pretrained weights for network / 学習するネットワークの初期重み")
325
+ parser.add_argument("--network_module", type=str, default=None,
326
+ help='network module to train / 学習対象のネットワークのモジュール')
327
+ parser.add_argument("--network_dim", type=int, default=None,
328
+ help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
329
+ parser.add_argument("--network_alpha", type=float, default=1,
330
+ help='alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)')
331
+ parser.add_argument("--network_args", type=str, default=None, nargs='*',
332
+ help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
333
+ parser.add_argument("--network_train_unet_only", action="store_true",
334
+ help="only training U-Net part / U-Net関連部分のみ学習する")
335
+ parser.add_argument("--network_train_text_encoder_only", action="store_true",
336
+ help="only training Text Encoder part / Text Encoder関連部分のみ学習する")
337
+ parser.add_argument("--training_comment", type=str, default=None,
338
+ help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
339
+
340
+
341
+ def setup_args(parser):
342
+ util.add_sd_models_arguments(parser)
343
+ util.add_dataset_arguments(parser, True, True, True)
344
+ util.add_training_arguments(parser, True)
345
+ add_misc_args(parser)
346
+
347
+
348
+ def get_occurrence_of_tags(args):
349
+ extension = args['caption_extension']
350
+ img_folder = args['img_folder']
351
+ output_folder = args['output_folder']
352
+ occurrence_dict = {}
353
+ print(img_folder)
354
+ for folder in os.listdir(img_folder):
355
+ print(folder)
356
+ if not os.path.isdir(os.path.join(img_folder, folder)):
357
+ continue
358
+ for file in os.listdir(os.path.join(img_folder, folder)):
359
+ if not os.path.isfile(os.path.join(img_folder, folder, file)):
360
+ continue
361
+ ext = os.path.splitext(file)[1]
362
+ if ext != extension:
363
+ continue
364
+ get_tags_from_file(os.path.join(img_folder, folder, file), occurrence_dict)
365
+ if not args['sort_tag_occurrence_alphabetically']:
366
+ output_list = {k: v for k, v in sorted(occurrence_dict.items(), key=lambda item: item[1], reverse=True)}
367
+ else:
368
+ output_list = {k: v for k, v in sorted(occurrence_dict.items(), key=lambda item: item[0])}
369
+ name = args['change_output_name'] if args['change_output_name'] else "last"
370
+ with open(os.path.join(output_folder, f"{name}.txt"), "w") as f:
371
+ f.write(f"Below is a list of keywords used during the training of {args['change_output_name']}:\n")
372
+ for k, v in output_list.items():
373
+ f.write(f"[{v}] {k}\n")
374
+ print(f"Created a txt file named {name}.txt in the output folder")
375
+
376
+
377
+ def get_tags_from_file(file, occurrence_dict):
378
+ f = open(file)
379
+ temp = f.read().replace(", ", ",").split(",")
380
+ f.close()
381
+ for tag in temp:
382
+ if tag in occurrence_dict:
383
+ occurrence_dict[tag] += 1
384
+ else:
385
+ occurrence_dict[tag] = 1
386
+
387
+
388
+ def ask_file(message, accepted_ext_list, file_path=None):
389
+ mb.showinfo(message=message)
390
+ res = ""
391
+ _initialdir = ""
392
+ _initialfile = ""
393
+ if file_path != None:
394
+ _initialdir = os.path.dirname(file_path) if os.path.exists(file_path) else ""
395
+ _initialfile = os.path.basename(file_path) if os.path.exists(file_path) else ""
396
+
397
+ while res == "":
398
+ res = fd.askopenfilename(title=message, initialdir=_initialdir, initialfile=_initialfile)
399
+ if res == "" or type(res) == tuple:
400
+ ret = mb.askretrycancel(message="Do you want to to cancel training?")
401
+ if not ret:
402
+ exit()
403
+ continue
404
+ elif not os.path.exists(res):
405
+ res = ""
406
+ continue
407
+ _, name = os.path.split(res)
408
+ split_name = name.split(".")
409
+ if split_name[-1] not in accepted_ext_list:
410
+ res = ""
411
+ return res
412
+
413
+
414
+ def ask_dir(message, dir_path=None):
415
+ mb.showinfo(message=message)
416
+ res = ""
417
+ _initialdir = ""
418
+ if dir_path != None:
419
+ _initialdir = dir_path if os.path.exists(dir_path) else ""
420
+ while res == "":
421
+ res = fd.askdirectory(title=message, initialdir=_initialdir)
422
+ if res == "" or type(res) == tuple:
423
+ ret = mb.askretrycancel(message="Do you want to to cancel training?")
424
+ if not ret:
425
+ exit()
426
+ continue
427
+ if not os.path.exists(res):
428
+ res = ""
429
+ return res
430
+
431
+
432
+ def ask_elements_trunc(args: dict):
433
+ args['base_model'] = ask_file("Select your base model", {"ckpt", "safetensors"}, args['base_model'])
434
+ args['img_folder'] = ask_dir("Select your image folder", args['img_folder'])
435
+ args['output_folder'] = ask_dir("Select your output folder", args['output_folder'])
436
+
437
+ ret = mb.askyesno(message="Do you want to save a json of your configuration?")
438
+ if ret:
439
+ args['save_json_folder'] = ask_dir("Select the folder to save json files to", args['save_json_folder'])
440
+ else:
441
+ args['save_json_folder'] = None
442
+
443
+ ret = mb.askyesno(message="Are you training on a SD2 based model?")
444
+ if ret:
445
+ args['v2'] = True
446
+
447
+ ret = mb.askyesno(message="Are you training on an realistic model?")
448
+ if ret:
449
+ args['clip_skip'] = 1
450
+
451
+ if args['v2']:
452
+ ret = mb.askyesno(message="Are you training on a model based on the 768x version of SD2?")
453
+ if ret:
454
+ args['v_parameterization'] = True
455
+
456
+ ret = mb.askyesno(message="Do you want to use regularization images?")
457
+ if ret:
458
+ args['reg_img_folder'] = ask_dir("Select your regularization folder", args['reg_img_folder'])
459
+ else:
460
+ args['reg_img_folder'] = None
461
+
462
+ ret = mb.askyesno(message="Do you want to continue from an earlier version?")
463
+ if ret:
464
+ args['lora_model_for_resume'] = ask_file("Select your lora model", {"ckpt", "pt", "safetensors"},
465
+ args['lora_model_for_resume'])
466
+ else:
467
+ args['lora_model_for_resume'] = None
468
+
469
+ ret = mb.askyesno(message="Do you want to flip all of your images? It is supposed to reduce biases\n"
470
+ "within your dataset but it can also ruin learning an asymmetrical element\n")
471
+ if ret:
472
+ args['flip_aug'] = True
473
+
474
+ ret = mb.askyesno(message="Do you want to change the name of output checkpoints?")
475
+ if ret:
476
+ ret = sd.askstring(title="output_name", prompt="What do you want your output name to be?\n"
477
+ "Cancel keeps outputs the original")
478
+ if ret:
479
+ args['change_output_name'] = ret
480
+ else:
481
+ args['change_output_name'] = None
482
+
483
+ ret = sd.askstring(title="comment",
484
+ prompt="Do you want to set a comment that gets put into the metadata?\nA good use of this would "
485
+ "be to include how to use, such as activation keywords.\nCancel will leave empty")
486
+ if ret is None:
487
+ args['training_comment'] = ret
488
+ else:
489
+ args['training_comment'] = None
490
+
491
+ ret = mb.askyesno(message="Do you want to train only one of unet and text encoder?")
492
+ if ret:
493
+ button = ButtonBox("Which do you want to train with?", ["unet_only", "text_only"])
494
+ button.window.mainloop()
495
+ if button.current_value != "":
496
+ args[button.current_value] = True
497
+
498
+ ret = mb.askyesno(message="Do you want to save a txt file that contains a list\n"
499
+ "of all tags that you have used in your training data?\n")
500
+ if ret:
501
+ args['tag_occurrence_txt_file'] = True
502
+ button = ButtonBox("How do you want tags to be ordered?", ["alphabetically", "occurrence-ly"])
503
+ button.window.mainloop()
504
+ if button.current_value == "alphabetically":
505
+ args['sort_tag_occurrence_alphabetically'] = True
506
+
507
+ ret = mb.askyesno(message="Do you want to use caption dropout?")
508
+ if ret:
509
+ ret = mb.askyesno(message="Do you want full caption files to dropout randomly?")
510
+ if ret:
511
+ ret = sd.askinteger(title="Caption_File_Dropout",
512
+ prompt="How often do you want caption files to drop out?\n"
513
+ "enter a number from 0 to 100 that is the percentage chance of dropout\n"
514
+ "Cancel sets to 0")
515
+ if ret and 0 <= ret <= 100:
516
+ args['caption_dropout_rate'] = ret / 100.0
517
+
518
+ ret = mb.askyesno(message="Do you want to have full epochs have no captions?")
519
+ if ret:
520
+ ret = sd.askinteger(title="Caption_epoch_dropout", prompt="The number set here is how often you will have an"
521
+ "epoch with no captions\nSo if you set 3, then every"
522
+ "three epochs will not have captions (3, 6, 9)\n"
523
+ "Cancel will set to None")
524
+ if ret:
525
+ args['caption_dropout_every_n_epochs'] = ret
526
+
527
+ ret = mb.askyesno(message="Do you want to have tags to randomly drop?")
528
+ if ret:
529
+ ret = sd.askinteger(title="Caption_tag_dropout", prompt="How often do you want tags to randomly drop out?\n"
530
+ "Enter a number between 0 and 100, that is the percentage"
531
+ "chance of dropout.\nCancel sets to 0")
532
+ if ret and 0 <= ret <= 100:
533
+ args['caption_tag_dropout_rate'] = ret / 100.0
534
+
535
+ ret = mb.askyesno(message="Do you want to use noise offset? Noise offset seems to allow for SD to better generate\n"
536
+ "darker or lighter images using this than normal.")
537
+ if ret:
538
+ ret = sd.askfloat(title="noise_offset", prompt="What value do you want to set? recommended value is 0.1,\n"
539
+ "but it can go higher. Cancel defaults to 0.1")
540
+ if ret:
541
+ args['noise_offset'] = ret
542
+ else:
543
+ args['noise_offset'] = 0.1
544
+ return args
545
+
546
+
547
+ def ask_elements(args: dict):
548
+ # start with file dialog
549
+ args['base_model'] = ask_file("Select your base model", {"ckpt", "safetensors"}, args['base_model'])
550
+ args['img_folder'] = ask_dir("Select your image folder", args['img_folder'])
551
+ args['output_folder'] = ask_dir("Select your output folder", args['output_folder'])
552
+
553
+ # optional file dialog
554
+ ret = mb.askyesno(message="Do you want to save a json of your configuration?")
555
+ if ret:
556
+ args['save_json_folder'] = ask_dir("Select the folder to save json files to", args['save_json_folder'])
557
+ else:
558
+ args['save_json_folder'] = None
559
+
560
+ ret = mb.askyesno(message="Are you training on a SD2 based model?")
561
+ if ret:
562
+ args['v2'] = True
563
+
564
+ ret = mb.askyesno(message="Are you training on an realistic model?")
565
+ if ret:
566
+ args['clip_skip'] = 1
567
+
568
+ if args['v2']:
569
+ ret = mb.askyesno(message="Are you training on a model based on the 768x version of SD2?")
570
+ if ret:
571
+ args['v_parameterization'] = True
572
+
573
+ ret = mb.askyesno(message="Do you want to use regularization images?")
574
+ if ret:
575
+ args['reg_img_folder'] = ask_dir("Select your regularization folder", args['reg_img_folder'])
576
+ else:
577
+ args['reg_img_folder'] = None
578
+
579
+ ret = mb.askyesno(message="Do you want to continue from an earlier version?")
580
+ if ret:
581
+ args['lora_model_for_resume'] = ask_file("Select your lora model", {"ckpt", "pt", "safetensors"},
582
+ args['lora_model_for_resume'])
583
+ else:
584
+ args['lora_model_for_resume'] = None
585
+
586
+ ret = mb.askyesno(message="Do you want to flip all of your images? It is supposed to reduce biases\n"
587
+ "within your dataset but it can also ruin learning an asymmetrical element\n")
588
+ if ret:
589
+ args['flip_aug'] = True
590
+
591
+ # text based required elements
592
+ ret = sd.askinteger(title="batch_size",
593
+ prompt="The number of images that get processed at one time, this is directly proportional to "
594
+ "your vram and resolution. with 12gb of vram, at 512 reso, you can get a maximum of 6 "
595
+ "batch size\nHow large is your batch size going to be?\nCancel will default to 1")
596
+ if ret is None:
597
+ args['batch_size'] = 1
598
+ else:
599
+ args['batch_size'] = ret
600
+
601
+ ret = sd.askinteger(title="num_epochs", prompt="How many epochs do you want?\nCancel will default to 1")
602
+ if ret is None:
603
+ args['num_epochs'] = 1
604
+ else:
605
+ args['num_epochs'] = ret
606
+
607
+ ret = sd.askinteger(title="network_dim", prompt="What is the dim size you want to use?\nCancel will default to 128")
608
+ if ret is None:
609
+ args['net_dim'] = 128
610
+ else:
611
+ args['net_dim'] = ret
612
+
613
+ ret = sd.askfloat(title="alpha", prompt="Alpha is the scalar of the training, generally a good starting point is "
614
+ "0.5x dim size\nWhat Alpha do you want?\nCancel will default to equal to "
615
+ "0.5 x network_dim")
616
+ if ret is None:
617
+ args['alpha'] = args['net_dim'] / 2
618
+ else:
619
+ args['alpha'] = ret
620
+
621
+ ret = sd.askinteger(title="resolution", prompt="How large of a resolution do you want to train at?\n"
622
+ "Cancel will default to 512")
623
+ if ret is None:
624
+ args['train_resolution'] = 512
625
+ else:
626
+ args['train_resolution'] = ret
627
+
628
+ ret = sd.askfloat(title="learning_rate", prompt="What learning rate do you want to use?\n"
629
+ "Cancel will default to 1e-4")
630
+ if ret is None:
631
+ args['learning_rate'] = 1e-4
632
+ else:
633
+ args['learning_rate'] = ret
634
+
635
+ ret = sd.askfloat(title="text_encoder_lr", prompt="Do you want to set the text_encoder_lr?\n"
636
+ "Cancel will default to None")
637
+ if ret is None:
638
+ args['text_encoder_lr'] = None
639
+ else:
640
+ args['text_encoder_lr'] = ret
641
+
642
+ ret = sd.askfloat(title="unet_lr", prompt="Do you want to set the unet_lr?\nCancel will default to None")
643
+ if ret is None:
644
+ args['unet_lr'] = None
645
+ else:
646
+ args['unet_lr'] = ret
647
+
648
+ button = ButtonBox("Which scheduler do you want?", ["cosine_with_restarts", "cosine", "polynomial",
649
+ "constant", "constant_with_warmup", "linear"])
650
+ button.window.mainloop()
651
+ args['scheduler'] = button.current_value if button.current_value != "" else "cosine_with_restarts"
652
+
653
+ if args['scheduler'] == "cosine_with_restarts":
654
+ ret = sd.askinteger(title="Cycle Count",
655
+ prompt="How many times do you want cosine to restart?\nThis is the entire amount of times "
656
+ "it will restart for the entire training\nCancel will default to 1")
657
+ if ret is None:
658
+ args['cosine_restarts'] = 1
659
+ else:
660
+ args['cosine_restarts'] = ret
661
+
662
+ if args['scheduler'] == "polynomial":
663
+ ret = sd.askfloat(title="Poly Strength",
664
+ prompt="What power do you want to set your polynomial to?\nhigher power means that the "
665
+ "model reduces the learning more more aggressively from initial training.\n1 = "
666
+ "linear\nCancel sets to 1")
667
+ if ret is None:
668
+ args['scheduler_power'] = 1
669
+ else:
670
+ args['scheduler_power'] = ret
671
+
672
+ ret = mb.askyesno(message="Do you want to save epochs as it trains?")
673
+ if ret:
674
+ ret = sd.askinteger(title="save_epoch",
675
+ prompt="How often do you want to save epochs?\nCancel will default to 1")
676
+ if ret is None:
677
+ args['save_every_n_epochs'] = 1
678
+ else:
679
+ args['save_every_n_epochs'] = ret
680
+
681
+ ret = mb.askyesno(message="Do you want to shuffle captions?")
682
+ if ret:
683
+ args['shuffle_captions'] = True
684
+ else:
685
+ args['shuffle_captions'] = False
686
+
687
+ ret = mb.askyesno(message="Do you want to keep some tokens at the front of your captions?")
688
+ if ret:
689
+ ret = sd.askinteger(title="keep_tokens", prompt="How many do you want to keep at the front?"
690
+ "\nCancel will default to 1")
691
+ if ret is None:
692
+ args['keep_tokens'] = 1
693
+ else:
694
+ args['keep_tokens'] = ret
695
+
696
+ ret = mb.askyesno(message="Do you want to have a warmup ratio?")
697
+ if ret:
698
+ ret = sd.askfloat(title="warmup_ratio", prompt="What is the ratio of steps to use as warmup "
699
+ "steps?\nCancel will default to None")
700
+ if ret is None:
701
+ args['warmup_lr_ratio'] = None
702
+ else:
703
+ args['warmup_lr_ratio'] = ret
704
+
705
+ ret = mb.askyesno(message="Do you want to change the name of output checkpoints?")
706
+ if ret:
707
+ ret = sd.askstring(title="output_name", prompt="What do you want your output name to be?\n"
708
+ "Cancel keeps outputs the original")
709
+ if ret:
710
+ args['change_output_name'] = ret
711
+ else:
712
+ args['change_output_name'] = None
713
+
714
+ ret = sd.askstring(title="comment",
715
+ prompt="Do you want to set a comment that gets put into the metadata?\nA good use of this would "
716
+ "be to include how to use, such as activation keywords.\nCancel will leave empty")
717
+ if ret is None:
718
+ args['training_comment'] = ret
719
+ else:
720
+ args['training_comment'] = None
721
+
722
+ ret = mb.askyesno(message="Do you want to train only one of unet and text encoder?")
723
+ if ret:
724
+ if ret:
725
+ button = ButtonBox("Which do you want to train with?", ["unet_only", "text_only"])
726
+ button.window.mainloop()
727
+ if button.current_value != "":
728
+ args[button.current_value] = True
729
+
730
+ ret = mb.askyesno(message="Do you want to save a txt file that contains a list\n"
731
+ "of all tags that you have used in your training data?\n")
732
+ if ret:
733
+ args['tag_occurrence_txt_file'] = True
734
+ button = ButtonBox("How do you want tags to be ordered?", ["alphabetically", "occurrence-ly"])
735
+ button.window.mainloop()
736
+ if button.current_value == "alphabetically":
737
+ args['sort_tag_occurrence_alphabetically'] = True
738
+
739
+ ret = mb.askyesno(message="Do you want to use caption dropout?")
740
+ if ret:
741
+ ret = mb.askyesno(message="Do you want full caption files to dropout randomly?")
742
+ if ret:
743
+ ret = sd.askinteger(title="Caption_File_Dropout",
744
+ prompt="How often do you want caption files to drop out?\n"
745
+ "enter a number from 0 to 100 that is the percentage chance of dropout\n"
746
+ "Cancel sets to 0")
747
+ if ret and 0 <= ret <= 100:
748
+ args['caption_dropout_rate'] = ret / 100.0
749
+
750
+ ret = mb.askyesno(message="Do you want to have full epochs have no captions?")
751
+ if ret:
752
+ ret = sd.askinteger(title="Caption_epoch_dropout", prompt="The number set here is how often you will have an"
753
+ "epoch with no captions\nSo if you set 3, then every"
754
+ "three epochs will not have captions (3, 6, 9)\n"
755
+ "Cancel will set to None")
756
+ if ret:
757
+ args['caption_dropout_every_n_epochs'] = ret
758
+
759
+ ret = mb.askyesno(message="Do you want to have tags to randomly drop?")
760
+ if ret:
761
+ ret = sd.askinteger(title="Caption_tag_dropout", prompt="How often do you want tags to randomly drop out?\n"
762
+ "Enter a number between 0 and 100, that is the percentage"
763
+ "chance of dropout.\nCancel sets to 0")
764
+ if ret and 0 <= ret <= 100:
765
+ args['caption_tag_dropout_rate'] = ret / 100.0
766
+
767
+ ret = mb.askyesno(message="Do you want to use noise offset? Noise offset seems to allow for SD to better generate\n"
768
+ "darker or lighter images using this than normal.")
769
+ if ret:
770
+ ret = sd.askfloat(title="noise_offset", prompt="What value do you want to set? recommended value is 0.1,\n"
771
+ "but it can go higher. Cancel defaults to 0.1")
772
+ if ret:
773
+ args['noise_offset'] = ret
774
+ else:
775
+ args['noise_offset'] = 0.1
776
+ return args
777
+
778
+
779
+ def save_json(path, obj: dict) -> None:
780
+ fp = open(os.path.join(path, f"config-{time.time()}.json"), "w")
781
+ json.dump(obj, fp=fp, indent=4)
782
+ fp.close()
783
+
784
+
785
+ def load_json(path, obj: dict) -> dict:
786
+ with open(path) as f:
787
+ json_obj = json.loads(f.read())
788
+ print("loaded json, setting variables...")
789
+ ui_name_scheme = {"pretrained_model_name_or_path": "base_model", "logging_dir": "log_dir",
790
+ "train_data_dir": "img_folder", "reg_data_dir": "reg_img_folder",
791
+ "output_dir": "output_folder", "max_resolution": "train_resolution",
792
+ "lr_scheduler": "scheduler", "lr_warmup": "warmup_lr_ratio",
793
+ "train_batch_size": "batch_size", "epoch": "num_epochs",
794
+ "save_at_n_epochs": "save_every_n_epochs", "num_cpu_threads_per_process": "num_workers",
795
+ "enable_bucket": "buckets", "save_model_as": "save_as", "shuffle_caption": "shuffle_captions",
796
+ "resume": "load_previous_save_state", "network_dim": "net_dim",
797
+ "gradient_accumulation_steps": "gradient_acc_steps", "output_name": "change_output_name",
798
+ "network_alpha": "alpha", "lr_scheduler_num_cycles": "cosine_restarts",
799
+ "lr_scheduler_power": "scheduler_power"}
800
+
801
+ for key in list(json_obj):
802
+ if key in ui_name_scheme:
803
+ json_obj[ui_name_scheme[key]] = json_obj[key]
804
+ if ui_name_scheme[key] in {"batch_size", "num_epochs"}:
805
+ try:
806
+ json_obj[ui_name_scheme[key]] = int(json_obj[ui_name_scheme[key]])
807
+ except ValueError:
808
+ print(f"attempting to load {key} from json failed as input isn't an integer")
809
+ quit(1)
810
+
811
+ for key in list(json_obj):
812
+ if obj["json_load_skip_list"] and key in obj["json_load_skip_list"]:
813
+ continue
814
+ if key in obj:
815
+ if key in {"keep_tokens", "warmup_lr_ratio"}:
816
+ json_obj[key] = int(json_obj[key]) if json_obj[key] is not None else None
817
+ if key in {"learning_rate", "unet_lr", "text_encoder_lr"}:
818
+ json_obj[key] = float(json_obj[key]) if json_obj[key] is not None else None
819
+ if obj[key] != json_obj[key]:
820
+ print_change(key, obj[key], json_obj[key])
821
+ obj[key] = json_obj[key]
822
+ print("completed changing variables.")
823
+ return obj
824
+
825
+
826
+ def print_change(value, old, new):
827
+ print(f"{value} changed from {old} to {new}")
828
+
829
+
830
+ class ButtonBox:
831
+ def __init__(self, label: str, button_name_list: list[str]) -> None:
832
+ self.window = tk.Tk()
833
+ self.button_list = []
834
+ self.current_value = ""
835
+
836
+ self.window.attributes("-topmost", True)
837
+ self.window.resizable(False, False)
838
+ self.window.eval('tk::PlaceWindow . center')
839
+
840
+ def del_window():
841
+ self.window.quit()
842
+ self.window.destroy()
843
+
844
+ self.window.protocol("WM_DELETE_WINDOW", del_window)
845
+ tk.Label(text=label, master=self.window).pack()
846
+ for button in button_name_list:
847
+ self.button_list.append(ttk.Button(text=button, master=self.window,
848
+ command=partial(self.set_current_value, button)))
849
+ self.button_list[-1].pack()
850
+
851
+ def set_current_value(self, value):
852
+ self.current_value = value
853
+ self.window.quit()
854
+ self.window.destroy()
855
+
856
+
857
+ root = tk.Tk()
858
+ root.attributes('-topmost', True)
859
+ root.withdraw()
860
+
861
+ if __name__ == "__main__":
862
+ main()
lycoris/kohya.py CHANGED
@@ -5,7 +5,6 @@
5
  # https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
6
 
7
  import math
8
- from warnings import warn
9
  import os
10
  from typing import List
11
  import torch
@@ -28,22 +27,6 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
28
  }[algo]
29
 
30
  print(f'Using rank adaptation algo: {algo}')
31
-
32
- if (algo == 'loha'
33
- and not kwargs.get('no_dim_warn', False)
34
- and (network_dim>64 or conv_dim>64)):
35
- print('='*20 + 'WARNING' + '='*20)
36
- warn(
37
- (
38
- "You are not supposed to use dim>64 (64*64 = 4096, it already has enough rank)"
39
- "in Hadamard Product representation!\n"
40
- "Please consider use lower dim or disable this warning with --network_args no_dim_warn=True\n"
41
- "If you just want to use high dim loha, please consider use lower lr."
42
- ),
43
- stacklevel=2,
44
- )
45
- print('='*20 + 'WARNING' + '='*20)
46
-
47
  network = LoRANetwork(
48
  text_encoder, unet,
49
  multiplier=multiplier,
 
5
  # https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
6
 
7
  import math
 
8
  import os
9
  from typing import List
10
  import torch
 
27
  }[algo]
28
 
29
  print(f'Using rank adaptation algo: {algo}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  network = LoRANetwork(
31
  text_encoder, unet,
32
  multiplier=multiplier,
lycoris/loha.py CHANGED
@@ -36,12 +36,7 @@ class LohaModule(nn.Module):
36
  Hadamard product Implementaion for Low Rank Adaptation
37
  """
38
 
39
- def __init__(
40
- self,
41
- lora_name,
42
- org_module: nn.Module,
43
- multiplier=1.0, lora_dim=4, alpha=1, dropout=0.,
44
- ):
45
  """ if alpha == 0 or None, alpha is rank (no scaling). """
46
  super().__init__()
47
  self.lora_name = lora_name
 
36
  Hadamard product Implementaion for Low Rank Adaptation
37
  """
38
 
39
+ def __init__(self, lora_name, org_module: nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=0.):
 
 
 
 
 
40
  """ if alpha == 0 or None, alpha is rank (no scaling). """
41
  super().__init__()
42
  self.lora_name = lora_name
lycoris/utils.py CHANGED
@@ -28,13 +28,11 @@ def extract_conv(
28
  assert 1>=mode_param>=0
29
  min_s = torch.max(S)*mode_param
30
  lora_rank = torch.sum(S>min_s)
31
- elif mode=='quantile' or mode=='percentile':
32
  assert 1>=mode_param>=0
33
  s_cum = torch.cumsum(S, dim=0)
34
  min_cum_sum = mode_param * torch.sum(S)
35
  lora_rank = torch.sum(s_cum<min_cum_sum)
36
- else:
37
- raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
38
  lora_rank = max(1, lora_rank)
39
  lora_rank = min(out_ch, in_ch, lora_rank)
40
 
@@ -90,13 +88,11 @@ def extract_linear(
90
  assert 1>=mode_param>=0
91
  min_s = torch.max(S)*mode_param
92
  lora_rank = torch.sum(S>min_s)
93
- elif mode=='quantile' or mode=='percentile':
94
  assert 1>=mode_param>=0
95
  s_cum = torch.cumsum(S, dim=0)
96
  min_cum_sum = mode_param * torch.sum(S)
97
  lora_rank = torch.sum(s_cum<min_cum_sum)
98
- else:
99
- raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
100
  lora_rank = max(1, lora_rank)
101
  lora_rank = min(out_ch, in_ch, lora_rank)
102
 
@@ -263,69 +259,6 @@ def merge_locon(
263
  child_module.weight += (alpha.to(device)/rank * scale * delta).cpu()
264
  del delta
265
 
266
- merge(
267
- LORA_PREFIX_TEXT_ENCODER,
268
- base_model[0],
269
- TEXT_ENCODER_TARGET_REPLACE_MODULE
270
- )
271
- merge(
272
- LORA_PREFIX_UNET,
273
- base_model[2],
274
- UNET_TARGET_REPLACE_MODULE
275
- )
276
-
277
-
278
- def merge_loha(
279
- base_model,
280
- loha_state_dict: Dict[str, torch.TensorType],
281
- scale: float = 1.0,
282
- device = 'cpu'
283
- ):
284
- UNET_TARGET_REPLACE_MODULE = [
285
- "Transformer2DModel",
286
- "Attention",
287
- "ResnetBlock2D",
288
- "Downsample2D",
289
- "Upsample2D"
290
- ]
291
- TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
292
- LORA_PREFIX_UNET = 'lora_unet'
293
- LORA_PREFIX_TEXT_ENCODER = 'lora_te'
294
- def merge(
295
- prefix,
296
- root_module: torch.nn.Module,
297
- target_replace_modules
298
- ):
299
- temp = {}
300
-
301
- for name, module in tqdm(list(root_module.named_modules())):
302
- if module.__class__.__name__ in target_replace_modules:
303
- temp[name] = {}
304
- for child_name, child_module in module.named_modules():
305
- layer = child_module.__class__.__name__
306
- if layer not in {'Linear', 'Conv2d'}:
307
- continue
308
- lora_name = prefix + '.' + name + '.' + child_name
309
- lora_name = lora_name.replace('.', '_')
310
-
311
- w1a = loha_state_dict[f'{lora_name}.hada_w1_a'].float().to(device)
312
- w1b = loha_state_dict[f'{lora_name}.hada_w1_b'].float().to(device)
313
- w2a = loha_state_dict[f'{lora_name}.hada_w2_a'].float().to(device)
314
- w2b = loha_state_dict[f'{lora_name}.hada_w2_b'].float().to(device)
315
- alpha = loha_state_dict[f'{lora_name}.alpha'].float().to(device)
316
- dim = w1b.shape[0]
317
-
318
- delta = (w1a @ w1b) * (w2a @ w2b)
319
- delta = delta.reshape(child_module.weight.shape)
320
-
321
- if layer == 'Conv2d':
322
- child_module.weight.requires_grad_(False)
323
- child_module.weight += (alpha.to(device)/dim * scale * delta).cpu()
324
- elif layer == 'Linear':
325
- child_module.weight.requires_grad_(False)
326
- child_module.weight += (alpha.to(device)/dim * scale * delta).cpu()
327
- del delta
328
-
329
  merge(
330
  LORA_PREFIX_TEXT_ENCODER,
331
  base_model[0],
 
28
  assert 1>=mode_param>=0
29
  min_s = torch.max(S)*mode_param
30
  lora_rank = torch.sum(S>min_s)
31
+ elif mode=='percentile':
32
  assert 1>=mode_param>=0
33
  s_cum = torch.cumsum(S, dim=0)
34
  min_cum_sum = mode_param * torch.sum(S)
35
  lora_rank = torch.sum(s_cum<min_cum_sum)
 
 
36
  lora_rank = max(1, lora_rank)
37
  lora_rank = min(out_ch, in_ch, lora_rank)
38
 
 
88
  assert 1>=mode_param>=0
89
  min_s = torch.max(S)*mode_param
90
  lora_rank = torch.sum(S>min_s)
91
+ elif mode=='percentile':
92
  assert 1>=mode_param>=0
93
  s_cum = torch.cumsum(S, dim=0)
94
  min_cum_sum = mode_param * torch.sum(S)
95
  lora_rank = torch.sum(s_cum<min_cum_sum)
 
 
96
  lora_rank = max(1, lora_rank)
97
  lora_rank = min(out_ch, in_ch, lora_rank)
98
 
 
259
  child_module.weight += (alpha.to(device)/rank * scale * delta).cpu()
260
  del delta
261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  merge(
263
  LORA_PREFIX_TEXT_ENCODER,
264
  base_model[0],
networks/check_lora_weights.py CHANGED
@@ -21,7 +21,7 @@ def main(file):
21
 
22
  for key, value in values:
23
  value = value.to(torch.float32)
24
- print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
25
 
26
 
27
  if __name__ == '__main__':
 
21
 
22
  for key, value in values:
23
  value = value.to(torch.float32)
24
+ print(f"{key},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
25
 
26
 
27
  if __name__ == '__main__':
networks/extract_lora_from_models.py CHANGED
@@ -45,13 +45,8 @@ def svd(args):
45
  text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
46
 
47
  # create LoRA network to extract weights: Use dim (rank) as alpha
48
- if args.conv_dim is None:
49
- kwargs = {}
50
- else:
51
- kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim}
52
-
53
- lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o, **kwargs)
54
- lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t, **kwargs)
55
  assert len(lora_network_o.text_encoder_loras) == len(
56
  lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
57
 
@@ -90,28 +85,13 @@ def svd(args):
90
 
91
  # make LoRA with svd
92
  print("calculating by svd")
 
93
  lora_weights = {}
94
  with torch.no_grad():
95
  for lora_name, mat in tqdm(list(diffs.items())):
96
- # if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3
97
  conv2d = (len(mat.size()) == 4)
98
- kernel_size = None if not conv2d else mat.size()[2:4]
99
- conv2d_3x3 = conv2d and kernel_size != (1, 1)
100
-
101
- rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim
102
- out_dim, in_dim = mat.size()[0:2]
103
-
104
- if args.device:
105
- mat = mat.to(args.device)
106
-
107
- # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
108
- rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
109
-
110
  if conv2d:
111
- if conv2d_3x3:
112
- mat = mat.flatten(start_dim=1)
113
- else:
114
- mat = mat.squeeze()
115
 
116
  U, S, Vh = torch.linalg.svd(mat)
117
 
@@ -128,27 +108,30 @@ def svd(args):
128
  U = U.clamp(low_val, hi_val)
129
  Vh = Vh.clamp(low_val, hi_val)
130
 
131
- if conv2d:
132
- U = U.reshape(out_dim, rank, 1, 1)
133
- Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
134
-
135
- U = U.to("cpu").contiguous()
136
- Vh = Vh.to("cpu").contiguous()
137
-
138
  lora_weights[lora_name] = (U, Vh)
139
 
140
  # make state dict for LoRA
141
- lora_sd = {}
142
- for lora_name, (up_weight, down_weight) in lora_weights.items():
143
- lora_sd[lora_name + '.lora_up.weight'] = up_weight
144
- lora_sd[lora_name + '.lora_down.weight'] = down_weight
145
- lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0])
146
 
147
- # load state dict to LoRA and save it
148
- lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
149
- lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict
 
 
 
150
 
151
- info = lora_network_save.load_state_dict(lora_sd)
 
 
 
 
 
 
 
 
 
152
  print(f"Loading extracted LoRA weights: {info}")
153
 
154
  dir_name = os.path.dirname(args.save_to)
@@ -156,9 +139,9 @@ def svd(args):
156
  os.makedirs(dir_name, exist_ok=True)
157
 
158
  # minimum metadata
159
- metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
160
 
161
- lora_network_save.save_weights(args.save_to, save_dtype, metadata)
162
  print(f"LoRA weights are saved to: {args.save_to}")
163
 
164
 
@@ -175,8 +158,6 @@ if __name__ == '__main__':
175
  parser.add_argument("--save_to", type=str, default=None,
176
  help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
177
  parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
178
- parser.add_argument("--conv_dim", type=int, default=None,
179
- help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)")
180
  parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
181
 
182
  args = parser.parse_args()
 
45
  text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
46
 
47
  # create LoRA network to extract weights: Use dim (rank) as alpha
48
+ lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o)
49
+ lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t)
 
 
 
 
 
50
  assert len(lora_network_o.text_encoder_loras) == len(
51
  lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
52
 
 
85
 
86
  # make LoRA with svd
87
  print("calculating by svd")
88
+ rank = args.dim
89
  lora_weights = {}
90
  with torch.no_grad():
91
  for lora_name, mat in tqdm(list(diffs.items())):
 
92
  conv2d = (len(mat.size()) == 4)
 
 
 
 
 
 
 
 
 
 
 
 
93
  if conv2d:
94
+ mat = mat.squeeze()
 
 
 
95
 
96
  U, S, Vh = torch.linalg.svd(mat)
97
 
 
108
  U = U.clamp(low_val, hi_val)
109
  Vh = Vh.clamp(low_val, hi_val)
110
 
 
 
 
 
 
 
 
111
  lora_weights[lora_name] = (U, Vh)
112
 
113
  # make state dict for LoRA
114
+ lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict
115
+ lora_sd = lora_network_o.state_dict()
116
+ print(f"LoRA has {len(lora_sd)} weights.")
 
 
117
 
118
+ for key in list(lora_sd.keys()):
119
+ if "alpha" in key:
120
+ continue
121
+
122
+ lora_name = key.split('.')[0]
123
+ i = 0 if "lora_up" in key else 1
124
 
125
+ weights = lora_weights[lora_name][i]
126
+ # print(key, i, weights.size(), lora_sd[key].size())
127
+ if len(lora_sd[key].size()) == 4:
128
+ weights = weights.unsqueeze(2).unsqueeze(3)
129
+
130
+ assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
131
+ lora_sd[key] = weights
132
+
133
+ # load state dict to LoRA and save it
134
+ info = lora_network_o.load_state_dict(lora_sd)
135
  print(f"Loading extracted LoRA weights: {info}")
136
 
137
  dir_name = os.path.dirname(args.save_to)
 
139
  os.makedirs(dir_name, exist_ok=True)
140
 
141
  # minimum metadata
142
+ metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
143
 
144
+ lora_network_o.save_weights(args.save_to, save_dtype, metadata)
145
  print(f"LoRA weights are saved to: {args.save_to}")
146
 
147
 
 
158
  parser.add_argument("--save_to", type=str, default=None,
159
  help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
160
  parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
 
 
161
  parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
162
 
163
  args = parser.parse_args()
networks/lora.py CHANGED
@@ -6,7 +6,6 @@
6
  import math
7
  import os
8
  from typing import List
9
- import numpy as np
10
  import torch
11
 
12
  from library import train_util
@@ -21,34 +20,22 @@ class LoRAModule(torch.nn.Module):
21
  """ if alpha == 0 or None, alpha is rank (no scaling). """
22
  super().__init__()
23
  self.lora_name = lora_name
 
24
 
25
  if org_module.__class__.__name__ == 'Conv2d':
26
  in_dim = org_module.in_channels
27
  out_dim = org_module.out_channels
 
 
28
  else:
29
  in_dim = org_module.in_features
30
  out_dim = org_module.out_features
31
-
32
- # if limit_rank:
33
- # self.lora_dim = min(lora_dim, in_dim, out_dim)
34
- # if self.lora_dim != lora_dim:
35
- # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
36
- # else:
37
- self.lora_dim = lora_dim
38
-
39
- if org_module.__class__.__name__ == 'Conv2d':
40
- kernel_size = org_module.kernel_size
41
- stride = org_module.stride
42
- padding = org_module.padding
43
- self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
44
- self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
45
- else:
46
- self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
47
- self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
48
 
49
  if type(alpha) == torch.Tensor:
50
  alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
51
- alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
52
  self.scale = alpha / self.lora_dim
53
  self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
54
 
@@ -58,192 +45,69 @@ class LoRAModule(torch.nn.Module):
58
 
59
  self.multiplier = multiplier
60
  self.org_module = org_module # remove in applying
61
- self.region = None
62
- self.region_mask = None
63
 
64
  def apply_to(self):
65
  self.org_forward = self.org_module.forward
66
  self.org_module.forward = self.forward
67
  del self.org_module
68
 
69
- def set_region(self, region):
70
- self.region = region
71
- self.region_mask = None
72
-
73
  def forward(self, x):
74
- if self.region is None:
75
- return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
76
-
77
- # regional LoRA FIXME same as additional-network extension
78
- if x.size()[1] % 77 == 0:
79
- # print(f"LoRA for context: {self.lora_name}")
80
- self.region = None
81
- return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
82
-
83
- # calculate region mask first time
84
- if self.region_mask is None:
85
- if len(x.size()) == 4:
86
- h, w = x.size()[2:4]
87
- else:
88
- seq_len = x.size()[1]
89
- ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
90
- h = int(self.region.size()[0] / ratio + .5)
91
- w = seq_len // h
92
-
93
- r = self.region.to(x.device)
94
- if r.dtype == torch.bfloat16:
95
- r = r.to(torch.float)
96
- r = r.unsqueeze(0).unsqueeze(1)
97
- # print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
98
- r = torch.nn.functional.interpolate(r, (h, w), mode='bilinear')
99
- r = r.to(x.dtype)
100
-
101
- if len(x.size()) == 3:
102
- r = torch.reshape(r, (1, x.size()[1], -1))
103
-
104
- self.region_mask = r
105
-
106
- return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask
107
 
108
 
109
  def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
110
  if network_dim is None:
111
  network_dim = 4 # default
112
-
113
- # extract dim/alpha for conv2d, and block dim
114
- conv_dim = kwargs.get('conv_dim', None)
115
- conv_alpha = kwargs.get('conv_alpha', None)
116
- if conv_dim is not None:
117
- conv_dim = int(conv_dim)
118
- if conv_alpha is None:
119
- conv_alpha = 1.0
120
- else:
121
- conv_alpha = float(conv_alpha)
122
-
123
- """
124
- block_dims = kwargs.get("block_dims")
125
- block_alphas = None
126
-
127
- if block_dims is not None:
128
- block_dims = [int(d) for d in block_dims.split(',')]
129
- assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
130
- block_alphas = kwargs.get("block_alphas")
131
- if block_alphas is None:
132
- block_alphas = [1] * len(block_dims)
133
- else:
134
- block_alphas = [int(a) for a in block_alphas(',')]
135
- assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
136
-
137
- conv_block_dims = kwargs.get("conv_block_dims")
138
- conv_block_alphas = None
139
-
140
- if conv_block_dims is not None:
141
- conv_block_dims = [int(d) for d in conv_block_dims.split(',')]
142
- assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
143
- conv_block_alphas = kwargs.get("conv_block_alphas")
144
- if conv_block_alphas is None:
145
- conv_block_alphas = [1] * len(conv_block_dims)
146
- else:
147
- conv_block_alphas = [int(a) for a in conv_block_alphas(',')]
148
- assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
149
- """
150
-
151
- network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim,
152
- alpha=network_alpha, conv_lora_dim=conv_dim, conv_alpha=conv_alpha)
153
  return network
154
 
155
 
156
- def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
157
- if weights_sd is None:
158
- if os.path.splitext(file)[1] == '.safetensors':
159
- from safetensors.torch import load_file, safe_open
160
- weights_sd = load_file(file)
161
- else:
162
- weights_sd = torch.load(file, map_location='cpu')
163
 
164
- # get dim/alpha mapping
165
- modules_dim = {}
166
- modules_alpha = {}
167
  for key, value in weights_sd.items():
168
- if '.' not in key:
169
- continue
170
-
171
- lora_name = key.split('.')[0]
172
- if 'alpha' in key:
173
- modules_alpha[lora_name] = value
174
- elif 'lora_down' in key:
175
- dim = value.size()[0]
176
- modules_dim[lora_name] = dim
177
- # print(lora_name, value.size(), dim)
178
-
179
- # support old LoRA without alpha
180
- for key in modules_dim.keys():
181
- if key not in modules_alpha:
182
- modules_alpha = modules_dim[key]
183
-
184
- network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
185
  network.weights_sd = weights_sd
186
  return network
187
 
188
 
189
  class LoRANetwork(torch.nn.Module):
190
- # is it possible to apply conv_in and conv_out?
191
  UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
192
- UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
193
  TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
194
  LORA_PREFIX_UNET = 'lora_unet'
195
  LORA_PREFIX_TEXT_ENCODER = 'lora_te'
196
 
197
- def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1, conv_lora_dim=None, conv_alpha=None, modules_dim=None, modules_alpha=None) -> None:
198
  super().__init__()
199
  self.multiplier = multiplier
200
-
201
  self.lora_dim = lora_dim
202
  self.alpha = alpha
203
- self.conv_lora_dim = conv_lora_dim
204
- self.conv_alpha = conv_alpha
205
-
206
- if modules_dim is not None:
207
- print(f"create LoRA network from weights")
208
- else:
209
- print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
210
-
211
- self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None
212
- if self.apply_to_conv2d_3x3:
213
- if self.conv_alpha is None:
214
- self.conv_alpha = self.alpha
215
- print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
216
 
217
  # create module instances
218
  def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
219
  loras = []
220
  for name, module in root_module.named_modules():
221
  if module.__class__.__name__ in target_replace_modules:
222
- # TODO get block index here
223
  for child_name, child_module in module.named_modules():
224
- is_linear = child_module.__class__.__name__ == "Linear"
225
- is_conv2d = child_module.__class__.__name__ == "Conv2d"
226
- is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
227
- if is_linear or is_conv2d:
228
  lora_name = prefix + '.' + name + '.' + child_name
229
  lora_name = lora_name.replace('.', '_')
230
-
231
- if modules_dim is not None:
232
- if lora_name not in modules_dim:
233
- continue # no LoRA module in this weights file
234
- dim = modules_dim[lora_name]
235
- alpha = modules_alpha[lora_name]
236
- else:
237
- if is_linear or is_conv2d_1x1:
238
- dim = self.lora_dim
239
- alpha = self.alpha
240
- elif self.apply_to_conv2d_3x3:
241
- dim = self.conv_lora_dim
242
- alpha = self.conv_alpha
243
- else:
244
- continue
245
-
246
- lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
247
  loras.append(lora)
248
  return loras
249
 
@@ -251,12 +115,7 @@ class LoRANetwork(torch.nn.Module):
251
  text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
252
  print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
253
 
254
- # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
255
- target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
256
- if modules_dim is not None or self.conv_lora_dim is not None:
257
- target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
258
-
259
- self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules)
260
  print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
261
 
262
  self.weights_sd = None
@@ -267,11 +126,6 @@ class LoRANetwork(torch.nn.Module):
267
  assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
268
  names.add(lora.lora_name)
269
 
270
- def set_multiplier(self, multiplier):
271
- self.multiplier = multiplier
272
- for lora in self.text_encoder_loras + self.unet_loras:
273
- lora.multiplier = self.multiplier
274
-
275
  def load_weights(self, file):
276
  if os.path.splitext(file)[1] == '.safetensors':
277
  from safetensors.torch import load_file, safe_open
@@ -381,18 +235,3 @@ class LoRANetwork(torch.nn.Module):
381
  save_file(state_dict, file, metadata)
382
  else:
383
  torch.save(state_dict, file)
384
-
385
- @ staticmethod
386
- def set_regions(networks, image):
387
- image = image.astype(np.float32) / 255.0
388
- for i, network in enumerate(networks[:3]):
389
- # NOTE: consider averaging overwrapping area
390
- region = image[:, :, i]
391
- if region.max() == 0:
392
- continue
393
- region = torch.tensor(region)
394
- network.set_region(region)
395
-
396
- def set_region(self, region):
397
- for lora in self.unet_loras:
398
- lora.set_region(region)
 
6
  import math
7
  import os
8
  from typing import List
 
9
  import torch
10
 
11
  from library import train_util
 
20
  """ if alpha == 0 or None, alpha is rank (no scaling). """
21
  super().__init__()
22
  self.lora_name = lora_name
23
+ self.lora_dim = lora_dim
24
 
25
  if org_module.__class__.__name__ == 'Conv2d':
26
  in_dim = org_module.in_channels
27
  out_dim = org_module.out_channels
28
+ self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
29
+ self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
30
  else:
31
  in_dim = org_module.in_features
32
  out_dim = org_module.out_features
33
+ self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
34
+ self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  if type(alpha) == torch.Tensor:
37
  alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
38
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
39
  self.scale = alpha / self.lora_dim
40
  self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
41
 
 
45
 
46
  self.multiplier = multiplier
47
  self.org_module = org_module # remove in applying
 
 
48
 
49
  def apply_to(self):
50
  self.org_forward = self.org_module.forward
51
  self.org_module.forward = self.forward
52
  del self.org_module
53
 
 
 
 
 
54
  def forward(self, x):
55
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
 
58
  def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
59
  if network_dim is None:
60
  network_dim = 4 # default
61
+ network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  return network
63
 
64
 
65
+ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
66
+ if os.path.splitext(file)[1] == '.safetensors':
67
+ from safetensors.torch import load_file, safe_open
68
+ weights_sd = load_file(file)
69
+ else:
70
+ weights_sd = torch.load(file, map_location='cpu')
 
71
 
72
+ # get dim (rank)
73
+ network_alpha = None
74
+ network_dim = None
75
  for key, value in weights_sd.items():
76
+ if network_alpha is None and 'alpha' in key:
77
+ network_alpha = value
78
+ if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
79
+ network_dim = value.size()[0]
80
+
81
+ if network_alpha is None:
82
+ network_alpha = network_dim
83
+
84
+ network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
 
 
 
 
 
 
 
 
85
  network.weights_sd = weights_sd
86
  return network
87
 
88
 
89
  class LoRANetwork(torch.nn.Module):
 
90
  UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
 
91
  TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
92
  LORA_PREFIX_UNET = 'lora_unet'
93
  LORA_PREFIX_TEXT_ENCODER = 'lora_te'
94
 
95
+ def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
96
  super().__init__()
97
  self.multiplier = multiplier
 
98
  self.lora_dim = lora_dim
99
  self.alpha = alpha
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  # create module instances
102
  def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
103
  loras = []
104
  for name, module in root_module.named_modules():
105
  if module.__class__.__name__ in target_replace_modules:
 
106
  for child_name, child_module in module.named_modules():
107
+ if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
 
 
 
108
  lora_name = prefix + '.' + name + '.' + child_name
109
  lora_name = lora_name.replace('.', '_')
110
+ lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  loras.append(lora)
112
  return loras
113
 
 
115
  text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
116
  print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
117
 
118
+ self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
 
 
 
 
 
119
  print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
120
 
121
  self.weights_sd = None
 
126
  assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
127
  names.add(lora.lora_name)
128
 
 
 
 
 
 
129
  def load_weights(self, file):
130
  if os.path.splitext(file)[1] == '.safetensors':
131
  from safetensors.torch import load_file, safe_open
 
235
  save_file(state_dict, file, metadata)
236
  else:
237
  torch.save(state_dict, file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
networks/merge_lora.py CHANGED
@@ -48,7 +48,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
48
  for name, module in root_module.named_modules():
49
  if module.__class__.__name__ in target_replace_modules:
50
  for child_name, child_module in module.named_modules():
51
- if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
52
  lora_name = prefix + '.' + name + '.' + child_name
53
  lora_name = lora_name.replace('.', '_')
54
  name_to_module[lora_name] = child_module
@@ -80,19 +80,13 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
80
 
81
  # W <- W + U * D
82
  weight = module.weight
83
- # print(module_name, down_weight.size(), up_weight.size())
84
  if len(weight.size()) == 2:
85
  # linear
86
  weight = weight + ratio * (up_weight @ down_weight) * scale
87
- elif down_weight.size()[2:4] == (1, 1):
88
- # conv2d 1x1
89
  weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
90
  ).unsqueeze(2).unsqueeze(3) * scale
91
- else:
92
- # conv2d 3x3
93
- conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
94
- # print(conved.size(), weight.size(), module.stride, module.padding)
95
- weight = weight + ratio * conved * scale
96
 
97
  module.weight = torch.nn.Parameter(weight)
98
 
@@ -129,7 +123,7 @@ def merge_lora_models(models, ratios, merge_dtype):
129
  alphas[lora_module_name] = alpha
130
  if lora_module_name not in base_alphas:
131
  base_alphas[lora_module_name] = alpha
132
-
133
  print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
134
 
135
  # merge
@@ -151,7 +145,7 @@ def merge_lora_models(models, ratios, merge_dtype):
151
  merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
152
  else:
153
  merged_sd[key] = lora_sd[key] * scale
154
-
155
  # set alpha to sd
156
  for lora_module_name, alpha in base_alphas.items():
157
  key = lora_module_name + ".alpha"
 
48
  for name, module in root_module.named_modules():
49
  if module.__class__.__name__ in target_replace_modules:
50
  for child_name, child_module in module.named_modules():
51
+ if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
52
  lora_name = prefix + '.' + name + '.' + child_name
53
  lora_name = lora_name.replace('.', '_')
54
  name_to_module[lora_name] = child_module
 
80
 
81
  # W <- W + U * D
82
  weight = module.weight
 
83
  if len(weight.size()) == 2:
84
  # linear
85
  weight = weight + ratio * (up_weight @ down_weight) * scale
86
+ else:
87
+ # conv2d
88
  weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
89
  ).unsqueeze(2).unsqueeze(3) * scale
 
 
 
 
 
90
 
91
  module.weight = torch.nn.Parameter(weight)
92
 
 
123
  alphas[lora_module_name] = alpha
124
  if lora_module_name not in base_alphas:
125
  base_alphas[lora_module_name] = alpha
126
+
127
  print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
128
 
129
  # merge
 
145
  merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
146
  else:
147
  merged_sd[key] = lora_sd[key] * scale
148
+
149
  # set alpha to sd
150
  for lora_module_name, alpha in base_alphas.items():
151
  key = lora_module_name + ".alpha"
networks/resize_lora.py CHANGED
@@ -1,15 +1,14 @@
1
  # Convert LoRA to different rank approximation (should only be used to go to lower rank)
2
  # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
3
- # Thanks to cloneofsimo
4
 
5
  import argparse
 
6
  import torch
7
  from safetensors.torch import load_file, save_file, safe_open
8
  from tqdm import tqdm
9
  from library import train_util, model_util
10
- import numpy as np
11
 
12
- MIN_SV = 1e-6
13
 
14
  def load_state_dict(file_name, dtype):
15
  if model_util.is_safetensors(file_name):
@@ -39,149 +38,12 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
39
  torch.save(model, file_name)
40
 
41
 
42
- def index_sv_cumulative(S, target):
43
- original_sum = float(torch.sum(S))
44
- cumulative_sums = torch.cumsum(S, dim=0)/original_sum
45
- index = int(torch.searchsorted(cumulative_sums, target)) + 1
46
- if index >= len(S):
47
- index = len(S) - 1
48
-
49
- return index
50
-
51
-
52
- def index_sv_fro(S, target):
53
- S_squared = S.pow(2)
54
- s_fro_sq = float(torch.sum(S_squared))
55
- sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
56
- index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
57
- if index >= len(S):
58
- index = len(S) - 1
59
-
60
- return index
61
-
62
-
63
- # Modified from Kohaku-blueleaf's extract/merge functions
64
- def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
65
- out_size, in_size, kernel_size, _ = weight.size()
66
- U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
67
-
68
- param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
69
- lora_rank = param_dict["new_rank"]
70
-
71
- U = U[:, :lora_rank]
72
- S = S[:lora_rank]
73
- U = U @ torch.diag(S)
74
- Vh = Vh[:lora_rank, :]
75
-
76
- param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
77
- param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
78
- del U, S, Vh, weight
79
- return param_dict
80
-
81
-
82
- def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
83
- out_size, in_size = weight.size()
84
-
85
- U, S, Vh = torch.linalg.svd(weight.to(device))
86
-
87
- param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
88
- lora_rank = param_dict["new_rank"]
89
-
90
- U = U[:, :lora_rank]
91
- S = S[:lora_rank]
92
- U = U @ torch.diag(S)
93
- Vh = Vh[:lora_rank, :]
94
-
95
- param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
96
- param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
97
- del U, S, Vh, weight
98
- return param_dict
99
-
100
-
101
- def merge_conv(lora_down, lora_up, device):
102
- in_rank, in_size, kernel_size, k_ = lora_down.shape
103
- out_size, out_rank, _, _ = lora_up.shape
104
- assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
105
-
106
- lora_down = lora_down.to(device)
107
- lora_up = lora_up.to(device)
108
-
109
- merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
110
- weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
111
- del lora_up, lora_down
112
- return weight
113
-
114
-
115
- def merge_linear(lora_down, lora_up, device):
116
- in_rank, in_size = lora_down.shape
117
- out_size, out_rank = lora_up.shape
118
- assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
119
-
120
- lora_down = lora_down.to(device)
121
- lora_up = lora_up.to(device)
122
-
123
- weight = lora_up @ lora_down
124
- del lora_up, lora_down
125
- return weight
126
-
127
-
128
- def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
129
- param_dict = {}
130
-
131
- if dynamic_method=="sv_ratio":
132
- # Calculate new dim and alpha based off ratio
133
- max_sv = S[0]
134
- min_sv = max_sv/dynamic_param
135
- new_rank = max(torch.sum(S > min_sv).item(),1)
136
- new_alpha = float(scale*new_rank)
137
-
138
- elif dynamic_method=="sv_cumulative":
139
- # Calculate new dim and alpha based off cumulative sum
140
- new_rank = index_sv_cumulative(S, dynamic_param)
141
- new_rank = max(new_rank, 1)
142
- new_alpha = float(scale*new_rank)
143
-
144
- elif dynamic_method=="sv_fro":
145
- # Calculate new dim and alpha based off sqrt sum of squares
146
- new_rank = index_sv_fro(S, dynamic_param)
147
- new_rank = min(max(new_rank, 1), len(S)-1)
148
- new_alpha = float(scale*new_rank)
149
- else:
150
- new_rank = rank
151
- new_alpha = float(scale*new_rank)
152
-
153
-
154
- if S[0] <= MIN_SV: # Zero matrix, set dim to 1
155
- new_rank = 1
156
- new_alpha = float(scale*new_rank)
157
- elif new_rank > rank: # cap max rank at rank
158
- new_rank = rank
159
- new_alpha = float(scale*new_rank)
160
-
161
-
162
- # Calculate resize info
163
- s_sum = torch.sum(torch.abs(S))
164
- s_rank = torch.sum(torch.abs(S[:new_rank]))
165
-
166
- S_squared = S.pow(2)
167
- s_fro = torch.sqrt(torch.sum(S_squared))
168
- s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
169
- fro_percent = float(s_red_fro/s_fro)
170
-
171
- param_dict["new_rank"] = new_rank
172
- param_dict["new_alpha"] = new_alpha
173
- param_dict["sum_retained"] = (s_rank)/s_sum
174
- param_dict["fro_retained"] = fro_percent
175
- param_dict["max_ratio"] = S[0]/S[new_rank]
176
-
177
- return param_dict
178
-
179
-
180
- def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
181
  network_alpha = None
182
  network_dim = None
183
  verbose_str = "\n"
184
- fro_list = []
 
185
 
186
  # Extract loaded lora dim and alpha
187
  for key, value in lora_sd.items():
@@ -195,9 +57,9 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
195
  network_alpha = network_dim
196
 
197
  scale = network_alpha/network_dim
 
198
 
199
- if dynamic_method:
200
- print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
201
 
202
  lora_down_weight = None
203
  lora_up_weight = None
@@ -206,6 +68,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
206
  block_down_name = None
207
  block_up_name = None
208
 
 
209
  with torch.no_grad():
210
  for key, value in tqdm(lora_sd.items()):
211
  if 'lora_down' in key:
@@ -222,43 +85,57 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
222
  conv2d = (len(lora_down_weight.size()) == 4)
223
 
224
  if conv2d:
225
- full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
226
- param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
227
- else:
228
- full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
229
- param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
 
 
 
 
 
 
230
 
231
  if verbose:
232
- max_ratio = param_dict['max_ratio']
233
- sum_retained = param_dict['sum_retained']
234
- fro_retained = param_dict['fro_retained']
235
- if not np.isnan(fro_retained):
236
- fro_list.append(float(fro_retained))
237
 
238
- verbose_str+=f"{block_down_name:75} | "
239
- verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
 
240
 
241
- if verbose and dynamic_method:
242
- verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
243
- else:
244
- verbose_str+=f"\n"
245
 
246
- new_alpha = param_dict['new_alpha']
247
- o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
248
- o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
249
- o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
  block_down_name = None
252
  block_up_name = None
253
  lora_down_weight = None
254
  lora_up_weight = None
255
  weights_loaded = False
256
- del param_dict
257
 
258
  if verbose:
259
  print(verbose_str)
260
-
261
- print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
262
  print("resizing complete")
263
  return o_lora_sd, network_dim, new_alpha
264
 
@@ -274,9 +151,6 @@ def resize(args):
274
  return torch.bfloat16
275
  return None
276
 
277
- if args.dynamic_method and not args.dynamic_param:
278
- raise Exception("If using dynamic_method, then dynamic_param is required")
279
-
280
  merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
281
  save_dtype = str_to_dtype(args.save_precision)
282
  if save_dtype is None:
@@ -285,23 +159,17 @@ def resize(args):
285
  print("loading Model...")
286
  lora_sd, metadata = load_state_dict(args.model, merge_dtype)
287
 
288
- print("Resizing Lora...")
289
- state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
290
 
291
  # update metadata
292
  if metadata is None:
293
  metadata = {}
294
 
295
  comment = metadata.get("ss_training_comment", "")
296
-
297
- if not args.dynamic_method:
298
- metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
299
- metadata["ss_network_dim"] = str(args.new_rank)
300
- metadata["ss_network_alpha"] = str(new_alpha)
301
- else:
302
- metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
303
- metadata["ss_network_dim"] = 'Dynamic'
304
- metadata["ss_network_alpha"] = 'Dynamic'
305
 
306
  model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
307
  metadata["sshs_model_hash"] = model_hash
@@ -325,11 +193,6 @@ if __name__ == '__main__':
325
  parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
326
  parser.add_argument("--verbose", action="store_true",
327
  help="Display verbose resizing information / rank変更時の詳細情報を出力する")
328
- parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
329
- help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
330
- parser.add_argument("--dynamic_param", type=float, default=None,
331
- help="Specify target for dynamic reduction")
332
-
333
 
334
  args = parser.parse_args()
335
  resize(args)
 
1
  # Convert LoRA to different rank approximation (should only be used to go to lower rank)
2
  # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
3
+ # Thanks to cloneofsimo and kohya
4
 
5
  import argparse
6
+ import os
7
  import torch
8
  from safetensors.torch import load_file, save_file, safe_open
9
  from tqdm import tqdm
10
  from library import train_util, model_util
 
11
 
 
12
 
13
  def load_state_dict(file_name, dtype):
14
  if model_util.is_safetensors(file_name):
 
38
  torch.save(model, file_name)
39
 
40
 
41
+ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  network_alpha = None
43
  network_dim = None
44
  verbose_str = "\n"
45
+
46
+ CLAMP_QUANTILE = 0.99
47
 
48
  # Extract loaded lora dim and alpha
49
  for key, value in lora_sd.items():
 
57
  network_alpha = network_dim
58
 
59
  scale = network_alpha/network_dim
60
+ new_alpha = float(scale*new_rank) # calculate new alpha from scale
61
 
62
+ print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}")
 
63
 
64
  lora_down_weight = None
65
  lora_up_weight = None
 
68
  block_down_name = None
69
  block_up_name = None
70
 
71
+ print("resizing lora...")
72
  with torch.no_grad():
73
  for key, value in tqdm(lora_sd.items()):
74
  if 'lora_down' in key:
 
85
  conv2d = (len(lora_down_weight.size()) == 4)
86
 
87
  if conv2d:
88
+ lora_down_weight = lora_down_weight.squeeze()
89
+ lora_up_weight = lora_up_weight.squeeze()
90
+
91
+ if device:
92
+ org_device = lora_up_weight.device
93
+ lora_up_weight = lora_up_weight.to(args.device)
94
+ lora_down_weight = lora_down_weight.to(args.device)
95
+
96
+ full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
97
+
98
+ U, S, Vh = torch.linalg.svd(full_weight_matrix)
99
 
100
  if verbose:
101
+ s_sum = torch.sum(torch.abs(S))
102
+ s_rank = torch.sum(torch.abs(S[:new_rank]))
103
+ verbose_str+=f"{block_down_name:76} | "
104
+ verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n"
 
105
 
106
+ U = U[:, :new_rank]
107
+ S = S[:new_rank]
108
+ U = U @ torch.diag(S)
109
 
110
+ Vh = Vh[:new_rank, :]
 
 
 
111
 
112
+ dist = torch.cat([U.flatten(), Vh.flatten()])
113
+ hi_val = torch.quantile(dist, CLAMP_QUANTILE)
114
+ low_val = -hi_val
115
+
116
+ U = U.clamp(low_val, hi_val)
117
+ Vh = Vh.clamp(low_val, hi_val)
118
+
119
+ if conv2d:
120
+ U = U.unsqueeze(2).unsqueeze(3)
121
+ Vh = Vh.unsqueeze(2).unsqueeze(3)
122
+
123
+ if device:
124
+ U = U.to(org_device)
125
+ Vh = Vh.to(org_device)
126
+
127
+ o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
128
+ o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
129
+ o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
130
 
131
  block_down_name = None
132
  block_up_name = None
133
  lora_down_weight = None
134
  lora_up_weight = None
135
  weights_loaded = False
 
136
 
137
  if verbose:
138
  print(verbose_str)
 
 
139
  print("resizing complete")
140
  return o_lora_sd, network_dim, new_alpha
141
 
 
151
  return torch.bfloat16
152
  return None
153
 
 
 
 
154
  merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
155
  save_dtype = str_to_dtype(args.save_precision)
156
  if save_dtype is None:
 
159
  print("loading Model...")
160
  lora_sd, metadata = load_state_dict(args.model, merge_dtype)
161
 
162
+ print("resizing rank...")
163
+ state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose)
164
 
165
  # update metadata
166
  if metadata is None:
167
  metadata = {}
168
 
169
  comment = metadata.get("ss_training_comment", "")
170
+ metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
171
+ metadata["ss_network_dim"] = str(args.new_rank)
172
+ metadata["ss_network_alpha"] = str(new_alpha)
 
 
 
 
 
 
173
 
174
  model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
175
  metadata["sshs_model_hash"] = model_hash
 
193
  parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
194
  parser.add_argument("--verbose", action="store_true",
195
  help="Display verbose resizing information / rank変更時の詳細情報を出力する")
 
 
 
 
 
196
 
197
  args = parser.parse_args()
198
  resize(args)
networks/svd_merge_lora.py CHANGED
@@ -23,20 +23,19 @@ def load_state_dict(file_name, dtype):
23
  return sd
24
 
25
 
26
- def save_to_file(file_name, state_dict, dtype):
27
  if dtype is not None:
28
  for key in list(state_dict.keys()):
29
  if type(state_dict[key]) == torch.Tensor:
30
  state_dict[key] = state_dict[key].to(dtype)
31
 
32
  if os.path.splitext(file_name)[1] == '.safetensors':
33
- save_file(state_dict, file_name)
34
  else:
35
- torch.save(state_dict, file_name)
36
 
37
 
38
- def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
39
- print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
40
  merged_sd = {}
41
  for model, ratio in zip(models, ratios):
42
  print(f"loading: {model}")
@@ -59,12 +58,11 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
59
  in_dim = down_weight.size()[1]
60
  out_dim = up_weight.size()[0]
61
  conv2d = len(down_weight.size()) == 4
62
- kernel_size = None if not conv2d else down_weight.size()[2:4]
63
- # print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
64
 
65
  # make original weight if not exist
66
  if lora_module_name not in merged_sd:
67
- weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
68
  if device:
69
  weight = weight.to(device)
70
  else:
@@ -77,18 +75,11 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
77
 
78
  # W <- W + U * D
79
  scale = (alpha / network_dim)
80
-
81
- if device: # and isinstance(scale, torch.Tensor):
82
- scale = scale.to(device)
83
-
84
  if not conv2d: # linear
85
  weight = weight + ratio * (up_weight @ down_weight) * scale
86
- elif kernel_size == (1, 1):
87
  weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
88
  ).unsqueeze(2).unsqueeze(3) * scale
89
- else:
90
- conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
91
- weight = weight + ratio * conved * scale
92
 
93
  merged_sd[lora_module_name] = weight
94
 
@@ -98,26 +89,16 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
98
  with torch.no_grad():
99
  for lora_module_name, mat in tqdm(list(merged_sd.items())):
100
  conv2d = (len(mat.size()) == 4)
101
- kernel_size = None if not conv2d else mat.size()[2:4]
102
- conv2d_3x3 = conv2d and kernel_size != (1, 1)
103
- out_dim, in_dim = mat.size()[0:2]
104
-
105
  if conv2d:
106
- if conv2d_3x3:
107
- mat = mat.flatten(start_dim=1)
108
- else:
109
- mat = mat.squeeze()
110
-
111
- module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
112
- module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
113
 
114
  U, S, Vh = torch.linalg.svd(mat)
115
 
116
- U = U[:, :module_new_rank]
117
- S = S[:module_new_rank]
118
  U = U @ torch.diag(S)
119
 
120
- Vh = Vh[:module_new_rank, :]
121
 
122
  dist = torch.cat([U.flatten(), Vh.flatten()])
123
  hi_val = torch.quantile(dist, CLAMP_QUANTILE)
@@ -126,16 +107,16 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
126
  U = U.clamp(low_val, hi_val)
127
  Vh = Vh.clamp(low_val, hi_val)
128
 
129
- if conv2d:
130
- U = U.reshape(out_dim, module_new_rank, 1, 1)
131
- Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1])
132
-
133
  up_weight = U
134
  down_weight = Vh
135
 
 
 
 
 
136
  merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
137
  merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
138
- merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(module_new_rank)
139
 
140
  return merged_lora_sd
141
 
@@ -157,11 +138,10 @@ def merge(args):
157
  if save_dtype is None:
158
  save_dtype = merge_dtype
159
 
160
- new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
161
- state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype)
162
 
163
  print(f"saving model to: {args.save_to}")
164
- save_to_file(args.save_to, state_dict, save_dtype)
165
 
166
 
167
  if __name__ == '__main__':
@@ -178,8 +158,6 @@ if __name__ == '__main__':
178
  help="ratios for each model / それぞれのLoRAモデルの比率")
179
  parser.add_argument("--new_rank", type=int, default=4,
180
  help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
181
- parser.add_argument("--new_conv_rank", type=int, default=None,
182
- help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ")
183
  parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
184
 
185
  args = parser.parse_args()
 
23
  return sd
24
 
25
 
26
+ def save_to_file(file_name, model, state_dict, dtype):
27
  if dtype is not None:
28
  for key in list(state_dict.keys()):
29
  if type(state_dict[key]) == torch.Tensor:
30
  state_dict[key] = state_dict[key].to(dtype)
31
 
32
  if os.path.splitext(file_name)[1] == '.safetensors':
33
+ save_file(model, file_name)
34
  else:
35
+ torch.save(model, file_name)
36
 
37
 
38
+ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
 
39
  merged_sd = {}
40
  for model, ratio in zip(models, ratios):
41
  print(f"loading: {model}")
 
58
  in_dim = down_weight.size()[1]
59
  out_dim = up_weight.size()[0]
60
  conv2d = len(down_weight.size()) == 4
61
+ print(lora_module_name, network_dim, alpha, in_dim, out_dim)
 
62
 
63
  # make original weight if not exist
64
  if lora_module_name not in merged_sd:
65
+ weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
66
  if device:
67
  weight = weight.to(device)
68
  else:
 
75
 
76
  # W <- W + U * D
77
  scale = (alpha / network_dim)
 
 
 
 
78
  if not conv2d: # linear
79
  weight = weight + ratio * (up_weight @ down_weight) * scale
80
+ else:
81
  weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
82
  ).unsqueeze(2).unsqueeze(3) * scale
 
 
 
83
 
84
  merged_sd[lora_module_name] = weight
85
 
 
89
  with torch.no_grad():
90
  for lora_module_name, mat in tqdm(list(merged_sd.items())):
91
  conv2d = (len(mat.size()) == 4)
 
 
 
 
92
  if conv2d:
93
+ mat = mat.squeeze()
 
 
 
 
 
 
94
 
95
  U, S, Vh = torch.linalg.svd(mat)
96
 
97
+ U = U[:, :new_rank]
98
+ S = S[:new_rank]
99
  U = U @ torch.diag(S)
100
 
101
+ Vh = Vh[:new_rank, :]
102
 
103
  dist = torch.cat([U.flatten(), Vh.flatten()])
104
  hi_val = torch.quantile(dist, CLAMP_QUANTILE)
 
107
  U = U.clamp(low_val, hi_val)
108
  Vh = Vh.clamp(low_val, hi_val)
109
 
 
 
 
 
110
  up_weight = U
111
  down_weight = Vh
112
 
113
+ if conv2d:
114
+ up_weight = up_weight.unsqueeze(2).unsqueeze(3)
115
+ down_weight = down_weight.unsqueeze(2).unsqueeze(3)
116
+
117
  merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
118
  merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
119
+ merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank)
120
 
121
  return merged_lora_sd
122
 
 
138
  if save_dtype is None:
139
  save_dtype = merge_dtype
140
 
141
+ state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, args.device, merge_dtype)
 
142
 
143
  print(f"saving model to: {args.save_to}")
144
+ save_to_file(args.save_to, state_dict, state_dict, save_dtype)
145
 
146
 
147
  if __name__ == '__main__':
 
158
  help="ratios for each model / それぞれのLoRAモデルの比率")
159
  parser.add_argument("--new_rank", type=int, default=4,
160
  help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
 
 
161
  parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
162
 
163
  args = parser.parse_args()
requirements.txt CHANGED
@@ -12,8 +12,6 @@ safetensors==0.2.6
12
  gradio==3.16.2
13
  altair==4.2.2
14
  easygui==0.98.3
15
- toml==0.10.2
16
- voluptuous==0.13.1
17
  # for BLIP captioning
18
  requests==2.28.2
19
  timm==0.6.12
@@ -23,4 +21,5 @@ fairscale==0.4.13
23
  tensorflow==2.10.1
24
  huggingface-hub==0.12.0
25
  # for kohya_ss library
 
26
  .
 
12
  gradio==3.16.2
13
  altair==4.2.2
14
  easygui==0.98.3
 
 
15
  # for BLIP captioning
16
  requests==2.28.2
17
  timm==0.6.12
 
21
  tensorflow==2.10.1
22
  huggingface-hub==0.12.0
23
  # for kohya_ss library
24
+ #locon.locon_kohya
25
  .
requirements_startup.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.15.0
2
+ transformers==4.26.0
3
+ ftfy==6.1.1
4
+ albumentations==1.3.0
5
+ opencv-python==4.7.0.68
6
+ einops==0.6.0
7
+ diffusers[torch]==0.10.2
8
+ pytorch-lightning==1.9.0
9
+ bitsandbytes==0.35.0
10
+ tensorboard==2.10.1
11
+ safetensors==0.2.6
12
+ gradio==3.18.0
13
+ altair==4.2.2
14
+ easygui==0.98.3
15
+ # for BLIP captioning
16
+ requests==2.28.2
17
+ timm==0.4.12
18
+ fairscale==0.4.4
19
+ # for WD14 captioning
20
+ tensorflow==2.10.1
21
+ huggingface-hub==0.12.0
22
+ # for kohya_ss library
23
+ .
train_db.py CHANGED
@@ -15,11 +15,7 @@ import diffusers
15
  from diffusers import DDPMScheduler
16
 
17
  import library.train_util as train_util
18
- import library.config_util as config_util
19
- from library.config_util import (
20
- ConfigSanitizer,
21
- BlueprintGenerator,
22
- )
23
 
24
 
25
  def collate_fn(examples):
@@ -37,33 +33,24 @@ def train(args):
37
 
38
  tokenizer = train_util.load_tokenizer(args)
39
 
40
- blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
41
- if args.dataset_config is not None:
42
- print(f"Load dataset config from {args.dataset_config}")
43
- user_config = config_util.load_user_config(args.dataset_config)
44
- ignored = ["train_data_dir", "reg_data_dir"]
45
- if any(getattr(args, attr) is not None for attr in ignored):
46
- print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
47
- else:
48
- user_config = {
49
- "datasets": [{
50
- "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
51
- }]
52
- }
53
-
54
- blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
55
- train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
56
 
57
  if args.no_token_padding:
58
- train_dataset_group.disable_token_padding()
 
 
 
 
 
59
 
60
  if args.debug_dataset:
61
- train_util.debug_dataset(train_dataset_group)
62
  return
63
 
64
- if cache_latents:
65
- assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
66
-
67
  # acceleratorを準備する
68
  print("prepare accelerator")
69
 
@@ -104,7 +91,7 @@ def train(args):
104
  vae.requires_grad_(False)
105
  vae.eval()
106
  with torch.no_grad():
107
- train_dataset_group.cache_latents(vae)
108
  vae.to("cpu")
109
  if torch.cuda.is_available():
110
  torch.cuda.empty_cache()
@@ -128,18 +115,38 @@ def train(args):
128
 
129
  # 学習に必要なクラスを準備する
130
  print("prepare optimizer, data loader etc.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  if train_text_encoder:
132
  trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
133
  else:
134
  trainable_params = unet.parameters()
135
 
136
- _, _, optimizer = train_util.get_optimizer(args, trainable_params)
 
137
 
138
  # dataloaderを準備する
139
  # DataLoaderのプロセス数:0はメインプロセスになる
140
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
141
  train_dataloader = torch.utils.data.DataLoader(
142
- train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
143
 
144
  # 学習ステップ数を計算する
145
  if args.max_train_epochs is not None:
@@ -149,10 +156,9 @@ def train(args):
149
  if args.stop_text_encoder_training is None:
150
  args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
151
 
152
- # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
153
- lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
154
- num_training_steps=args.max_train_steps,
155
- num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
156
 
157
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
158
  if args.full_fp16:
@@ -189,8 +195,8 @@ def train(args):
189
  # 学習する
190
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
191
  print("running training / 学習開始")
192
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
193
- print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
194
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
195
  print(f" num epochs / epoch数: {num_train_epochs}")
196
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
@@ -211,7 +217,7 @@ def train(args):
211
  loss_total = 0.0
212
  for epoch in range(num_train_epochs):
213
  print(f"epoch {epoch+1}/{num_train_epochs}")
214
- train_dataset_group.set_current_epoch(epoch + 1)
215
 
216
  # 指定したステップ数までText Encoderを学習する:epoch最初の状態
217
  unet.train()
@@ -275,12 +281,12 @@ def train(args):
275
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
276
 
277
  accelerator.backward(loss)
278
- if accelerator.sync_gradients and args.max_grad_norm != 0.0:
279
  if train_text_encoder:
280
  params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
281
  else:
282
  params_to_clip = unet.parameters()
283
- accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
284
 
285
  optimizer.step()
286
  lr_scheduler.step()
@@ -291,13 +297,9 @@ def train(args):
291
  progress_bar.update(1)
292
  global_step += 1
293
 
294
- train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
295
-
296
  current_loss = loss.detach().item()
297
  if args.logging_dir is not None:
298
- logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
299
- if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
300
- logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
301
  accelerator.log(logs, step=global_step)
302
 
303
  if epoch == 0:
@@ -324,8 +326,6 @@ def train(args):
324
  train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
325
  save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
326
 
327
- train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
328
-
329
  is_main_process = accelerator.is_main_process
330
  if is_main_process:
331
  unet = unwrap_model(unet)
@@ -352,8 +352,6 @@ if __name__ == '__main__':
352
  train_util.add_dataset_arguments(parser, True, False, True)
353
  train_util.add_training_arguments(parser, True)
354
  train_util.add_sd_saving_arguments(parser)
355
- train_util.add_optimizer_arguments(parser)
356
- config_util.add_config_arguments(parser)
357
 
358
  parser.add_argument("--no_token_padding", action="store_true",
359
  help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
 
15
  from diffusers import DDPMScheduler
16
 
17
  import library.train_util as train_util
18
+ from library.train_util import DreamBoothDataset
 
 
 
 
19
 
20
 
21
  def collate_fn(examples):
 
33
 
34
  tokenizer = train_util.load_tokenizer(args)
35
 
36
+ train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
37
+ tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
38
+ args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
39
+ args.bucket_reso_steps, args.bucket_no_upscale,
40
+ args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  if args.no_token_padding:
43
+ train_dataset.disable_token_padding()
44
+
45
+ # 学習データのdropout率を設定する
46
+ train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
47
+
48
+ train_dataset.make_buckets()
49
 
50
  if args.debug_dataset:
51
+ train_util.debug_dataset(train_dataset)
52
  return
53
 
 
 
 
54
  # acceleratorを準備する
55
  print("prepare accelerator")
56
 
 
91
  vae.requires_grad_(False)
92
  vae.eval()
93
  with torch.no_grad():
94
+ train_dataset.cache_latents(vae)
95
  vae.to("cpu")
96
  if torch.cuda.is_available():
97
  torch.cuda.empty_cache()
 
115
 
116
  # 学習に必要なクラスを準備する
117
  print("prepare optimizer, data loader etc.")
118
+
119
+ # 8-bit Adamを使う
120
+ if args.use_8bit_adam:
121
+ try:
122
+ import bitsandbytes as bnb
123
+ except ImportError:
124
+ raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
125
+ print("use 8-bit Adam optimizer")
126
+ optimizer_class = bnb.optim.AdamW8bit
127
+ elif args.use_lion_optimizer:
128
+ try:
129
+ import lion_pytorch
130
+ except ImportError:
131
+ raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
132
+ print("use Lion optimizer")
133
+ optimizer_class = lion_pytorch.Lion
134
+ else:
135
+ optimizer_class = torch.optim.AdamW
136
+
137
  if train_text_encoder:
138
  trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
139
  else:
140
  trainable_params = unet.parameters()
141
 
142
+ # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
143
+ optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
144
 
145
  # dataloaderを準備する
146
  # DataLoaderのプロセス数:0はメインプロセスになる
147
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
148
  train_dataloader = torch.utils.data.DataLoader(
149
+ train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
150
 
151
  # 学習ステップ数を計算する
152
  if args.max_train_epochs is not None:
 
156
  if args.stop_text_encoder_training is None:
157
  args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
158
 
159
+ # lr schedulerを用意する
160
+ lr_scheduler = diffusers.optimization.get_scheduler(
161
+ args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
 
162
 
163
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
164
  if args.full_fp16:
 
195
  # 学習する
196
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
197
  print("running training / 学習開始")
198
+ print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
199
+ print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
200
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
201
  print(f" num epochs / epoch数: {num_train_epochs}")
202
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
 
217
  loss_total = 0.0
218
  for epoch in range(num_train_epochs):
219
  print(f"epoch {epoch+1}/{num_train_epochs}")
220
+ train_dataset.set_current_epoch(epoch + 1)
221
 
222
  # 指定したステップ数までText Encoderを学習する:epoch最初の状態
223
  unet.train()
 
281
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
282
 
283
  accelerator.backward(loss)
284
+ if accelerator.sync_gradients:
285
  if train_text_encoder:
286
  params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
287
  else:
288
  params_to_clip = unet.parameters()
289
+ accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
290
 
291
  optimizer.step()
292
  lr_scheduler.step()
 
297
  progress_bar.update(1)
298
  global_step += 1
299
 
 
 
300
  current_loss = loss.detach().item()
301
  if args.logging_dir is not None:
302
+ logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
 
 
303
  accelerator.log(logs, step=global_step)
304
 
305
  if epoch == 0:
 
326
  train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
327
  save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
328
 
 
 
329
  is_main_process = accelerator.is_main_process
330
  if is_main_process:
331
  unet = unwrap_model(unet)
 
352
  train_util.add_dataset_arguments(parser, True, False, True)
353
  train_util.add_training_arguments(parser, True)
354
  train_util.add_sd_saving_arguments(parser)
 
 
355
 
356
  parser.add_argument("--no_token_padding", action="store_true",
357
  help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
train_network.py CHANGED
@@ -1,4 +1,8 @@
 
 
 
1
  from torch.nn.parallel import DistributedDataParallel as DDP
 
2
  import importlib
3
  import argparse
4
  import gc
@@ -11,41 +15,94 @@ import json
11
  from tqdm import tqdm
12
  import torch
13
  from accelerate.utils import set_seed
 
14
  from diffusers import DDPMScheduler
15
 
16
  import library.train_util as train_util
17
- from library.train_util import (
18
- DreamBoothDataset,
19
- )
20
- import library.config_util as config_util
21
- from library.config_util import (
22
- ConfigSanitizer,
23
- BlueprintGenerator,
24
- )
25
 
26
 
27
  def collate_fn(examples):
28
  return examples[0]
29
 
30
 
31
- # TODO 他のスクリプトと共通化する
32
  def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
33
  logs = {"loss/current": current_loss, "loss/average": avr_loss}
34
 
35
  if args.network_train_unet_only:
36
- logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
37
  elif args.network_train_text_encoder_only:
38
- logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
39
  else:
40
- logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
41
- logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
42
-
43
- if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
44
- logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
45
 
46
  return logs
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def train(args):
50
  session_id = random.randint(0, 2**32)
51
  training_started_at = time.time()
@@ -54,7 +111,6 @@ def train(args):
54
 
55
  cache_latents = args.cache_latents
56
  use_dreambooth_method = args.in_json is None
57
- use_user_config = args.dataset_config is not None
58
 
59
  if args.seed is not None:
60
  set_seed(args.seed)
@@ -62,51 +118,38 @@ def train(args):
62
  tokenizer = train_util.load_tokenizer(args)
63
 
64
  # データセットを準備する
65
- blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
66
- if use_user_config:
67
- print(f"Load dataset config from {args.dataset_config}")
68
- user_config = config_util.load_user_config(args.dataset_config)
69
- ignored = ["train_data_dir", "reg_data_dir", "in_json"]
70
- if any(getattr(args, attr) is not None for attr in ignored):
71
- print(
72
- "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
73
  else:
74
- if use_dreambooth_method:
75
- print("Use DreamBooth method.")
76
- user_config = {
77
- "datasets": [{
78
- "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
79
- }]
80
- }
81
- else:
82
- print("Train with captions.")
83
- user_config = {
84
- "datasets": [{
85
- "subsets": [{
86
- "image_dir": args.train_data_dir,
87
- "metadata_file": args.in_json,
88
- }]
89
- }]
90
- }
91
-
92
- blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
93
- train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
94
 
95
  if args.debug_dataset:
96
- train_util.debug_dataset(train_dataset_group)
97
  return
98
- if len(train_dataset_group) == 0:
99
  print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
100
  return
101
 
102
- if cache_latents:
103
- assert train_dataset_group.is_latent_cacheable(
104
- ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
105
-
106
  # acceleratorを準備する
107
  print("prepare accelerator")
108
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
109
- is_main_process = accelerator.is_main_process
110
 
111
  # mixed precisionに対応した型を用意しておき適宜castする
112
  weight_dtype, save_dtype = train_util.prepare_dtype(args)
@@ -118,7 +161,7 @@ def train(args):
118
  if args.lowram:
119
  text_encoder.to("cuda")
120
  unet.to("cuda")
121
-
122
  # モデルに xformers とか memory efficient attention を組み込む
123
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
124
 
@@ -128,15 +171,13 @@ def train(args):
128
  vae.requires_grad_(False)
129
  vae.eval()
130
  with torch.no_grad():
131
- train_dataset_group.cache_latents(vae)
132
  vae.to("cpu")
133
  if torch.cuda.is_available():
134
  torch.cuda.empty_cache()
135
  gc.collect()
136
 
137
  # prepare network
138
- import sys
139
- sys.path.append(os.path.dirname(__file__))
140
  print("import network module:", args.network_module)
141
  network_module = importlib.import_module(args.network_module)
142
 
@@ -167,25 +208,48 @@ def train(args):
167
  # 学習に必要なクラスを準備する
168
  print("prepare optimizer, data loader etc.")
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
171
- optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
 
 
172
 
173
  # dataloaderを準備する
174
  # DataLoaderのプロセス数:0はメインプロセスになる
175
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
176
  train_dataloader = torch.utils.data.DataLoader(
177
- train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
178
 
179
  # 学習ステップ数を計算する
180
  if args.max_train_epochs is not None:
181
- args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes)
182
- if is_main_process:
183
- print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
184
 
185
  # lr schedulerを用意する
186
- lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
187
- num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps,
188
- num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
 
 
189
 
190
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
191
  if args.full_fp16:
@@ -253,21 +317,17 @@ def train(args):
253
  args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
254
 
255
  # 学習する
256
- # TODO: find a way to handle total batch size when there are multiple datasets
257
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
258
-
259
- if is_main_process:
260
- print("running training / 学習開始")
261
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
262
- print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
263
- print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
264
- print(f" num epochs / epoch数: {num_train_epochs}")
265
- print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
266
- # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
267
- print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
268
- print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
269
-
270
- # TODO refactor metadata creation and move to util
271
  metadata = {
272
  "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
273
  "ss_training_started_at": training_started_at, # unix timestamp
@@ -275,10 +335,12 @@ def train(args):
275
  "ss_learning_rate": args.learning_rate,
276
  "ss_text_encoder_lr": args.text_encoder_lr,
277
  "ss_unet_lr": args.unet_lr,
278
- "ss_num_train_images": train_dataset_group.num_train_images,
279
- "ss_num_reg_images": train_dataset_group.num_reg_images,
280
  "ss_num_batches_per_epoch": len(train_dataloader),
281
  "ss_num_epochs": num_train_epochs,
 
 
282
  "ss_gradient_checkpointing": args.gradient_checkpointing,
283
  "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
284
  "ss_max_train_steps": args.max_train_steps,
@@ -290,156 +352,33 @@ def train(args):
290
  "ss_mixed_precision": args.mixed_precision,
291
  "ss_full_fp16": bool(args.full_fp16),
292
  "ss_v2": bool(args.v2),
 
293
  "ss_clip_skip": args.clip_skip,
294
  "ss_max_token_length": args.max_token_length,
 
 
 
 
295
  "ss_cache_latents": bool(args.cache_latents),
 
 
 
296
  "ss_seed": args.seed,
297
- "ss_lowram": args.lowram,
298
  "ss_noise_offset": args.noise_offset,
 
 
 
 
299
  "ss_training_comment": args.training_comment, # will not be updated after training
300
  "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
301
- "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
302
- "ss_max_grad_norm": args.max_grad_norm,
303
- "ss_caption_dropout_rate": args.caption_dropout_rate,
304
- "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
305
- "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
306
- "ss_face_crop_aug_range": args.face_crop_aug_range,
307
- "ss_prior_loss_weight": args.prior_loss_weight,
308
  }
309
 
310
- if use_user_config:
311
- # save metadata of multiple datasets
312
- # NOTE: pack "ss_datasets" value as json one time
313
- # or should also pack nested collections as json?
314
- datasets_metadata = []
315
- tag_frequency = {} # merge tag frequency for metadata editor
316
- dataset_dirs_info = {} # merge subset dirs for metadata editor
317
-
318
- for dataset in train_dataset_group.datasets:
319
- is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
320
- dataset_metadata = {
321
- "is_dreambooth": is_dreambooth_dataset,
322
- "batch_size_per_device": dataset.batch_size,
323
- "num_train_images": dataset.num_train_images, # includes repeating
324
- "num_reg_images": dataset.num_reg_images,
325
- "resolution": (dataset.width, dataset.height),
326
- "enable_bucket": bool(dataset.enable_bucket),
327
- "min_bucket_reso": dataset.min_bucket_reso,
328
- "max_bucket_reso": dataset.max_bucket_reso,
329
- "tag_frequency": dataset.tag_frequency,
330
- "bucket_info": dataset.bucket_info,
331
- }
332
-
333
- subsets_metadata = []
334
- for subset in dataset.subsets:
335
- subset_metadata = {
336
- "img_count": subset.img_count,
337
- "num_repeats": subset.num_repeats,
338
- "color_aug": bool(subset.color_aug),
339
- "flip_aug": bool(subset.flip_aug),
340
- "random_crop": bool(subset.random_crop),
341
- "shuffle_caption": bool(subset.shuffle_caption),
342
- "keep_tokens": subset.keep_tokens,
343
- }
344
-
345
- image_dir_or_metadata_file = None
346
- if subset.image_dir:
347
- image_dir = os.path.basename(subset.image_dir)
348
- subset_metadata["image_dir"] = image_dir
349
- image_dir_or_metadata_file = image_dir
350
-
351
- if is_dreambooth_dataset:
352
- subset_metadata["class_tokens"] = subset.class_tokens
353
- subset_metadata["is_reg"] = subset.is_reg
354
- if subset.is_reg:
355
- image_dir_or_metadata_file = None # not merging reg dataset
356
- else:
357
- metadata_file = os.path.basename(subset.metadata_file)
358
- subset_metadata["metadata_file"] = metadata_file
359
- image_dir_or_metadata_file = metadata_file # may overwrite
360
-
361
- subsets_metadata.append(subset_metadata)
362
-
363
- # merge dataset dir: not reg subset only
364
- # TODO update additional-network extension to show detailed dataset config from metadata
365
- if image_dir_or_metadata_file is not None:
366
- # datasets may have a certain dir multiple times
367
- v = image_dir_or_metadata_file
368
- i = 2
369
- while v in dataset_dirs_info:
370
- v = image_dir_or_metadata_file + f" ({i})"
371
- i += 1
372
- image_dir_or_metadata_file = v
373
-
374
- dataset_dirs_info[image_dir_or_metadata_file] = {
375
- "n_repeats": subset.num_repeats,
376
- "img_count": subset.img_count
377
- }
378
-
379
- dataset_metadata["subsets"] = subsets_metadata
380
- datasets_metadata.append(dataset_metadata)
381
-
382
- # merge tag frequency:
383
- for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
384
- # あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
385
- # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
386
- # なので、ここで複数datasetの回数を合算してもあまり意味はない
387
- if ds_dir_name in tag_frequency:
388
- continue
389
- tag_frequency[ds_dir_name] = ds_freq_for_dir
390
-
391
- metadata["ss_datasets"] = json.dumps(datasets_metadata)
392
- metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
393
- metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
394
- else:
395
- # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
396
- assert len(
397
- train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
398
-
399
- dataset = train_dataset_group.datasets[0]
400
-
401
- dataset_dirs_info = {}
402
- reg_dataset_dirs_info = {}
403
- if use_dreambooth_method:
404
- for subset in dataset.subsets:
405
- info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
406
- info[os.path.basename(subset.image_dir)] = {
407
- "n_repeats": subset.num_repeats,
408
- "img_count": subset.img_count
409
- }
410
- else:
411
- for subset in dataset.subsets:
412
- dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
413
- "n_repeats": subset.num_repeats,
414
- "img_count": subset.img_count
415
- }
416
-
417
- metadata.update({
418
- "ss_batch_size_per_device": args.train_batch_size,
419
- "ss_total_batch_size": total_batch_size,
420
- "ss_resolution": args.resolution,
421
- "ss_color_aug": bool(args.color_aug),
422
- "ss_flip_aug": bool(args.flip_aug),
423
- "ss_random_crop": bool(args.random_crop),
424
- "ss_shuffle_caption": bool(args.shuffle_caption),
425
- "ss_enable_bucket": bool(dataset.enable_bucket),
426
- "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
427
- "ss_min_bucket_reso": dataset.min_bucket_reso,
428
- "ss_max_bucket_reso": dataset.max_bucket_reso,
429
- "ss_keep_tokens": args.keep_tokens,
430
- "ss_dataset_dirs": json.dumps(dataset_dirs_info),
431
- "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
432
- "ss_tag_frequency": json.dumps(dataset.tag_frequency),
433
- "ss_bucket_info": json.dumps(dataset.bucket_info),
434
- })
435
-
436
- # add extra args
437
- if args.network_args:
438
- metadata["ss_network_args"] = json.dumps(net_kwargs)
439
- # for key, value in net_kwargs.items():
440
- # metadata["ss_arg_" + key] = value
441
-
442
- # model name and hash
443
  if args.pretrained_model_name_or_path is not None:
444
  sd_model_name = args.pretrained_model_name_or_path
445
  if os.path.exists(sd_model_name):
@@ -458,13 +397,6 @@ def train(args):
458
 
459
  metadata = {k: str(v) for k, v in metadata.items()}
460
 
461
- # make minimum metadata for filtering
462
- minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"]
463
- minimum_metadata = {}
464
- for key in minimum_keys:
465
- if key in metadata:
466
- minimum_metadata[key] = metadata[key]
467
-
468
  progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
469
  global_step = 0
470
 
@@ -477,9 +409,8 @@ def train(args):
477
  loss_list = []
478
  loss_total = 0.0
479
  for epoch in range(num_train_epochs):
480
- if is_main_process:
481
- print(f"epoch {epoch+1}/{num_train_epochs}")
482
- train_dataset_group.set_current_epoch(epoch + 1)
483
 
484
  metadata["ss_epoch"] = str(epoch+1)
485
 
@@ -516,7 +447,7 @@ def train(args):
516
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
517
 
518
  # Predict the noise residual
519
- with accelerator.autocast():
520
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
521
 
522
  if args.v_parameterization:
@@ -534,9 +465,9 @@ def train(args):
534
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
535
 
536
  accelerator.backward(loss)
537
- if accelerator.sync_gradients and args.max_grad_norm != 0.0:
538
  params_to_clip = network.get_trainable_params()
539
- accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
540
 
541
  optimizer.step()
542
  lr_scheduler.step()
@@ -547,8 +478,6 @@ def train(args):
547
  progress_bar.update(1)
548
  global_step += 1
549
 
550
- train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
551
-
552
  current_loss = loss.detach().item()
553
  if epoch == 0:
554
  loss_list.append(current_loss)
@@ -579,9 +508,8 @@ def train(args):
579
  def save_func():
580
  ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
581
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
582
- metadata["ss_training_finished_at"] = str(time.time())
583
  print(f"saving checkpoint: {ckpt_file}")
584
- unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
585
 
586
  def remove_old_func(old_epoch_no):
587
  old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
@@ -590,18 +518,15 @@ def train(args):
590
  print(f"removing old checkpoint: {old_ckpt_file}")
591
  os.remove(old_ckpt_file)
592
 
593
- if is_main_process:
594
- saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
595
- if saving and args.save_state:
596
- train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
597
-
598
- train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
599
 
600
  # end of epoch
601
 
602
  metadata["ss_epoch"] = str(num_train_epochs)
603
- metadata["ss_training_finished_at"] = str(time.time())
604
 
 
605
  if is_main_process:
606
  network = unwrap_model(network)
607
 
@@ -620,7 +545,7 @@ def train(args):
620
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
621
 
622
  print(f"save trained model to {ckpt_file}")
623
- network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
624
  print("model saved.")
625
 
626
 
@@ -630,8 +555,6 @@ if __name__ == '__main__':
630
  train_util.add_sd_models_arguments(parser)
631
  train_util.add_dataset_arguments(parser, True, True, True)
632
  train_util.add_training_arguments(parser, True)
633
- train_util.add_optimizer_arguments(parser)
634
- config_util.add_config_arguments(parser)
635
 
636
  parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
637
  parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
@@ -639,6 +562,10 @@ if __name__ == '__main__':
639
 
640
  parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
641
  parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
 
 
 
 
642
 
643
  parser.add_argument("--network_weights", type=str, default=None,
644
  help="pretrained weights for network / 学習するネットワークの初期重み")
 
1
+ from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
2
+ from torch.optim import Optimizer
3
+ from torch.cuda.amp import autocast
4
  from torch.nn.parallel import DistributedDataParallel as DDP
5
+ from typing import Optional, Union
6
  import importlib
7
  import argparse
8
  import gc
 
15
  from tqdm import tqdm
16
  import torch
17
  from accelerate.utils import set_seed
18
+ import diffusers
19
  from diffusers import DDPMScheduler
20
 
21
  import library.train_util as train_util
22
+ from library.train_util import DreamBoothDataset, FineTuningDataset
 
 
 
 
 
 
 
23
 
24
 
25
  def collate_fn(examples):
26
  return examples[0]
27
 
28
 
 
29
  def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
30
  logs = {"loss/current": current_loss, "loss/average": avr_loss}
31
 
32
  if args.network_train_unet_only:
33
+ logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
34
  elif args.network_train_text_encoder_only:
35
+ logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
36
  else:
37
+ logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
38
+ logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
 
 
 
39
 
40
  return logs
41
 
42
 
43
+ # Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
44
+ # code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
45
+ # Which is a newer release of diffusers than currently packaged with sd-scripts
46
+ # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
47
+
48
+
49
+ def get_scheduler_fix(
50
+ name: Union[str, SchedulerType],
51
+ optimizer: Optimizer,
52
+ num_warmup_steps: Optional[int] = None,
53
+ num_training_steps: Optional[int] = None,
54
+ num_cycles: int = 1,
55
+ power: float = 1.0,
56
+ ):
57
+ """
58
+ Unified API to get any scheduler from its name.
59
+ Args:
60
+ name (`str` or `SchedulerType`):
61
+ The name of the scheduler to use.
62
+ optimizer (`torch.optim.Optimizer`):
63
+ The optimizer that will be used during training.
64
+ num_warmup_steps (`int`, *optional*):
65
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
66
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
67
+ num_training_steps (`int``, *optional*):
68
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
69
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
70
+ num_cycles (`int`, *optional*):
71
+ The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
72
+ power (`float`, *optional*, defaults to 1.0):
73
+ Power factor. See `POLYNOMIAL` scheduler
74
+ last_epoch (`int`, *optional*, defaults to -1):
75
+ The index of the last epoch when resuming training.
76
+ """
77
+ name = SchedulerType(name)
78
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
79
+ if name == SchedulerType.CONSTANT:
80
+ return schedule_func(optimizer)
81
+
82
+ # All other schedulers require `num_warmup_steps`
83
+ if num_warmup_steps is None:
84
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
85
+
86
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
87
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
88
+
89
+ # All other schedulers require `num_training_steps`
90
+ if num_training_steps is None:
91
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
92
+
93
+ if name == SchedulerType.COSINE_WITH_RESTARTS:
94
+ return schedule_func(
95
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
96
+ )
97
+
98
+ if name == SchedulerType.POLYNOMIAL:
99
+ return schedule_func(
100
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
101
+ )
102
+
103
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
104
+
105
+
106
  def train(args):
107
  session_id = random.randint(0, 2**32)
108
  training_started_at = time.time()
 
111
 
112
  cache_latents = args.cache_latents
113
  use_dreambooth_method = args.in_json is None
 
114
 
115
  if args.seed is not None:
116
  set_seed(args.seed)
 
118
  tokenizer = train_util.load_tokenizer(args)
119
 
120
  # データセットを準備する
121
+ if use_dreambooth_method:
122
+ print("Use DreamBooth method.")
123
+ train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
124
+ tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
125
+ args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
126
+ args.bucket_reso_steps, args.bucket_no_upscale,
127
+ args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
128
+ args.random_crop, args.debug_dataset)
129
  else:
130
+ print("Train with captions.")
131
+ train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
132
+ tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
133
+ args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
134
+ args.bucket_reso_steps, args.bucket_no_upscale,
135
+ args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
136
+ args.dataset_repeats, args.debug_dataset)
137
+
138
+ # 学習データのdropout率を設定する
139
+ train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
140
+
141
+ train_dataset.make_buckets()
 
 
 
 
 
 
 
 
142
 
143
  if args.debug_dataset:
144
+ train_util.debug_dataset(train_dataset)
145
  return
146
+ if len(train_dataset) == 0:
147
  print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
148
  return
149
 
 
 
 
 
150
  # acceleratorを準備する
151
  print("prepare accelerator")
152
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
 
153
 
154
  # mixed precisionに対応した型を用意しておき適宜castする
155
  weight_dtype, save_dtype = train_util.prepare_dtype(args)
 
161
  if args.lowram:
162
  text_encoder.to("cuda")
163
  unet.to("cuda")
164
+
165
  # モデルに xformers とか memory efficient attention を組み込む
166
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
167
 
 
171
  vae.requires_grad_(False)
172
  vae.eval()
173
  with torch.no_grad():
174
+ train_dataset.cache_latents(vae)
175
  vae.to("cpu")
176
  if torch.cuda.is_available():
177
  torch.cuda.empty_cache()
178
  gc.collect()
179
 
180
  # prepare network
 
 
181
  print("import network module:", args.network_module)
182
  network_module = importlib.import_module(args.network_module)
183
 
 
208
  # 学習に必要なクラスを準備する
209
  print("prepare optimizer, data loader etc.")
210
 
211
+ # 8-bit Adamを使う
212
+ if args.use_8bit_adam:
213
+ try:
214
+ import bitsandbytes as bnb
215
+ except ImportError:
216
+ raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
217
+ print("use 8-bit Adam optimizer")
218
+ optimizer_class = bnb.optim.AdamW8bit
219
+ elif args.use_lion_optimizer:
220
+ try:
221
+ import lion_pytorch
222
+ except ImportError:
223
+ raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
224
+ print("use Lion optimizer")
225
+ optimizer_class = lion_pytorch.Lion
226
+ else:
227
+ optimizer_class = torch.optim.AdamW
228
+
229
+ optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
230
+
231
  trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
232
+
233
+ # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
234
+ optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
235
 
236
  # dataloaderを準備する
237
  # DataLoaderのプロセス数:0はメインプロセスになる
238
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
239
  train_dataloader = torch.utils.data.DataLoader(
240
+ train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
241
 
242
  # 学習ステップ数を計算する
243
  if args.max_train_epochs is not None:
244
+ args.max_train_steps = args.max_train_epochs * len(train_dataloader)
245
+ print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
 
246
 
247
  # lr schedulerを用意する
248
+ # lr_scheduler = diffusers.optimization.get_scheduler(
249
+ lr_scheduler = get_scheduler_fix(
250
+ args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
251
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
252
+ num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
253
 
254
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
255
  if args.full_fp16:
 
317
  args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
318
 
319
  # 学習する
 
320
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
321
+ print("running training / 学習開始")
322
+ print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
323
+ print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
324
+ print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
325
+ print(f" num epochs / epoch数: {num_train_epochs}")
326
+ print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
327
+ print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
328
+ print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
329
+ print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
330
+
 
 
 
331
  metadata = {
332
  "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
333
  "ss_training_started_at": training_started_at, # unix timestamp
 
335
  "ss_learning_rate": args.learning_rate,
336
  "ss_text_encoder_lr": args.text_encoder_lr,
337
  "ss_unet_lr": args.unet_lr,
338
+ "ss_num_train_images": train_dataset.num_train_images, # includes repeating
339
+ "ss_num_reg_images": train_dataset.num_reg_images,
340
  "ss_num_batches_per_epoch": len(train_dataloader),
341
  "ss_num_epochs": num_train_epochs,
342
+ "ss_batch_size_per_device": args.train_batch_size,
343
+ "ss_total_batch_size": total_batch_size,
344
  "ss_gradient_checkpointing": args.gradient_checkpointing,
345
  "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
346
  "ss_max_train_steps": args.max_train_steps,
 
352
  "ss_mixed_precision": args.mixed_precision,
353
  "ss_full_fp16": bool(args.full_fp16),
354
  "ss_v2": bool(args.v2),
355
+ "ss_resolution": args.resolution,
356
  "ss_clip_skip": args.clip_skip,
357
  "ss_max_token_length": args.max_token_length,
358
+ "ss_color_aug": bool(args.color_aug),
359
+ "ss_flip_aug": bool(args.flip_aug),
360
+ "ss_random_crop": bool(args.random_crop),
361
+ "ss_shuffle_caption": bool(args.shuffle_caption),
362
  "ss_cache_latents": bool(args.cache_latents),
363
+ "ss_enable_bucket": bool(train_dataset.enable_bucket),
364
+ "ss_min_bucket_reso": train_dataset.min_bucket_reso,
365
+ "ss_max_bucket_reso": train_dataset.max_bucket_reso,
366
  "ss_seed": args.seed,
367
+ "ss_keep_tokens": args.keep_tokens,
368
  "ss_noise_offset": args.noise_offset,
369
+ "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
370
+ "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
371
+ "ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
372
+ "ss_bucket_info": json.dumps(train_dataset.bucket_info),
373
  "ss_training_comment": args.training_comment, # will not be updated after training
374
  "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
375
+ "ss_optimizer": optimizer_name
 
 
 
 
 
 
376
  }
377
 
378
+ # uncomment if another network is added
379
+ # for key, value in net_kwargs.items():
380
+ # metadata["ss_arg_" + key] = value
381
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  if args.pretrained_model_name_or_path is not None:
383
  sd_model_name = args.pretrained_model_name_or_path
384
  if os.path.exists(sd_model_name):
 
397
 
398
  metadata = {k: str(v) for k, v in metadata.items()}
399
 
 
 
 
 
 
 
 
400
  progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
401
  global_step = 0
402
 
 
409
  loss_list = []
410
  loss_total = 0.0
411
  for epoch in range(num_train_epochs):
412
+ print(f"epoch {epoch+1}/{num_train_epochs}")
413
+ train_dataset.set_current_epoch(epoch + 1)
 
414
 
415
  metadata["ss_epoch"] = str(epoch+1)
416
 
 
447
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
448
 
449
  # Predict the noise residual
450
+ with autocast():
451
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
452
 
453
  if args.v_parameterization:
 
465
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
466
 
467
  accelerator.backward(loss)
468
+ if accelerator.sync_gradients:
469
  params_to_clip = network.get_trainable_params()
470
+ accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
471
 
472
  optimizer.step()
473
  lr_scheduler.step()
 
478
  progress_bar.update(1)
479
  global_step += 1
480
 
 
 
481
  current_loss = loss.detach().item()
482
  if epoch == 0:
483
  loss_list.append(current_loss)
 
508
  def save_func():
509
  ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
510
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
 
511
  print(f"saving checkpoint: {ckpt_file}")
512
+ unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
513
 
514
  def remove_old_func(old_epoch_no):
515
  old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
 
518
  print(f"removing old checkpoint: {old_ckpt_file}")
519
  os.remove(old_ckpt_file)
520
 
521
+ saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
522
+ if saving and args.save_state:
523
+ train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
 
 
 
524
 
525
  # end of epoch
526
 
527
  metadata["ss_epoch"] = str(num_train_epochs)
 
528
 
529
+ is_main_process = accelerator.is_main_process
530
  if is_main_process:
531
  network = unwrap_model(network)
532
 
 
545
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
546
 
547
  print(f"save trained model to {ckpt_file}")
548
+ network.save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
549
  print("model saved.")
550
 
551
 
 
555
  train_util.add_sd_models_arguments(parser)
556
  train_util.add_dataset_arguments(parser, True, True, True)
557
  train_util.add_training_arguments(parser, True)
 
 
558
 
559
  parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
560
  parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
 
562
 
563
  parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
564
  parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
565
+ parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
566
+ help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
567
+ parser.add_argument("--lr_scheduler_power", type=float, default=1,
568
+ help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
569
 
570
  parser.add_argument("--network_weights", type=str, default=None,
571
  help="pretrained weights for network / 学習するネットワークの初期重み")
train_network_opt.py CHANGED
@@ -1,5 +1,8 @@
 
 
1
  from torch.cuda.amp import autocast
2
  from torch.nn.parallel import DistributedDataParallel as DDP
 
3
  import importlib
4
  import argparse
5
  import gc
@@ -12,49 +15,138 @@ import json
12
  from tqdm import tqdm
13
  import torch
14
  from accelerate.utils import set_seed
15
- #import diffusers
16
  from diffusers import DDPMScheduler
17
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  ##### バケット拡張のためのモジュール
19
  import append_module
20
  ######
21
  import library.train_util as train_util
22
- from library.train_util import (
23
- DreamBoothDataset,
24
- )
25
- import library.config_util as config_util
26
- from library.config_util import (
27
- ConfigSanitizer,
28
- BlueprintGenerator,
29
- )
30
 
31
 
32
  def collate_fn(examples):
33
  return examples[0]
34
 
35
 
36
- # TODO 他のスクリプトと共通化する
37
- def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, split_names=None):
38
  logs = {"loss/current": current_loss, "loss/average": avr_loss}
39
- if not args.split_lora_networks:
40
- if args.network_train_unet_only:
41
- logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
42
- elif args.network_train_text_encoder_only:
43
- logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
44
- else:
45
- logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
46
- logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
47
  else:
48
  last_lrs = lr_scheduler.get_last_lr()
49
- for last_lr, t_name in zip(last_lrs, split_names):
50
- logs[f"lr/{t_name}"] = float(last_lr)
51
- #D-Adaptationの仕様ちゃんと見てないからたぶん分割したのをちゃんと表示するならそれに合わせた記述が必要 でも多分D-Adaptationの挙動的に全部同一の形になるのでいらない
52
- if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
53
- logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  return logs
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def train(args):
59
  session_id = random.randint(0, 2**32)
60
  training_started_at = time.time()
@@ -63,7 +155,6 @@ def train(args):
63
 
64
  cache_latents = args.cache_latents
65
  use_dreambooth_method = args.in_json is None
66
- use_user_config = args.dataset_config is not None
67
 
68
  if args.seed is not None:
69
  set_seed(args.seed)
@@ -71,72 +162,52 @@ def train(args):
71
  tokenizer = train_util.load_tokenizer(args)
72
 
73
  # データセットを準備する
74
- if args.min_resolution:
75
- args.min_resolution = tuple([int(r) for r in args.min_resolution.split(',')])
76
- if len(args.min_resolution) == 1:
77
- args.min_resolution = (args.min_resolution[0], args.min_resolution[0])
78
- blueprint_generator = append_module.BlueprintGenerator(append_module.ConfigSanitizer(True, True, True))
79
- else:
80
- blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
81
- if use_user_config:
82
- print(f"Load dataset config from {args.dataset_config}")
83
- user_config = config_util.load_user_config(args.dataset_config)
84
- ignored = ["train_data_dir", "reg_data_dir", "in_json"]
85
- if any(getattr(args, attr) is not None for attr in ignored):
86
- print(
87
- "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
88
- else:
89
- if use_dreambooth_method:
90
- print("Use DreamBooth method.")
91
- user_config = {
92
- "datasets": [{
93
- "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
94
- }]
95
- }
96
- else:
97
- print("Train with captions.")
98
- user_config = {
99
- "datasets": [{
100
- "subsets": [{
101
- "image_dir": args.train_data_dir,
102
- "metadata_file": args.in_json,
103
- }]
104
- }]
105
- }
106
-
107
- blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
108
- if args.min_resolution:
109
- train_dataset_group = append_module.generate_dataset_group_by_blueprint(blueprint.dataset_group)
110
  else:
111
- train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  if args.debug_dataset:
114
- train_util.debug_dataset(train_dataset_group)
115
  return
116
- if len(train_dataset_group) == 0:
117
  print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
118
  return
119
 
120
- if cache_latents:
121
- assert train_dataset_group.is_latent_cacheable(
122
- ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
123
-
124
  # acceleratorを準備する
125
  print("prepare accelerator")
126
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
127
- is_main_process = accelerator.is_main_process
128
 
129
  # mixed precisionに対応した型を用意しておき適宜castする
130
  weight_dtype, save_dtype = train_util.prepare_dtype(args)
131
 
132
  # モデルを読み込む
133
  text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
134
-
135
- # work on low-ram device
136
- if args.lowram:
137
- text_encoder.to("cuda")
138
- unet.to("cuda")
139
-
140
  # モデルに xformers とか memory efficient attention を組み込む
141
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
142
 
@@ -146,15 +217,13 @@ def train(args):
146
  vae.requires_grad_(False)
147
  vae.eval()
148
  with torch.no_grad():
149
- train_dataset_group.cache_latents(vae)
150
  vae.to("cpu")
151
  if torch.cuda.is_available():
152
  torch.cuda.empty_cache()
153
  gc.collect()
154
 
155
  # prepare network
156
- import sys
157
- sys.path.append(os.path.dirname(__file__))
158
  print("import network module:", args.network_module)
159
  network_module = importlib.import_module(args.network_module)
160
 
@@ -184,65 +253,188 @@ def train(args):
184
 
185
  # 学習に必要なクラスを準備する
186
  print("prepare optimizer, data loader etc.")
187
- split_flag = (args.split_lora_networks) or ((not args.network_train_text_encoder_only) and (not args.network_train_unet_only))
188
-
189
- used_names = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  if args.split_lora_networks:
191
- lr_dic, block_args_dic = append_module.create_lr_blocks(args.blocks_lr_setting, args.block_optim_args)
192
  lora_names = append_module.create_split_names(args.split_lora_networks, args.split_lora_level)
193
- append_module.replace_prepare_optimizer_params(network, network_module)
194
- trainable_params, adafactor_scheduler_arg, used_names = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, lora_names, lr_dic, block_args_dic)
195
  else:
196
  trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
197
- if split_flag:
198
- _t_lr = 0.
199
- _u_lr = 0.
200
- if args.text_encoder_lr:
201
- _t_lr = args.text_encoder_lr
202
- if args.unet_lr:
203
- _u_lr = args.unet_lr
204
- adafactor_scheduler_arg = {"initial_lr": [_t_lr, _u_lr]}
205
-
206
- optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
207
- if args.use_lookahead:
208
- try:
209
- import torch_optimizer
210
- lookahed_arg = {"k": 5, "alpha": 0.5}
211
- if args.lookahead_arg is not None:
212
- for _arg in args.lookahead_arg:
213
- k, v = _arg.split("=")
214
- if k == "k":
215
- lookahed_arg[k] = int(v)
216
- else:
217
- lookahed_arg[k] = float(v)
218
- optimizer = torch_optimizer.Lookahead(optimizer, **lookahed_arg)
219
- except:
220
- print("\n============\ntorch_optimizerのimportに失敗しました Lookaheadを無効化して処理を続けます\n============\n")
221
  # dataloaderを準備する
222
  # DataLoaderのプロセス数:0はメインプロセスになる
223
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
224
  train_dataloader = torch.utils.data.DataLoader(
225
- train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
226
 
227
  # 学習ステップ数を計算する
228
  if args.max_train_epochs is not None:
229
- args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes)
230
- if is_main_process:
231
- print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
232
 
233
  # lr schedulerを用意する
234
- if args.lr_scheduler.startswith("adafactor") and split_flag:
235
- lr_scheduler = append_module.get_scheduler_Adafactor(args.lr_scheduler, optimizer, adafactor_scheduler_arg)
 
 
 
 
236
  else:
237
- lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
238
- num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps,
239
- num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
240
-
 
241
  #追加機能の設定をコメントに追記して残す
242
- if args.use_lookahead:
243
- args.training_comment=f"{args.training_comment} use Lookahead: True Lookahead args: {lookahed_arg}"
244
- if args.split_lora_networks:
245
- args.training_comment=f"{args.training_comment} split_lora_networks: {args.split_lora_networks} split_level: {args.split_lora_level}"
246
  if args.min_resolution:
247
  args.training_comment=f"{args.training_comment} min_resolution: {args.min_resolution} area_step: {args.area_step}"
248
 
@@ -312,21 +504,17 @@ def train(args):
312
  args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
313
 
314
  # 学習する
315
- # TODO: find a way to handle total batch size when there are multiple datasets
316
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
317
-
318
- if is_main_process:
319
- print("running training / 学習開始")
320
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
321
- print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
322
- print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
323
- print(f" num epochs / epoch数: {num_train_epochs}")
324
- print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
325
- # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
326
- print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
327
- print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
328
-
329
- # TODO refactor metadata creation and move to util
330
  metadata = {
331
  "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
332
  "ss_training_started_at": training_started_at, # unix timestamp
@@ -334,10 +522,12 @@ def train(args):
334
  "ss_learning_rate": args.learning_rate,
335
  "ss_text_encoder_lr": args.text_encoder_lr,
336
  "ss_unet_lr": args.unet_lr,
337
- "ss_num_train_images": train_dataset_group.num_train_images,
338
- "ss_num_reg_images": train_dataset_group.num_reg_images,
339
  "ss_num_batches_per_epoch": len(train_dataloader),
340
  "ss_num_epochs": num_train_epochs,
 
 
341
  "ss_gradient_checkpointing": args.gradient_checkpointing,
342
  "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
343
  "ss_max_train_steps": args.max_train_steps,
@@ -349,156 +539,32 @@ def train(args):
349
  "ss_mixed_precision": args.mixed_precision,
350
  "ss_full_fp16": bool(args.full_fp16),
351
  "ss_v2": bool(args.v2),
 
352
  "ss_clip_skip": args.clip_skip,
353
  "ss_max_token_length": args.max_token_length,
 
 
 
 
354
  "ss_cache_latents": bool(args.cache_latents),
 
 
 
355
  "ss_seed": args.seed,
356
- "ss_lowram": args.lowram,
357
  "ss_noise_offset": args.noise_offset,
 
 
 
 
358
  "ss_training_comment": args.training_comment, # will not be updated after training
359
- "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
360
- "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
361
- "ss_max_grad_norm": args.max_grad_norm,
362
- "ss_caption_dropout_rate": args.caption_dropout_rate,
363
- "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
364
- "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
365
- "ss_face_crop_aug_range": args.face_crop_aug_range,
366
- "ss_prior_loss_weight": args.prior_loss_weight,
367
  }
368
 
369
- if use_user_config:
370
- # save metadata of multiple datasets
371
- # NOTE: pack "ss_datasets" value as json one time
372
- # or should also pack nested collections as json?
373
- datasets_metadata = []
374
- tag_frequency = {} # merge tag frequency for metadata editor
375
- dataset_dirs_info = {} # merge subset dirs for metadata editor
376
-
377
- for dataset in train_dataset_group.datasets:
378
- is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
379
- dataset_metadata = {
380
- "is_dreambooth": is_dreambooth_dataset,
381
- "batch_size_per_device": dataset.batch_size,
382
- "num_train_images": dataset.num_train_images, # includes repeating
383
- "num_reg_images": dataset.num_reg_images,
384
- "resolution": (dataset.width, dataset.height),
385
- "enable_bucket": bool(dataset.enable_bucket),
386
- "min_bucket_reso": dataset.min_bucket_reso,
387
- "max_bucket_reso": dataset.max_bucket_reso,
388
- "tag_frequency": dataset.tag_frequency,
389
- "bucket_info": dataset.bucket_info,
390
- }
391
-
392
- subsets_metadata = []
393
- for subset in dataset.subsets:
394
- subset_metadata = {
395
- "img_count": subset.img_count,
396
- "num_repeats": subset.num_repeats,
397
- "color_aug": bool(subset.color_aug),
398
- "flip_aug": bool(subset.flip_aug),
399
- "random_crop": bool(subset.random_crop),
400
- "shuffle_caption": bool(subset.shuffle_caption),
401
- "keep_tokens": subset.keep_tokens,
402
- }
403
-
404
- image_dir_or_metadata_file = None
405
- if subset.image_dir:
406
- image_dir = os.path.basename(subset.image_dir)
407
- subset_metadata["image_dir"] = image_dir
408
- image_dir_or_metadata_file = image_dir
409
-
410
- if is_dreambooth_dataset:
411
- subset_metadata["class_tokens"] = subset.class_tokens
412
- subset_metadata["is_reg"] = subset.is_reg
413
- if subset.is_reg:
414
- image_dir_or_metadata_file = None # not merging reg dataset
415
- else:
416
- metadata_file = os.path.basename(subset.metadata_file)
417
- subset_metadata["metadata_file"] = metadata_file
418
- image_dir_or_metadata_file = metadata_file # may overwrite
419
-
420
- subsets_metadata.append(subset_metadata)
421
-
422
- # merge dataset dir: not reg subset only
423
- # TODO update additional-network extension to show detailed dataset config from metadata
424
- if image_dir_or_metadata_file is not None:
425
- # datasets may have a certain dir multiple times
426
- v = image_dir_or_metadata_file
427
- i = 2
428
- while v in dataset_dirs_info:
429
- v = image_dir_or_metadata_file + f" ({i})"
430
- i += 1
431
- image_dir_or_metadata_file = v
432
-
433
- dataset_dirs_info[image_dir_or_metadata_file] = {
434
- "n_repeats": subset.num_repeats,
435
- "img_count": subset.img_count
436
- }
437
-
438
- dataset_metadata["subsets"] = subsets_metadata
439
- datasets_metadata.append(dataset_metadata)
440
-
441
- # merge tag frequency:
442
- for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
443
- # あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
444
- # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
445
- # なので、ここで複数datasetの回数を合算してもあまり意味はない
446
- if ds_dir_name in tag_frequency:
447
- continue
448
- tag_frequency[ds_dir_name] = ds_freq_for_dir
449
-
450
- metadata["ss_datasets"] = json.dumps(datasets_metadata)
451
- metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
452
- metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
453
- else:
454
- # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
455
- assert len(
456
- train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
457
-
458
- dataset = train_dataset_group.datasets[0]
459
-
460
- dataset_dirs_info = {}
461
- reg_dataset_dirs_info = {}
462
- if use_dreambooth_method:
463
- for subset in dataset.subsets:
464
- info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
465
- info[os.path.basename(subset.image_dir)] = {
466
- "n_repeats": subset.num_repeats,
467
- "img_count": subset.img_count
468
- }
469
- else:
470
- for subset in dataset.subsets:
471
- dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
472
- "n_repeats": subset.num_repeats,
473
- "img_count": subset.img_count
474
- }
475
-
476
- metadata.update({
477
- "ss_batch_size_per_device": args.train_batch_size,
478
- "ss_total_batch_size": total_batch_size,
479
- "ss_resolution": args.resolution,
480
- "ss_color_aug": bool(args.color_aug),
481
- "ss_flip_aug": bool(args.flip_aug),
482
- "ss_random_crop": bool(args.random_crop),
483
- "ss_shuffle_caption": bool(args.shuffle_caption),
484
- "ss_enable_bucket": bool(dataset.enable_bucket),
485
- "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
486
- "ss_min_bucket_reso": dataset.min_bucket_reso,
487
- "ss_max_bucket_reso": dataset.max_bucket_reso,
488
- "ss_keep_tokens": args.keep_tokens,
489
- "ss_dataset_dirs": json.dumps(dataset_dirs_info),
490
- "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
491
- "ss_tag_frequency": json.dumps(dataset.tag_frequency),
492
- "ss_bucket_info": json.dumps(dataset.bucket_info),
493
- })
494
-
495
- # add extra args
496
- if args.network_args:
497
- metadata["ss_network_args"] = json.dumps(net_kwargs)
498
  # for key, value in net_kwargs.items():
499
  # metadata["ss_arg_" + key] = value
500
 
501
- # model name and hash
502
  if args.pretrained_model_name_or_path is not None:
503
  sd_model_name = args.pretrained_model_name_or_path
504
  if os.path.exists(sd_model_name):
@@ -517,13 +583,6 @@ def train(args):
517
 
518
  metadata = {k: str(v) for k, v in metadata.items()}
519
 
520
- # make minimum metadata for filtering
521
- minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"]
522
- minimum_metadata = {}
523
- for key in minimum_keys:
524
- if key in metadata:
525
- minimum_metadata[key] = metadata[key]
526
-
527
  progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
528
  global_step = 0
529
 
@@ -536,9 +595,8 @@ def train(args):
536
  loss_list = []
537
  loss_total = 0.0
538
  for epoch in range(num_train_epochs):
539
- if is_main_process:
540
- print(f"epoch {epoch+1}/{num_train_epochs}")
541
- train_dataset_group.set_current_epoch(epoch + 1)
542
 
543
  metadata["ss_epoch"] = str(epoch+1)
544
 
@@ -575,7 +633,7 @@ def train(args):
575
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
576
 
577
  # Predict the noise residual
578
- with accelerator.autocast():
579
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
580
 
581
  if args.v_parameterization:
@@ -593,13 +651,12 @@ def train(args):
593
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
594
 
595
  accelerator.backward(loss)
596
- if accelerator.sync_gradients and args.max_grad_norm != 0.0:
597
  params_to_clip = network.get_trainable_params()
598
- accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
599
 
600
  optimizer.step()
601
- if accelerator.sync_gradients:
602
- lr_scheduler.step()
603
  optimizer.zero_grad(set_to_none=True)
604
 
605
  # Checks if the accelerator has performed an optimization step behind the scenes
@@ -607,8 +664,6 @@ def train(args):
607
  progress_bar.update(1)
608
  global_step += 1
609
 
610
- train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
611
-
612
  current_loss = loss.detach().item()
613
  if epoch == 0:
614
  loss_list.append(current_loss)
@@ -621,7 +676,7 @@ def train(args):
621
  progress_bar.set_postfix(**logs)
622
 
623
  if args.logging_dir is not None:
624
- logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, used_names)
625
  accelerator.log(logs, step=global_step)
626
 
627
  if global_step >= args.max_train_steps:
@@ -639,9 +694,8 @@ def train(args):
639
  def save_func():
640
  ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
641
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
642
- metadata["ss_training_finished_at"] = str(time.time())
643
  print(f"saving checkpoint: {ckpt_file}")
644
- unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
645
 
646
  def remove_old_func(old_epoch_no):
647
  old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
@@ -650,18 +704,15 @@ def train(args):
650
  print(f"removing old checkpoint: {old_ckpt_file}")
651
  os.remove(old_ckpt_file)
652
 
653
- if is_main_process:
654
- saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
655
- if saving and args.save_state:
656
- train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
657
-
658
- train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
659
 
660
  # end of epoch
661
 
662
  metadata["ss_epoch"] = str(num_train_epochs)
663
- metadata["ss_training_finished_at"] = str(time.time())
664
 
 
665
  if is_main_process:
666
  network = unwrap_model(network)
667
 
@@ -680,7 +731,7 @@ def train(args):
680
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
681
 
682
  print(f"save trained model to {ckpt_file}")
683
- network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
684
  print("model saved.")
685
 
686
 
@@ -690,8 +741,6 @@ if __name__ == '__main__':
690
  train_util.add_sd_models_arguments(parser)
691
  train_util.add_dataset_arguments(parser, True, True, True)
692
  train_util.add_training_arguments(parser, True)
693
- train_util.add_optimizer_arguments(parser)
694
- config_util.add_config_arguments(parser)
695
 
696
  parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
697
  parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
@@ -699,6 +748,10 @@ if __name__ == '__main__':
699
 
700
  parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
701
  parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
 
 
 
 
702
 
703
  parser.add_argument("--network_weights", type=str, default=None,
704
  help="pretrained weights for network / 学習するネットワークの初期重み")
@@ -718,30 +771,27 @@ if __name__ == '__main__':
718
  #Optimizer変更関連のオプション追加
719
  append_module.add_append_arguments(parser)
720
  args = append_module.get_config(parser)
721
- if not args.not_output_config:
722
- #argsを保存する
723
- import yaml
724
- import datetime
725
- _t = datetime.datetime.today().strftime('%Y%m%d_%H%M')
726
- if args.output_name==None:
727
- config_name = f"train_network_config_{_t}.yaml"
728
- else:
729
- config_name = f"train_network_config_{os.path.basename(args.output_name)}_{_t}.yaml"
730
- print(f"{config_name} に設定を書き出し中...")
731
- with open(config_name, mode="w") as f:
732
- yaml.dump(args.__dict__, f, indent=4)
733
 
734
  if args.resolution==args.min_resolution:
735
  args.min_resolution=None
736
 
737
  train(args)
738
- print("done!")
739
 
 
 
 
 
 
 
 
 
 
 
 
 
740
 
741
  '''
742
  optimizer設定メモ
743
- torch_optimizer.AdaBelief
744
- adastand.Adastand
745
  (optimizer_argから設定できるように変更するためのメモ)
746
 
747
  AdamWのweight_decay初期値は1e-2
@@ -771,7 +821,6 @@ Adafactor
771
  transformerベースのT5学習において最強とかいう噂のoptimizer
772
  huggingfaceのサンプルパラ
773
  eps=1e-30,1e-3 clip_threshold=1.0 decay_rate=-0.8 relative_step=False scale_parameter=False warmup_init=False
774
- epsの二つ目の値1e-3が学習率に影響大きい
775
 
776
  AggMo
777
 
 
1
+ from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
2
+ from torch.optim import Optimizer
3
  from torch.cuda.amp import autocast
4
  from torch.nn.parallel import DistributedDataParallel as DDP
5
+ from typing import Optional, Union
6
  import importlib
7
  import argparse
8
  import gc
 
15
  from tqdm import tqdm
16
  import torch
17
  from accelerate.utils import set_seed
18
+ import diffusers
19
  from diffusers import DDPMScheduler
20
+ print("**********************************")
21
+ #先に
22
+ #pip install torch_optimizer
23
+ #が必要
24
+ try:
25
+ import torch_optimizer as optim
26
+ except:
27
+ print("torch_optimizerがインストールされていないためAdafactorとAdastand以外の追加optimzierは使えません。\noptimizerの変更をしたい場合先にpip install torch_optimizerでライブラリを追加してください")
28
+ try:
29
+ import adastand
30
+ except:
31
+ print("※Adastandが使えません")
32
+
33
+ from transformers.optimization import Adafactor, AdafactorSchedule
34
+ print("**********************************")
35
  ##### バケット拡張のためのモジュール
36
  import append_module
37
  ######
38
  import library.train_util as train_util
39
+ from library.train_util import DreamBoothDataset, FineTuningDataset
 
 
 
 
 
 
 
40
 
41
 
42
  def collate_fn(examples):
43
  return examples[0]
44
 
45
 
46
+ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
 
47
  logs = {"loss/current": current_loss, "loss/average": avr_loss}
48
+
49
+ if args.network_train_unet_only:
50
+ logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
51
+ elif args.network_train_text_encoder_only:
52
+ logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
 
 
 
53
  else:
54
  last_lrs = lr_scheduler.get_last_lr()
55
+ if len(last_lrs) == 2:
56
+ logs["lr/textencoder"] = float(last_lrs[0])
57
+ logs["lr/unet"] = float(last_lrs[-1]) # may be same to textencoder
58
+ else:
59
+ if len(last_lrs) == 4:
60
+ logs_names = ["textencoder", "lora_unet_mid_block", "unet_down_blocks", "unet_up_blocks"]
61
+ elif len(last_lrs) == 8:
62
+ logs_names = ["textencoder", "unet_midblock"]
63
+ for i in range(3):
64
+ logs_names.append(f"unet_down_blocks_{i}")
65
+ logs_names.append(f"unet_up_blocks_{i+1}")
66
+ else:
67
+ logs_names = []
68
+ for i in range(12):
69
+ logs_names.append(f"text_model_encoder_layers_{i}_")
70
+ logs_names.append("unet_midblock")
71
+ for i in range(3):
72
+ logs_names.append(f"unet_down_blocks_{i}")
73
+ logs_names.append(f"unet_up_blocks_{i+1}")
74
+
75
+ for last_lr, logs_name in zip(last_lrs, logs_names):
76
+ logs[f"lr/{logs_name}"] = float(last_lr)
77
 
78
  return logs
79
 
80
 
81
+ # Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
82
+ # code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
83
+ # Which is a newer release of diffusers than currently packaged with sd-scripts
84
+ # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
85
+
86
+
87
+ def get_scheduler_fix(
88
+ name: Union[str, SchedulerType],
89
+ optimizer: Optimizer,
90
+ num_warmup_steps: Optional[int] = None,
91
+ num_training_steps: Optional[int] = None,
92
+ num_cycles: float = 1.,
93
+ power: float = 1.0,
94
+ ):
95
+ """
96
+ Unified API to get any scheduler from its name.
97
+ Args:
98
+ name (`str` or `SchedulerType`):
99
+ The name of the scheduler to use.
100
+ optimizer (`torch.optim.Optimizer`):
101
+ The optimizer that will be used during training.
102
+ num_warmup_steps (`int`, *optional*):
103
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
104
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
105
+ num_training_steps (`int``, *optional*):
106
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
107
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
108
+ num_cycles (`int`, *optional*):
109
+ The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
110
+ power (`float`, *optional*, defaults to 1.0):
111
+ Power factor. See `POLYNOMIAL` scheduler
112
+ last_epoch (`int`, *optional*, defaults to -1):
113
+ The index of the last epoch when resuming training.
114
+ """
115
+ name = SchedulerType(name)
116
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
117
+ if name == SchedulerType.CONSTANT:
118
+ return schedule_func(optimizer)
119
+
120
+ # All other schedulers require `num_warmup_steps`
121
+ if num_warmup_steps is None:
122
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
123
+
124
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
125
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
126
+
127
+ # All other schedulers require `num_training_steps`
128
+ if num_training_steps is None:
129
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
130
+
131
+ if name == SchedulerType.COSINE:
132
+ print(f"{name} num_cycles: {num_cycles}")
133
+ return schedule_func(
134
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
135
+ )
136
+ if name == SchedulerType.COSINE_WITH_RESTARTS:
137
+ print(f"{name} num_cycles: {int(num_cycles)}")
138
+ return schedule_func(
139
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=int(num_cycles)
140
+ )
141
+
142
+ if name == SchedulerType.POLYNOMIAL:
143
+ return schedule_func(
144
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
145
+ )
146
+
147
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
148
+
149
+
150
  def train(args):
151
  session_id = random.randint(0, 2**32)
152
  training_started_at = time.time()
 
155
 
156
  cache_latents = args.cache_latents
157
  use_dreambooth_method = args.in_json is None
 
158
 
159
  if args.seed is not None:
160
  set_seed(args.seed)
 
162
  tokenizer = train_util.load_tokenizer(args)
163
 
164
  # データセットを準備する
165
+ if use_dreambooth_method:
166
+ if args.min_resolution:
167
+ args.min_resolution = tuple([int(r) for r in args.min_resolution.split(',')])
168
+ if len(args.min_resolution) == 1:
169
+ args.min_resolution = (args.min_resolution[0], args.min_resolution[0])
170
+
171
+ print("Use DreamBooth method.")
172
+ train_dataset = append_module.DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
173
+ tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
174
+ args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
175
+ args.bucket_reso_steps, args.bucket_no_upscale,
176
+ args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
177
+ args.random_crop, args.debug_dataset, args.min_resolution, args.area_step)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  else:
179
+ print("Train with captions.")
180
+ train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
181
+ tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
182
+ args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
183
+ args.bucket_reso_steps, args.bucket_no_upscale,
184
+ args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
185
+ args.dataset_repeats, args.debug_dataset)
186
+
187
+ # 学習データのdropout率を設定する
188
+ train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
189
+
190
+ train_dataset.make_buckets()
191
 
192
  if args.debug_dataset:
193
+ train_util.debug_dataset(train_dataset)
194
  return
195
+ if len(train_dataset) == 0:
196
  print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
197
  return
198
 
 
 
 
 
199
  # acceleratorを準備する
200
  print("prepare accelerator")
201
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
 
202
 
203
  # mixed precisionに対応した型を用意しておき適宜castする
204
  weight_dtype, save_dtype = train_util.prepare_dtype(args)
205
 
206
  # モデルを読み込む
207
  text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
208
+ # unnecessary, but work on low-ram device
209
+ text_encoder.to("cuda")
210
+ unet.to("cuda")
 
 
 
211
  # モデルに xformers とか memory efficient attention を組み込む
212
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
213
 
 
217
  vae.requires_grad_(False)
218
  vae.eval()
219
  with torch.no_grad():
220
+ train_dataset.cache_latents(vae)
221
  vae.to("cpu")
222
  if torch.cuda.is_available():
223
  torch.cuda.empty_cache()
224
  gc.collect()
225
 
226
  # prepare network
 
 
227
  print("import network module:", args.network_module)
228
  network_module = importlib.import_module(args.network_module)
229
 
 
253
 
254
  # 学習に必要なクラスを準備する
255
  print("prepare optimizer, data loader etc.")
256
+ try:
257
+ print(f"torch_optimzier version is {optim.__version__}")
258
+ not_torch_optimizer_flag = False
259
+ except:
260
+ not_torch_optimizer_flag = True
261
+ try:
262
+ print(f"adastand version is {adastand.__version__()}")
263
+ not_adasatand_optimzier_flag = False
264
+ except:
265
+ not_adasatand_optimzier_flag = True
266
+
267
+ # 8-bit Adamを使う
268
+ if args.optimizer=="Adafactor" or args.optimizer=="Adastand" or args.optimizer=="Adastand_belief":
269
+ not_torch_optimizer_flag = False
270
+ if args.optimizer=="Adafactor":
271
+ not_adasatand_optimzier_flag = False
272
+ if not_torch_optimizer_flag or not_adasatand_optimzier_flag:
273
+ print(f"==========================\n必要なライブラリがないため {args.optimizer} の使用ができません。optimizerを AdamW に変更して実行します\n==========================")
274
+ args.optimizer="AdamW"
275
+ if args.use_8bit_adam:
276
+ if not args.optimizer=="AdamW" and not args.optimizer=="Lamb":
277
+ print(f"\n==========================\n{args.optimizer} は8bitAdamに実装されていないので8bitAdamをオフにします\n==========================\n")
278
+ args.use_8bit_adam=False
279
+ if args.use_8bit_adam:
280
+ try:
281
+ import bitsandbytes as bnb
282
+ except ImportError:
283
+ raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
284
+ print("use 8-bit Adam optimizer")
285
+ args.training_comment=f"{args.training_comment} use_8bit_adam=True"
286
+ if args.optimizer=="Lamb":
287
+ optimizer_class = bnb.optim.LAMB8bit
288
+ else:
289
+ args.optimizer="AdamW"
290
+ optimizer_class = bnb.optim.AdamW8bit
291
+ else:
292
+ print(f"use {args.optimizer}")
293
+ if args.optimizer=="RAdam":
294
+ optimizer_class = torch.optim.RAdam
295
+ elif args.optimizer=="AdaBound":
296
+ optimizer_class = optim.AdaBound
297
+ elif args.optimizer=="AdaBelief":
298
+ optimizer_class = optim.AdaBelief
299
+ elif args.optimizer=="AdamP":
300
+ optimizer_class = optim.AdamP
301
+ elif args.optimizer=="Adafactor":
302
+ optimizer_class = Adafactor
303
+ elif args.optimizer=="Adastand":
304
+ optimizer_class = adastand.Adastand
305
+ elif args.optimizer=="Adastand_belief":
306
+ optimizer_class = adastand.Adastand_b
307
+ elif args.optimizer=="AggMo":
308
+ optimizer_class = optim.AggMo
309
+ elif args.optimizer=="Apollo":
310
+ optimizer_class = optim.Apollo
311
+ elif args.optimizer=="Lamb":
312
+ optimizer_class = optim.Lamb
313
+ elif args.optimizer=="Ranger":
314
+ optimizer_class = optim.Ranger
315
+ elif args.optimizer=="RangerVA":
316
+ optimizer_class = optim.RangerVA
317
+ elif args.optimizer=="Yogi":
318
+ optimizer_class = optim.Yogi
319
+ elif args.optimizer=="Shampoo":
320
+ optimizer_class = optim.Shampoo
321
+ elif args.optimizer=="NovoGrad":
322
+ optimizer_class = optim.NovoGrad
323
+ elif args.optimizer=="QHAdam":
324
+ optimizer_class = optim.QHAdam
325
+ elif args.optimizer=="DiffGrad" or args.optimizer=="Lookahead_DiffGrad":
326
+ optimizer_class = optim.DiffGrad
327
+ elif args.optimizer=="MADGRAD":
328
+ optimizer_class = optim.MADGRAD
329
+ else:
330
+ optimizer_class = torch.optim.AdamW
331
+
332
+ # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
333
+ #optimizerデフォ設定
334
+ if args.optimizer_arg==None:
335
+ if args.optimizer=="AdaBelief":
336
+ args.optimizer_arg = ["eps=1e-16","betas=0.9,0.999","weight_decouple=True","rectify=False","fixed_decay=False"]
337
+ elif args.optimizer=="DiffGrad":
338
+ args.optimizer_arg = ["eps=1e-16"]
339
+ optimizer_arg = {}
340
+ lookahed_arg = {"k": 5, "alpha": 0.5}
341
+ adafactor_scheduler_arg = {"initial_lr": 0.}
342
+ int_args = ["k","n_sma_threshold","warmup"]
343
+ str_args = ["transformer","grad_transformer"]
344
+ if not args.optimizer_arg==None and len(args.optimizer_arg)>0:
345
+ for _opt_arg in args.optimizer_arg:
346
+ key, value = _opt_arg.split("=")
347
+ if value=="True" or value=="False":
348
+ optimizer_arg[key]=bool((value=="True"))
349
+ elif key=="betas" or key=="nus" or key=="eps2" or (key=="eps" and "," in value):
350
+ _value = value.split(",")
351
+ optimizer_arg[key] = (float(_value[0]),float(_value[1]))
352
+ del _value
353
+ elif key in int_args:
354
+ if "Lookahead" in args.optimizer:
355
+ lookahed_arg[key] = int(value)
356
+ else:
357
+ optimizer_arg[key] = int(value)
358
+ elif key in str_args:
359
+ optimizer_arg[key] = value
360
+ else:
361
+ if key=="alpha" and "Lookahead" in args.optimizer:
362
+ lookahed_arg[key] = int(value)
363
+ elif key=="initial_lr" and args.optimizer == "Adafactor":
364
+ adafactor_scheduler_arg[key] = float(value)
365
+ else:
366
+ optimizer_arg[key] = float(value)
367
+ del _opt_arg
368
+ AdafactorScheduler_Flag = False
369
+ list_of_init_lr = []
370
+ if args.optimizer=="Adafactor":
371
+ if not "relative_step" in optimizer_arg:
372
+ optimizer_arg["relative_step"] = True
373
+ if "warmup_init" in optimizer_arg:
374
+ if optimizer_arg["warmup_init"]==True and optimizer_arg["relative_step"]==False:
375
+ print("**************\nwarmup_initはrelative_stepがオンである必要があるためrelative_stepをオンにします\n**************")
376
+ optimizer_arg["relative_step"] = True
377
+ if optimizer_arg["relative_step"] == True:
378
+ AdafactorScheduler_Flag = True
379
+ list_of_init_lr = [0.,0.]
380
+ if args.text_encoder_lr is not None: list_of_init_lr[0] = float(args.text_encoder_lr)
381
+ if args.unet_lr is not None: list_of_init_lr[1] = float(args.unet_lr)
382
+ #if not "initial_lr" in adafactor_scheduler_arg:
383
+ # adafactor_scheduler_arg = args.learning_rate
384
+ args.learning_rate = None
385
+ args.text_encoder_lr = None
386
+ args.unet_lr = None
387
+ print(f"optimizer arg: {optimizer_arg}")
388
+ print("=-----------------------------------=")
389
+ if not AdafactorScheduler_Flag: args.split_lora_networks = False
390
  if args.split_lora_networks:
 
391
  lora_names = append_module.create_split_names(args.split_lora_networks, args.split_lora_level)
392
+ append_module.replace_prepare_optimizer_params(network)
393
+ trainable_params, _list_of_init_lr = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, list_of_init_lr, lora_names)
394
  else:
395
  trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
396
+ _list_of_init_lr = []
397
+ print(f"trainable_params_len: {len(trainable_params)}")
398
+ if len(_list_of_init_lr)>0:
399
+ list_of_init_lr = _list_of_init_lr
400
+ print(f"split loras network is {len(list_of_init_lr)}")
401
+ if len(list_of_init_lr) > 0:
402
+ adafactor_scheduler_arg["initial_lr"] = list_of_init_lr
403
+
404
+ optimizer = optimizer_class(trainable_params, lr=args.learning_rate, **optimizer_arg)
405
+
406
+ if args.optimizer=="Lookahead_DiffGrad" or args.optimizer=="Lookahedad_Adam":
407
+ optimizer = optim.Lookahead(optimizer, **lookahed_arg)
408
+ print(f"lookahed_arg: {lookahed_arg}")
409
+
 
 
 
 
 
 
 
 
 
 
410
  # dataloaderを準備する
411
  # DataLoaderのプロセス数:0はメインプロセスになる
412
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
413
  train_dataloader = torch.utils.data.DataLoader(
414
+ train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
415
 
416
  # 学習ステップ数を計算する
417
  if args.max_train_epochs is not None:
418
+ args.max_train_steps = args.max_train_epochs * len(train_dataloader)
419
+ print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
 
420
 
421
  # lr schedulerを用意する
422
+ # lr_scheduler = diffusers.optimization.get_scheduler(
423
+ if AdafactorScheduler_Flag:
424
+ print("===================================\nAdafactorはデフォルトでrelative_stepがオンになっているので lrは自動算出されるためLrScheculerの指定も無効になります\nもし任意のLrやLr_Schedulerを使いたい場合は --optimizer_arg relative_ste=False を指定してください\nまた任意のLrを使う場合は scale_parameter=False も併せて指定するのが推奨です\n===================================")
425
+ lr_scheduler = append_module.AdafactorSchedule_append(optimizer, **adafactor_scheduler_arg)
426
+ print(f"AdafactorSchedule initial lrs: {lr_scheduler.get_lr()}")
427
+ del list_of_init_lr
428
  else:
429
+ lr_scheduler = get_scheduler_fix(
430
+ args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
431
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
432
+ num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
433
+
434
  #追加機能の設定をコメントに追記して残す
435
+ args.training_comment=f"{args.training_comment} optimizer: {args.optimizer} / optimizer_arg: {args.optimizer_arg}"
436
+ if AdafactorScheduler_Flag:
437
+ args.training_comment=f"{args.training_comment} split_lora_networks: {args.split_lora_networks}"
 
438
  if args.min_resolution:
439
  args.training_comment=f"{args.training_comment} min_resolution: {args.min_resolution} area_step: {args.area_step}"
440
 
 
504
  args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
505
 
506
  # 学習する
 
507
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
508
+ print("running training / 学習開始")
509
+ print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
510
+ print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
511
+ print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
512
+ print(f" num epochs / epoch数: {num_train_epochs}")
513
+ print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
514
+ print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
515
+ print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
516
+ print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
517
+
 
 
 
518
  metadata = {
519
  "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
520
  "ss_training_started_at": training_started_at, # unix timestamp
 
522
  "ss_learning_rate": args.learning_rate,
523
  "ss_text_encoder_lr": args.text_encoder_lr,
524
  "ss_unet_lr": args.unet_lr,
525
+ "ss_num_train_images": train_dataset.num_train_images, # includes repeating
526
+ "ss_num_reg_images": train_dataset.num_reg_images,
527
  "ss_num_batches_per_epoch": len(train_dataloader),
528
  "ss_num_epochs": num_train_epochs,
529
+ "ss_batch_size_per_device": args.train_batch_size,
530
+ "ss_total_batch_size": total_batch_size,
531
  "ss_gradient_checkpointing": args.gradient_checkpointing,
532
  "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
533
  "ss_max_train_steps": args.max_train_steps,
 
539
  "ss_mixed_precision": args.mixed_precision,
540
  "ss_full_fp16": bool(args.full_fp16),
541
  "ss_v2": bool(args.v2),
542
+ "ss_resolution": args.resolution,
543
  "ss_clip_skip": args.clip_skip,
544
  "ss_max_token_length": args.max_token_length,
545
+ "ss_color_aug": bool(args.color_aug),
546
+ "ss_flip_aug": bool(args.flip_aug),
547
+ "ss_random_crop": bool(args.random_crop),
548
+ "ss_shuffle_caption": bool(args.shuffle_caption),
549
  "ss_cache_latents": bool(args.cache_latents),
550
+ "ss_enable_bucket": bool(train_dataset.enable_bucket),
551
+ "ss_min_bucket_reso": train_dataset.min_bucket_reso,
552
+ "ss_max_bucket_reso": train_dataset.max_bucket_reso,
553
  "ss_seed": args.seed,
554
+ "ss_keep_tokens": args.keep_tokens,
555
  "ss_noise_offset": args.noise_offset,
556
+ "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
557
+ "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
558
+ "ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
559
+ "ss_bucket_info": json.dumps(train_dataset.bucket_info),
560
  "ss_training_comment": args.training_comment, # will not be updated after training
561
+ "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash()
 
 
 
 
 
 
 
562
  }
563
 
564
+ # uncomment if another network is added
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
  # for key, value in net_kwargs.items():
566
  # metadata["ss_arg_" + key] = value
567
 
 
568
  if args.pretrained_model_name_or_path is not None:
569
  sd_model_name = args.pretrained_model_name_or_path
570
  if os.path.exists(sd_model_name):
 
583
 
584
  metadata = {k: str(v) for k, v in metadata.items()}
585
 
 
 
 
 
 
 
 
586
  progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
587
  global_step = 0
588
 
 
595
  loss_list = []
596
  loss_total = 0.0
597
  for epoch in range(num_train_epochs):
598
+ print(f"epoch {epoch+1}/{num_train_epochs}")
599
+ train_dataset.set_current_epoch(epoch + 1)
 
600
 
601
  metadata["ss_epoch"] = str(epoch+1)
602
 
 
633
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
634
 
635
  # Predict the noise residual
636
+ with autocast():
637
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
638
 
639
  if args.v_parameterization:
 
651
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
652
 
653
  accelerator.backward(loss)
654
+ if accelerator.sync_gradients:
655
  params_to_clip = network.get_trainable_params()
656
+ accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
657
 
658
  optimizer.step()
659
+ lr_scheduler.step()
 
660
  optimizer.zero_grad(set_to_none=True)
661
 
662
  # Checks if the accelerator has performed an optimization step behind the scenes
 
664
  progress_bar.update(1)
665
  global_step += 1
666
 
 
 
667
  current_loss = loss.detach().item()
668
  if epoch == 0:
669
  loss_list.append(current_loss)
 
676
  progress_bar.set_postfix(**logs)
677
 
678
  if args.logging_dir is not None:
679
+ logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
680
  accelerator.log(logs, step=global_step)
681
 
682
  if global_step >= args.max_train_steps:
 
694
  def save_func():
695
  ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
696
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
 
697
  print(f"saving checkpoint: {ckpt_file}")
698
+ unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
699
 
700
  def remove_old_func(old_epoch_no):
701
  old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
 
704
  print(f"removing old checkpoint: {old_ckpt_file}")
705
  os.remove(old_ckpt_file)
706
 
707
+ saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
708
+ if saving and args.save_state:
709
+ train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
 
 
 
710
 
711
  # end of epoch
712
 
713
  metadata["ss_epoch"] = str(num_train_epochs)
 
714
 
715
+ is_main_process = accelerator.is_main_process
716
  if is_main_process:
717
  network = unwrap_model(network)
718
 
 
731
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
732
 
733
  print(f"save trained model to {ckpt_file}")
734
+ network.save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
735
  print("model saved.")
736
 
737
 
 
741
  train_util.add_sd_models_arguments(parser)
742
  train_util.add_dataset_arguments(parser, True, True, True)
743
  train_util.add_training_arguments(parser, True)
 
 
744
 
745
  parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
746
  parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
 
748
 
749
  parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
750
  parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
751
+ parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
752
+ help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
753
+ parser.add_argument("--lr_scheduler_power", type=float, default=1,
754
+ help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
755
 
756
  parser.add_argument("--network_weights", type=str, default=None,
757
  help="pretrained weights for network / 学習するネットワークの初期重み")
 
771
  #Optimizer変更関連のオプション追加
772
  append_module.add_append_arguments(parser)
773
  args = append_module.get_config(parser)
 
 
 
 
 
 
 
 
 
 
 
 
774
 
775
  if args.resolution==args.min_resolution:
776
  args.min_resolution=None
777
 
778
  train(args)
 
779
 
780
+ #学習が終わったら現在のargsを保存する
781
+ # import yaml
782
+ # import datetime
783
+ # _t = datetime.datetime.today().strftime('%Y%m%d_%H%M')
784
+ # if args.output_name==None:
785
+ # config_name = f"train_network_config_{_t}.yaml"
786
+ # else:
787
+ # config_name = f"train_network_config_{os.path.basename(args.output_name)}_{_t}.yaml"
788
+ # print(f"{config_name} に設定を書き出し中...")
789
+ # with open(config_name, mode="w") as f:
790
+ # yaml.dump(args.__dict__, f, indent=4)
791
+ # print("done!")
792
 
793
  '''
794
  optimizer設定メモ
 
 
795
  (optimizer_argから設定できるように変更するためのメモ)
796
 
797
  AdamWのweight_decay初期値は1e-2
 
821
  transformerベースのT5学習において最強とかいう噂のoptimizer
822
  huggingfaceのサンプルパラ
823
  eps=1e-30,1e-3 clip_threshold=1.0 decay_rate=-0.8 relative_step=False scale_parameter=False warmup_init=False
 
824
 
825
  AggMo
826
 
train_textual_inversion.py CHANGED
@@ -11,11 +11,7 @@ import diffusers
11
  from diffusers import DDPMScheduler
12
 
13
  import library.train_util as train_util
14
- import library.config_util as config_util
15
- from library.config_util import (
16
- ConfigSanitizer,
17
- BlueprintGenerator,
18
- )
19
 
20
  imagenet_templates_small = [
21
  "a photo of a {}",
@@ -83,6 +79,7 @@ def train(args):
83
  train_util.prepare_dataset_args(args, True)
84
 
85
  cache_latents = args.cache_latents
 
86
 
87
  if args.seed is not None:
88
  set_seed(args.seed)
@@ -142,35 +139,21 @@ def train(args):
142
  print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
143
 
144
  # データセットを準備する
145
- blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
146
- if args.dataset_config is not None:
147
- print(f"Load dataset config from {args.dataset_config}")
148
- user_config = config_util.load_user_config(args.dataset_config)
149
- ignored = ["train_data_dir", "reg_data_dir", "in_json"]
150
- if any(getattr(args, attr) is not None for attr in ignored):
151
- print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
152
  else:
153
- use_dreambooth_method = args.in_json is None
154
- if use_dreambooth_method:
155
- print("Use DreamBooth method.")
156
- user_config = {
157
- "datasets": [{
158
- "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
159
- }]
160
- }
161
- else:
162
- print("Train with captions.")
163
- user_config = {
164
- "datasets": [{
165
- "subsets": [{
166
- "image_dir": args.train_data_dir,
167
- "metadata_file": args.in_json,
168
- }]
169
- }]
170
- }
171
-
172
- blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
173
- train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
174
 
175
  # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
176
  if use_template:
@@ -180,30 +163,20 @@ def train(args):
180
  captions = []
181
  for tmpl in templates:
182
  captions.append(tmpl.format(replace_to))
183
- train_dataset_group.add_replacement("", captions)
 
 
 
184
 
185
- if args.num_vectors_per_token > 1:
186
- prompt_replacement = (args.token_string, replace_to)
187
- else:
188
- prompt_replacement = None
189
- else:
190
- if args.num_vectors_per_token > 1:
191
- replace_to = " ".join(token_strings)
192
- train_dataset_group.add_replacement(args.token_string, replace_to)
193
- prompt_replacement = (args.token_string, replace_to)
194
- else:
195
- prompt_replacement = None
196
 
197
  if args.debug_dataset:
198
- train_util.debug_dataset(train_dataset_group, show_input_ids=True)
199
  return
200
- if len(train_dataset_group) == 0:
201
  print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
202
  return
203
 
204
- if cache_latents:
205
- assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
206
-
207
  # モデルに xformers とか memory efficient attention を組み込む
208
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
209
 
@@ -213,7 +186,7 @@ def train(args):
213
  vae.requires_grad_(False)
214
  vae.eval()
215
  with torch.no_grad():
216
- train_dataset_group.cache_latents(vae)
217
  vae.to("cpu")
218
  if torch.cuda.is_available():
219
  torch.cuda.empty_cache()
@@ -225,14 +198,35 @@ def train(args):
225
 
226
  # 学習に必要なクラスを準備する
227
  print("prepare optimizer, data loader etc.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  trainable_params = text_encoder.get_input_embeddings().parameters()
229
- _, _, optimizer = train_util.get_optimizer(args, trainable_params)
 
 
230
 
231
  # dataloaderを準備する
232
  # DataLoaderのプロセス数:0はメインプロセスになる
233
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
234
  train_dataloader = torch.utils.data.DataLoader(
235
- train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
236
 
237
  # 学習ステップ数を計算する
238
  if args.max_train_epochs is not None:
@@ -240,9 +234,8 @@ def train(args):
240
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
241
 
242
  # lr schedulerを用意する
243
- lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
244
- num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
245
- num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
246
 
247
  # acceleratorがなんかよろしくやってくれるらしい
248
  text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
@@ -290,8 +283,8 @@ def train(args):
290
  # 学習する
291
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
292
  print("running training / 学習開始")
293
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
294
- print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
295
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
296
  print(f" num epochs / epoch数: {num_train_epochs}")
297
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
@@ -310,11 +303,12 @@ def train(args):
310
 
311
  for epoch in range(num_train_epochs):
312
  print(f"epoch {epoch+1}/{num_train_epochs}")
313
- train_dataset_group.set_current_epoch(epoch + 1)
314
 
315
  text_encoder.train()
316
 
317
  loss_total = 0
 
318
  for step, batch in enumerate(train_dataloader):
319
  with accelerator.accumulate(text_encoder):
320
  with torch.no_grad():
@@ -363,9 +357,9 @@ def train(args):
363
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
364
 
365
  accelerator.backward(loss)
366
- if accelerator.sync_gradients and args.max_grad_norm != 0.0:
367
  params_to_clip = text_encoder.get_input_embeddings().parameters()
368
- accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
369
 
370
  optimizer.step()
371
  lr_scheduler.step()
@@ -380,14 +374,9 @@ def train(args):
380
  progress_bar.update(1)
381
  global_step += 1
382
 
383
- train_util.sample_images(accelerator, args, None, global_step, accelerator.device,
384
- vae, tokenizer, text_encoder, unet, prompt_replacement)
385
-
386
  current_loss = loss.detach().item()
387
  if args.logging_dir is not None:
388
- logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
389
- if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
390
- logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
391
  accelerator.log(logs, step=global_step)
392
 
393
  loss_total += current_loss
@@ -405,6 +394,8 @@ def train(args):
405
  accelerator.wait_for_everyone()
406
 
407
  updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
 
 
408
 
409
  if args.save_every_n_epochs is not None:
410
  model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
@@ -426,9 +417,6 @@ def train(args):
426
  if saving and args.save_state:
427
  train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
428
 
429
- train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device,
430
- vae, tokenizer, text_encoder, unet, prompt_replacement)
431
-
432
  # end of epoch
433
 
434
  is_main_process = accelerator.is_main_process
@@ -503,8 +491,6 @@ if __name__ == '__main__':
503
  train_util.add_sd_models_arguments(parser)
504
  train_util.add_dataset_arguments(parser, True, True, False)
505
  train_util.add_training_arguments(parser, True)
506
- train_util.add_optimizer_arguments(parser)
507
- config_util.add_config_arguments(parser)
508
 
509
  parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
510
  help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
 
11
  from diffusers import DDPMScheduler
12
 
13
  import library.train_util as train_util
14
+ from library.train_util import DreamBoothDataset, FineTuningDataset
 
 
 
 
15
 
16
  imagenet_templates_small = [
17
  "a photo of a {}",
 
79
  train_util.prepare_dataset_args(args, True)
80
 
81
  cache_latents = args.cache_latents
82
+ use_dreambooth_method = args.in_json is None
83
 
84
  if args.seed is not None:
85
  set_seed(args.seed)
 
139
  print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
140
 
141
  # データセットを準備する
142
+ if use_dreambooth_method:
143
+ print("Use DreamBooth method.")
144
+ train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
145
+ tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
146
+ args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
147
+ args.bucket_reso_steps, args.bucket_no_upscale,
148
+ args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
149
  else:
150
+ print("Train with captions.")
151
+ train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
152
+ tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
153
+ args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
154
+ args.bucket_reso_steps, args.bucket_no_upscale,
155
+ args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
156
+ args.dataset_repeats, args.debug_dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
159
  if use_template:
 
163
  captions = []
164
  for tmpl in templates:
165
  captions.append(tmpl.format(replace_to))
166
+ train_dataset.add_replacement("", captions)
167
+ elif args.num_vectors_per_token > 1:
168
+ replace_to = " ".join(token_strings)
169
+ train_dataset.add_replacement(args.token_string, replace_to)
170
 
171
+ train_dataset.make_buckets()
 
 
 
 
 
 
 
 
 
 
172
 
173
  if args.debug_dataset:
174
+ train_util.debug_dataset(train_dataset, show_input_ids=True)
175
  return
176
+ if len(train_dataset) == 0:
177
  print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
178
  return
179
 
 
 
 
180
  # モデルに xformers とか memory efficient attention を組み込む
181
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
182
 
 
186
  vae.requires_grad_(False)
187
  vae.eval()
188
  with torch.no_grad():
189
+ train_dataset.cache_latents(vae)
190
  vae.to("cpu")
191
  if torch.cuda.is_available():
192
  torch.cuda.empty_cache()
 
198
 
199
  # 学習に必要なクラスを準備する
200
  print("prepare optimizer, data loader etc.")
201
+
202
+ # 8-bit Adamを使う
203
+ if args.use_8bit_adam:
204
+ try:
205
+ import bitsandbytes as bnb
206
+ except ImportError:
207
+ raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
208
+ print("use 8-bit Adam optimizer")
209
+ optimizer_class = bnb.optim.AdamW8bit
210
+ elif args.use_lion_optimizer:
211
+ try:
212
+ import lion_pytorch
213
+ except ImportError:
214
+ raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
215
+ print("use Lion optimizer")
216
+ optimizer_class = lion_pytorch.Lion
217
+ else:
218
+ optimizer_class = torch.optim.AdamW
219
+
220
  trainable_params = text_encoder.get_input_embeddings().parameters()
221
+
222
+ # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
223
+ optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
224
 
225
  # dataloaderを準備する
226
  # DataLoaderのプロセス数:0はメインプロセスになる
227
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
228
  train_dataloader = torch.utils.data.DataLoader(
229
+ train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
230
 
231
  # 学習ステップ数を計算する
232
  if args.max_train_epochs is not None:
 
234
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
235
 
236
  # lr schedulerを用意する
237
+ lr_scheduler = diffusers.optimization.get_scheduler(
238
+ args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
 
239
 
240
  # acceleratorがなんかよろしくやってくれるらしい
241
  text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
 
283
  # 学習する
284
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
285
  print("running training / 学習開始")
286
+ print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
287
+ print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
288
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
289
  print(f" num epochs / epoch数: {num_train_epochs}")
290
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
 
303
 
304
  for epoch in range(num_train_epochs):
305
  print(f"epoch {epoch+1}/{num_train_epochs}")
306
+ train_dataset.set_current_epoch(epoch + 1)
307
 
308
  text_encoder.train()
309
 
310
  loss_total = 0
311
+ bef_epo_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
312
  for step, batch in enumerate(train_dataloader):
313
  with accelerator.accumulate(text_encoder):
314
  with torch.no_grad():
 
357
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
358
 
359
  accelerator.backward(loss)
360
+ if accelerator.sync_gradients:
361
  params_to_clip = text_encoder.get_input_embeddings().parameters()
362
+ accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
363
 
364
  optimizer.step()
365
  lr_scheduler.step()
 
374
  progress_bar.update(1)
375
  global_step += 1
376
 
 
 
 
377
  current_loss = loss.detach().item()
378
  if args.logging_dir is not None:
379
+ logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
 
 
380
  accelerator.log(logs, step=global_step)
381
 
382
  loss_total += current_loss
 
394
  accelerator.wait_for_everyone()
395
 
396
  updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
397
+ # d = updated_embs - bef_epo_embs
398
+ # print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
399
 
400
  if args.save_every_n_epochs is not None:
401
  model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
 
417
  if saving and args.save_state:
418
  train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
419
 
 
 
 
420
  # end of epoch
421
 
422
  is_main_process = accelerator.is_main_process
 
491
  train_util.add_sd_models_arguments(parser)
492
  train_util.add_dataset_arguments(parser, True, True, False)
493
  train_util.add_training_arguments(parser, True)
 
 
494
 
495
  parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
496
  help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")