abc commited on
Commit
350076d
·
1 Parent(s): b71766b

Upload 35 files

Browse files
append_module.py CHANGED
@@ -2,7 +2,19 @@ import argparse
2
  import json
3
  import shutil
4
  import time
5
- from typing import Dict, List, NamedTuple, Tuple
 
 
 
 
 
 
 
 
 
 
 
 
6
  from accelerate import Accelerator
7
  from torch.autograd.function import Function
8
  import glob
@@ -28,6 +40,7 @@ import safetensors.torch
28
 
29
  import library.model_util as model_util
30
  import library.train_util as train_util
 
31
 
32
  #============================================================================================================
33
  #AdafactorScheduleに暫定的にinitial_lrを層別に適用できるようにしたもの
@@ -115,6 +128,124 @@ def make_bucket_resolutions_fix(max_reso, min_reso, min_size=256, max_size=1024,
115
  return area_size_resos_list, area_size_list
116
 
117
  #============================================================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  #train_util 内より
119
  #============================================================================================================
120
  class BucketManager_append(train_util.BucketManager):
@@ -179,7 +310,7 @@ class BucketManager_append(train_util.BucketManager):
179
  bucket_size_id_list.append(bucket_size_id + i + 1)
180
  _min_error = 1000.
181
  _min_id = bucket_size_id
182
- for now_size_id in bucket_size_id:
183
  self.predefined_aspect_ratios = self.predefined_aspect_ratios_list[now_size_id]
184
  ar_errors = self.predefined_aspect_ratios - aspect_ratio
185
  ar_error = np.abs(ar_errors).min()
@@ -253,13 +384,13 @@ class BucketManager_append(train_util.BucketManager):
253
  return reso, resized_size, ar_error
254
 
255
  class DreamBoothDataset(train_util.DreamBoothDataset):
256
- def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset, min_resolution=None, area_step=None) -> None:
257
  print("use append DreamBoothDataset")
258
  self.min_resolution = min_resolution
259
  self.area_step = area_step
260
- super().__init__(batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens,
261
- resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight,
262
- flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
263
  def make_buckets(self):
264
  '''
265
  bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
@@ -353,11 +484,10 @@ class DreamBoothDataset(train_util.DreamBoothDataset):
353
  self._length = len(self.buckets_indices)
354
 
355
  class FineTuningDataset(train_util.FineTuningDataset):
356
- def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
357
  train_util.glob_images = glob_images
358
- super().__init__( json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
359
- resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range,
360
- random_crop, dataset_repeats, debug_dataset)
361
 
362
  def glob_images(directory, base="*", npz_flag=True):
363
  img_paths = []
@@ -373,13 +503,26 @@ def glob_images(directory, base="*", npz_flag=True):
373
  img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
374
  return img_paths
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  #============================================================================================================
377
  #networks.lora
378
  #============================================================================================================
379
  from networks.lora import LoRANetwork
380
  def replace_prepare_optimizer_params(networks):
381
- def prepare_optimizer_params(self, text_encoder_lr, unet_lr, scheduler_lr=None, loranames=None):
382
-
383
  def enumerate_params(loras, lora_name=None):
384
  params = []
385
  for lora in loras:
@@ -393,6 +536,7 @@ def replace_prepare_optimizer_params(networks):
393
  self.requires_grad_(True)
394
  all_params = []
395
  ret_scheduler_lr = []
 
396
 
397
  if loranames is not None:
398
  textencoder_names = [None]
@@ -405,22 +549,60 @@ def replace_prepare_optimizer_params(networks):
405
  if self.text_encoder_loras:
406
  for textencoder_name in textencoder_names:
407
  param_data = {'params': enumerate_params(self.text_encoder_loras, lora_name=textencoder_name)}
 
408
  if text_encoder_lr is not None:
409
  param_data['lr'] = text_encoder_lr
410
- if scheduler_lr is not None:
411
- ret_scheduler_lr.append(scheduler_lr[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  all_params.append(param_data)
413
 
414
  if self.unet_loras:
415
  for unet_name in unet_names:
416
  param_data = {'params': enumerate_params(self.unet_loras, lora_name=unet_name)}
 
417
  if unet_lr is not None:
418
  param_data['lr'] = unet_lr
419
- if scheduler_lr is not None:
420
- ret_scheduler_lr.append(scheduler_lr[1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  all_params.append(param_data)
422
 
423
- return all_params, ret_scheduler_lr
424
 
425
  LoRANetwork.prepare_optimizer_params = prepare_optimizer_params
426
 
@@ -429,14 +611,98 @@ def replace_prepare_optimizer_params(networks):
429
  #============================================================================================================
430
  def add_append_arguments(parser: argparse.ArgumentParser):
431
  # for train_network_opt.py
432
- parser.add_argument("--optimizer", type=str, default="AdamW", choices=["AdamW", "RAdam", "AdaBound", "AdaBelief", "AggMo", "AdamP", "Adastand", "Adastand_belief", "Apollo", "Lamb", "Ranger", "RangerVA", "Lookahead_Adam", "Lookahead_DiffGrad", "Yogi", "NovoGrad", "QHAdam", "DiffGrad", "MADGRAD", "Adafactor"], help="使用するoptimizerを指定する")
433
- parser.add_argument("--optimizer_arg", type=str, default=None, nargs='*')
 
 
434
  parser.add_argument("--split_lora_networks", action="store_true")
435
  parser.add_argument("--split_lora_level", type=int, default=0, help="どれくらい細分化するかの設定 0がunetのみを層別に 1がunetを大枠で分割 2がtextencoder含めて層別")
 
 
436
  parser.add_argument("--min_resolution", type=str, default=None)
437
  parser.add_argument("--area_step", type=int, default=1)
438
  parser.add_argument("--config", type=str, default=None)
439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  def create_split_names(split_flag, split_level):
441
  split_names = None
442
  if split_flag:
@@ -446,14 +712,23 @@ def create_split_names(split_flag, split_level):
446
  if split_level==1:
447
  unet_names.append(f"lora_unet_down_blocks_")
448
  unet_names.append(f"lora_unet_up_blocks_")
449
- elif split_level==2 or split_level==0:
450
- if split_level==2:
451
  text_encoder_names = []
452
  for i in range(12):
453
  text_encoder_names.append(f"lora_te_text_model_encoder_layers_{i}_")
 
 
 
 
 
 
 
454
  for i in range(3):
455
- unet_names.append(f"lora_unet_down_blocks_{i}")
456
- unet_names.append(f"lora_unet_up_blocks_{i+1}")
 
 
457
  split_names["text_encoder"] = text_encoder_names
458
  split_names["unet"] = unet_names
459
  return split_names
@@ -465,7 +740,7 @@ def get_config(parser):
465
  import datetime
466
  if os.path.splitext(args.config)[-1] == ".yaml":
467
  args.config = os.path.splitext(args.config)[0]
468
- config_path = f"./{args.config}.yaml"
469
  if os.path.exists(config_path):
470
  print(f"{config_path} から設定を読���込み中...")
471
  margs, rest = parser.parse_known_args()
@@ -486,19 +761,41 @@ def get_config(parser):
486
  args_type_dic[key] = act.type
487
  #データタイプの確認とargsにkeyの内容を代入していく
488
  for key, v in configs.items():
489
- if key in args_dic:
490
- if args_dic[key] is not None:
491
- new_type = type(args_dic[key])
492
- if (not type(v) == new_type) and (not new_type==list):
493
- v = new_type(v)
494
- else:
495
- if v is not None:
496
  if not type(v) == args_type_dic[key]:
497
  v = args_type_dic[key](v)
498
- args_dic[key] = v
499
  #最後にデフォから指定が変わってるものを変更する
500
  for key, v in change_def_dic.items():
501
  args_dic[key] = v
502
  else:
503
  print(f"{config_path} が見つかりませんでした")
504
  return args
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import json
3
  import shutil
4
  import time
5
+ from typing import (
6
+ Dict,
7
+ List,
8
+ NamedTuple,
9
+ Optional,
10
+ Sequence,
11
+ Tuple,
12
+ Union,
13
+ )
14
+ from dataclasses import (
15
+ asdict,
16
+ dataclass,
17
+ )
18
  from accelerate import Accelerator
19
  from torch.autograd.function import Function
20
  import glob
 
40
 
41
  import library.model_util as model_util
42
  import library.train_util as train_util
43
+ import library.config_util as config_util
44
 
45
  #============================================================================================================
46
  #AdafactorScheduleに暫定的にinitial_lrを層別に適用できるようにしたもの
 
128
  return area_size_resos_list, area_size_list
129
 
130
  #============================================================================================================
131
+ #config_util 内より
132
+ #============================================================================================================
133
+ @dataclass
134
+ class DreamBoothDatasetParams(config_util.DreamBoothDatasetParams):
135
+ min_resolution: Optional[Tuple[int, int]] = None
136
+ area_step : int = 2
137
+
138
+ class ConfigSanitizer(config_util.ConfigSanitizer):
139
+ #@config_util.curry
140
+ @staticmethod
141
+ def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
142
+ config_util.Schema(config_util.ExactSequence([klass, klass]))(value)
143
+ return tuple(value)
144
+
145
+ #@config_util.curry
146
+ @staticmethod
147
+ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
148
+ config_util.Schema(config_util.Any(klass, config_util.ExactSequence([klass, klass])))(value)
149
+ try:
150
+ config_util.Schema(klass)(value)
151
+ return (value, value)
152
+ except:
153
+ return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
154
+ # datasets schema
155
+ DATASET_ASCENDABLE_SCHEMA = {
156
+ "batch_size": int,
157
+ "bucket_no_upscale": bool,
158
+ "bucket_reso_steps": int,
159
+ "enable_bucket": bool,
160
+ "max_bucket_reso": int,
161
+ "min_bucket_reso": int,
162
+ "resolution": config_util.functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
163
+ "min_resolution": config_util.functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
164
+ "area_step": int,
165
+ }
166
+ def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None:
167
+ super().__init__(support_dreambooth, support_finetuning, support_dropout)
168
+ def _check(self):
169
+ print(self.db_dataset_schema)
170
+
171
+ class BlueprintGenerator(config_util.BlueprintGenerator):
172
+ def __init__(self, sanitizer: ConfigSanitizer):
173
+ config_util.DreamBoothDatasetParams = DreamBoothDatasetParams
174
+ super().__init__(sanitizer)
175
+
176
+ def generate_dataset_group_by_blueprint(dataset_group_blueprint: config_util.DatasetGroupBlueprint):
177
+ datasets: List[Union[DreamBoothDataset, FineTuningDataset]] = []
178
+
179
+ for dataset_blueprint in dataset_group_blueprint.datasets:
180
+ if dataset_blueprint.is_dreambooth:
181
+ subset_klass = train_util.DreamBoothSubset
182
+ dataset_klass = DreamBoothDataset
183
+ else:
184
+ subset_klass = train_util.FineTuningSubset
185
+ dataset_klass = FineTuningDataset
186
+
187
+ subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
188
+ dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
189
+ datasets.append(dataset)
190
+
191
+ # print info
192
+ info = ""
193
+ for i, dataset in enumerate(datasets):
194
+ is_dreambooth = isinstance(dataset, DreamBoothDataset)
195
+ info += config_util.dedent(f"""\
196
+ [Dataset {i}]
197
+ batch_size: {dataset.batch_size}
198
+ resolution: {(dataset.width, dataset.height)}
199
+ enable_bucket: {dataset.enable_bucket}
200
+ """)
201
+
202
+ if dataset.enable_bucket:
203
+ info += config_util.indent(config_util.dedent(f"""\
204
+ min_bucket_reso: {dataset.min_bucket_reso}
205
+ max_bucket_reso: {dataset.max_bucket_reso}
206
+ bucket_reso_steps: {dataset.bucket_reso_steps}
207
+ bucket_no_upscale: {dataset.bucket_no_upscale}
208
+ \n"""), " ")
209
+ else:
210
+ info += "\n"
211
+
212
+ for j, subset in enumerate(dataset.subsets):
213
+ info += config_util.indent(config_util.dedent(f"""\
214
+ [Subset {j} of Dataset {i}]
215
+ image_dir: "{subset.image_dir}"
216
+ image_count: {subset.img_count}
217
+ num_repeats: {subset.num_repeats}
218
+ shuffle_caption: {subset.shuffle_caption}
219
+ keep_tokens: {subset.keep_tokens}
220
+ caption_dropout_rate: {subset.caption_dropout_rate}
221
+ caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
222
+ caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
223
+ color_aug: {subset.color_aug}
224
+ flip_aug: {subset.flip_aug}
225
+ face_crop_aug_range: {subset.face_crop_aug_range}
226
+ random_crop: {subset.random_crop}
227
+ """), " ")
228
+
229
+ if is_dreambooth:
230
+ info += config_util.indent(config_util.dedent(f"""\
231
+ is_reg: {subset.is_reg}
232
+ class_tokens: {subset.class_tokens}
233
+ caption_extension: {subset.caption_extension}
234
+ \n"""), " ")
235
+ else:
236
+ info += config_util.indent(config_util.dedent(f"""\
237
+ metadata_file: {subset.metadata_file}
238
+ \n"""), " ")
239
+
240
+ print(info)
241
+
242
+ # make buckets first because it determines the length of dataset
243
+ for i, dataset in enumerate(datasets):
244
+ print(f"[Dataset {i}]")
245
+ dataset.make_buckets()
246
+
247
+ return train_util.DatasetGroup(datasets)
248
+ #============================================================================================================
249
  #train_util 内より
250
  #============================================================================================================
251
  class BucketManager_append(train_util.BucketManager):
 
310
  bucket_size_id_list.append(bucket_size_id + i + 1)
311
  _min_error = 1000.
312
  _min_id = bucket_size_id
313
+ for now_size_id in bucket_size_id_list:
314
  self.predefined_aspect_ratios = self.predefined_aspect_ratios_list[now_size_id]
315
  ar_errors = self.predefined_aspect_ratios - aspect_ratio
316
  ar_error = np.abs(ar_errors).min()
 
384
  return reso, resized_size, ar_error
385
 
386
  class DreamBoothDataset(train_util.DreamBoothDataset):
387
+ def __init__(self, subsets: Sequence[train_util.DreamBoothSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset, min_resolution=None, area_step=None) -> None:
388
  print("use append DreamBoothDataset")
389
  self.min_resolution = min_resolution
390
  self.area_step = area_step
391
+ super().__init__(subsets, batch_size, tokenizer, max_token_length,
392
+ resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale,
393
+ prior_loss_weight, debug_dataset)
394
  def make_buckets(self):
395
  '''
396
  bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
 
484
  self._length = len(self.buckets_indices)
485
 
486
  class FineTuningDataset(train_util.FineTuningDataset):
487
+ def __init__(self, subsets: Sequence[train_util.FineTuningSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset) -> None:
488
  train_util.glob_images = glob_images
489
+ super().__init__(subsets, batch_size, tokenizer, max_token_length,
490
+ resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, debug_dataset)
 
491
 
492
  def glob_images(directory, base="*", npz_flag=True):
493
  img_paths = []
 
503
  img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
504
  return img_paths
505
 
506
+ import transformers
507
+ from torch.optim import Optimizer
508
+ from diffusers.optimization import SchedulerType
509
+ from typing import Union
510
+ def get_scheduler_Adafactor(
511
+ name: Union[str, SchedulerType],
512
+ optimizer: Optimizer,
513
+ scheduler_arg: Dict
514
+ ):
515
+ if name.startswith("adafactor"):
516
+ assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
517
+ print(scheduler_arg)
518
+ return AdafactorSchedule_append(optimizer, **scheduler_arg)
519
  #============================================================================================================
520
  #networks.lora
521
  #============================================================================================================
522
  from networks.lora import LoRANetwork
523
  def replace_prepare_optimizer_params(networks):
524
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, loranames=None, lr_dic=None, block_args_dic=None):
525
+
526
  def enumerate_params(loras, lora_name=None):
527
  params = []
528
  for lora in loras:
 
536
  self.requires_grad_(True)
537
  all_params = []
538
  ret_scheduler_lr = []
539
+ used_names = []
540
 
541
  if loranames is not None:
542
  textencoder_names = [None]
 
549
  if self.text_encoder_loras:
550
  for textencoder_name in textencoder_names:
551
  param_data = {'params': enumerate_params(self.text_encoder_loras, lora_name=textencoder_name)}
552
+ used_names.append(textencoder_name)
553
  if text_encoder_lr is not None:
554
  param_data['lr'] = text_encoder_lr
555
+ if lr_dic is not None:
556
+ if textencoder_name in lr_dic:
557
+ param_data['lr'] = lr_dic[textencoder_name]
558
+ print(f"{textencoder_name} lr: {param_data['lr']}")
559
+
560
+ if block_args_dic is not None:
561
+ if "lora_te_" in block_args_dic:
562
+ for pname, value in block_args_dic["lora_te_"].items():
563
+ param_data[pname] = value
564
+ if textencoder_name in block_args_dic:
565
+ for pname, value in block_args_dic[textencoder_name].items():
566
+ param_data[pname] = value
567
+
568
+ if text_encoder_lr is not None:
569
+ ret_scheduler_lr.append(text_encoder_lr)
570
+ else:
571
+ ret_scheduler_lr.append(0.)
572
+ if lr_dic is not None:
573
+ if textencoder_name in lr_dic:
574
+ ret_scheduler_lr[-1] = lr_dic[textencoder_name]
575
  all_params.append(param_data)
576
 
577
  if self.unet_loras:
578
  for unet_name in unet_names:
579
  param_data = {'params': enumerate_params(self.unet_loras, lora_name=unet_name)}
580
+ used_names.append(unet_name)
581
  if unet_lr is not None:
582
  param_data['lr'] = unet_lr
583
+ if lr_dic is not None:
584
+ if unet_name in lr_dic:
585
+ param_data['lr'] = lr_dic[unet_name]
586
+ print(f"{unet_name} lr: {param_data['lr']}")
587
+
588
+ if block_args_dic is not None:
589
+ if "lora_unet_" in block_args_dic:
590
+ for pname, value in block_args_dic["lora_unet_"].items():
591
+ param_data[pname] = value
592
+ if unet_name in block_args_dic:
593
+ for pname, value in block_args_dic[unet_name].items():
594
+ param_data[pname] = value
595
+
596
+ if unet_lr is not None:
597
+ ret_scheduler_lr.append(unet_lr)
598
+ else:
599
+ ret_scheduler_lr.append(0.)
600
+ if lr_dic is not None:
601
+ if unet_name in lr_dic:
602
+ ret_scheduler_lr[-1] = lr_dic[unet_name]
603
  all_params.append(param_data)
604
 
605
+ return all_params, {"initial_lr" : ret_scheduler_lr}, used_names
606
 
607
  LoRANetwork.prepare_optimizer_params = prepare_optimizer_params
608
 
 
611
  #============================================================================================================
612
  def add_append_arguments(parser: argparse.ArgumentParser):
613
  # for train_network_opt.py
614
+ #parser.add_argument("--optimizer", type=str, default="AdamW", choices=["AdamW", "RAdam", "AdaBound", "AdaBelief", "AggMo", "AdamP", "Adastand", "Adastand_belief", "Apollo", "Lamb", "Ranger", "RangerVA", "Lookahead_Adam", "Lookahead_DiffGrad", "Yogi", "NovoGrad", "QHAdam", "DiffGrad", "MADGRAD", "Adafactor"], help="使用するoptimizerを指定する")
615
+ #parser.add_argument("--optimizer_arg", type=str, default=None, nargs='*')
616
+ parser.add_argument("--use_lookahead", action="store_true")
617
+ parser.add_argument("--lookahead_arg", type=str, nargs="*", default=None)
618
  parser.add_argument("--split_lora_networks", action="store_true")
619
  parser.add_argument("--split_lora_level", type=int, default=0, help="どれくらい細分化するかの設定 0がunetのみを層別に 1がunetを大枠で分割 2がtextencoder含めて層別")
620
+ parser.add_argument("--blocks_lr_setting", type=str, default=None)
621
+ parser.add_argument("--block_optim_args", type=str, nargs="*", default=None)
622
  parser.add_argument("--min_resolution", type=str, default=None)
623
  parser.add_argument("--area_step", type=int, default=1)
624
  parser.add_argument("--config", type=str, default=None)
625
 
626
+ def create_lr_blocks(lr_setting_str=None, block_optim_args=None):
627
+ ex_block_weight_dic = {
628
+ "BASE": "te",
629
+ "IN01": "down_0_at_0", "IN02": "down_0_at_1",
630
+ "IN04": "down_1_at_0", "IN05": "down_1_at_1",
631
+ "IN07": "down_2_at_0", "IN08": "down_2_at_1",
632
+ "MID": "mid",
633
+ "OUT03": "up_1_at_0", "OUT04": "up_1_at_1", "OUT05": "up_1_at_2",
634
+ "OUT06": "up_2_at_0", "OUT07": "up_2_at_1", "OUT08": "up_2_at_2",
635
+ "OUT09": "up_3_at_0", "OUT10": "up_3_at_1", "OUT11": "up_3_at_2",
636
+ }
637
+
638
+ blocks_name_dic = { "te": "lora_te_",
639
+ "unet": "lora_unet_",
640
+ "mid": "lora_unet_mid_block",
641
+ "down": "lora_unet_down_blocks_",
642
+ "up": "lora_unet_up_blocks_"}
643
+ for i in range(12):
644
+ blocks_name_dic[f"te_{i}"] = f"lora_te_text_model_encoder_layers_{i}_"
645
+ for i in range(3):
646
+ blocks_name_dic[f"down_{i}"] = f"lora_unet_down_blocks_{i}"
647
+ blocks_name_dic[f"up_{i+1}"] = f"lora_unet_up_blocks_{i+1}"
648
+ for i in range(3):
649
+ for j in range(2):
650
+ blocks_name_dic[f"down_{i}_at_{j}"] = f"lora_unet_down_blocks_{i}_attentions_{j}_"
651
+ for j in range(3):
652
+ blocks_name_dic[f"up_{i+1}_at_{j}"] = f"lora_unet_up_blocks_{i+1}_attentions_{j}_"
653
+
654
+ lr_dic = {}
655
+ if lr_setting_str==None or lr_setting_str=="":
656
+ pass
657
+ else:
658
+ lr_settings = lr_setting_str.replace(" ", "").split(",")
659
+ for lr_setting in lr_settings:
660
+ key, value = lr_setting.split("=")
661
+ if key in ex_block_weight_dic:
662
+ key = ex_block_weight_dic[key]
663
+ if key in blocks_name_dic:
664
+ new_key = blocks_name_dic[key]
665
+ lr_dic[new_key] = float(value)
666
+ if len(lr_dic)==0:
667
+ lr_dic = None
668
+
669
+ args_dic = {}
670
+ if (block_optim_args is None):
671
+ block_optim_args = []
672
+ if (len(block_optim_args)>0):
673
+ for my_arg in block_optim_args:
674
+ my_arg = my_arg.replace(" ", "")
675
+ splits = my_arg.split(":")
676
+ b_name = splits[0]
677
+ if b_name in ex_block_weight_dic:
678
+ b_name = ex_block_weight_dic[b_name]
679
+ new_b_name = blocks_name_dic[b_name]
680
+ key, _value = splits[1].split("=")
681
+ value_type = float
682
+ if len(splits)==3:
683
+ if _value=="str":
684
+ value_type = str
685
+ elif _value=="int":
686
+ value_type = int
687
+ _value = splits[2]
688
+ if _value=="true" or _value=="false":
689
+ value_type = bool
690
+ if "," in _value:
691
+ _value = _value.split(",")
692
+ for i in range(len(_value)):
693
+ _value[i] = value_type(_value[i])
694
+ value=tuple(_value)
695
+ else:
696
+ value = value_type(_value)
697
+
698
+ if not new_b_name in args_dic:
699
+ args_dic[new_b_name] = {}
700
+ args_dic[new_b_name][key] = value
701
+
702
+ if len(args_dic)==0:
703
+ args_dic = None
704
+ return lr_dic, args_dic
705
+
706
  def create_split_names(split_flag, split_level):
707
  split_names = None
708
  if split_flag:
 
712
  if split_level==1:
713
  unet_names.append(f"lora_unet_down_blocks_")
714
  unet_names.append(f"lora_unet_up_blocks_")
715
+ elif split_level==2 or split_level==0 or split_level==4:
716
+ if split_level>=2:
717
  text_encoder_names = []
718
  for i in range(12):
719
  text_encoder_names.append(f"lora_te_text_model_encoder_layers_{i}_")
720
+
721
+ if split_level<=2:
722
+ for i in range(3):
723
+ unet_names.append(f"lora_unet_down_blocks_{i}")
724
+ unet_names.append(f"lora_unet_up_blocks_{i+1}")
725
+
726
+ if split_level>=3:
727
  for i in range(3):
728
+ for j in range(2):
729
+ unet_names.append(f"lora_unet_down_blocks_{i}_attentions_{j}_")
730
+ for j in range(3):
731
+ unet_names.append(f"lora_unet_up_blocks_{i+1}_attentions_{j}_")
732
  split_names["text_encoder"] = text_encoder_names
733
  split_names["unet"] = unet_names
734
  return split_names
 
740
  import datetime
741
  if os.path.splitext(args.config)[-1] == ".yaml":
742
  args.config = os.path.splitext(args.config)[0]
743
+ config_path = f"{args.config}.yaml"
744
  if os.path.exists(config_path):
745
  print(f"{config_path} から設定を読���込み中...")
746
  margs, rest = parser.parse_known_args()
 
761
  args_type_dic[key] = act.type
762
  #データタイプの確認とargsにkeyの内容を代入していく
763
  for key, v in configs.items():
764
+ if v is not None:
765
+ if key in args_dic:
766
+ if args_dic[key] is not None:
767
+ new_type = type(args_dic[key])
768
+ if (not type(v) == new_type) and (not new_type==list):
769
+ v = new_type(v)
770
+ else:
771
  if not type(v) == args_type_dic[key]:
772
  v = args_type_dic[key](v)
773
+ args_dic[key] = v
774
  #最後にデフォから指定が変わってるものを変更する
775
  for key, v in change_def_dic.items():
776
  args_dic[key] = v
777
  else:
778
  print(f"{config_path} が見つかりませんでした")
779
  return args
780
+
781
+ '''
782
+ class GradientReversalFunction(torch.autograd.Function):
783
+ @staticmethod
784
+ def forward(ctx, input_forward: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
785
+ ctx.save_for_backward(scale)
786
+ return input_forward
787
+ @staticmethod
788
+ def backward(ctx, grad_backward: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
789
+ scale, = ctx.saved_tensors
790
+ return scale * -grad_backward, None
791
+
792
+ class GradientReversal(torch.nn.Module):
793
+ def __init__(self, scale: float):
794
+ super(GradientReversal, self).__init__()
795
+ self.scale = torch.tensor(scale)
796
+ def forward(self, x: torch.Tensor, flag: bool = False) -> torch.Tensor:
797
+ if flag:
798
+ return x
799
+ else:
800
+ return GradientReversalFunction.apply(x, self.scale)
801
+ '''
fine_tune.py CHANGED
@@ -13,7 +13,11 @@ import diffusers
13
  from diffusers import DDPMScheduler
14
 
15
  import library.train_util as train_util
16
-
 
 
 
 
17
 
18
  def collate_fn(examples):
19
  return examples[0]
@@ -30,25 +34,36 @@ def train(args):
30
 
31
  tokenizer = train_util.load_tokenizer(args)
32
 
33
- train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
34
- tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
35
- args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
36
- args.bucket_reso_steps, args.bucket_no_upscale,
37
- args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
38
- args.dataset_repeats, args.debug_dataset)
39
-
40
- # 学習データのdropout率を設定する
41
- train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
42
-
43
- train_dataset.make_buckets()
 
 
 
 
 
 
 
 
44
 
45
  if args.debug_dataset:
46
- train_util.debug_dataset(train_dataset)
47
  return
48
- if len(train_dataset) == 0:
49
  print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
50
  return
51
 
 
 
 
52
  # acceleratorを準備する
53
  print("prepare accelerator")
54
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
@@ -109,7 +124,7 @@ def train(args):
109
  vae.requires_grad_(False)
110
  vae.eval()
111
  with torch.no_grad():
112
- train_dataset.cache_latents(vae)
113
  vae.to("cpu")
114
  if torch.cuda.is_available():
115
  torch.cuda.empty_cache()
@@ -149,33 +164,13 @@ def train(args):
149
 
150
  # 学習に必要なクラスを準備する
151
  print("prepare optimizer, data loader etc.")
152
-
153
- # 8-bit Adamを使う
154
- if args.use_8bit_adam:
155
- try:
156
- import bitsandbytes as bnb
157
- except ImportError:
158
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
159
- print("use 8-bit Adam optimizer")
160
- optimizer_class = bnb.optim.AdamW8bit
161
- elif args.use_lion_optimizer:
162
- try:
163
- import lion_pytorch
164
- except ImportError:
165
- raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
166
- print("use Lion optimizer")
167
- optimizer_class = lion_pytorch.Lion
168
- else:
169
- optimizer_class = torch.optim.AdamW
170
-
171
- # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
172
- optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
173
 
174
  # dataloaderを準備する
175
  # DataLoaderのプロセス数:0はメインプロセスになる
176
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
177
  train_dataloader = torch.utils.data.DataLoader(
178
- train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
179
 
180
  # 学習ステップ数を計算する
181
  if args.max_train_epochs is not None:
@@ -183,8 +178,9 @@ def train(args):
183
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
184
 
185
  # lr schedulerを用意する
186
- lr_scheduler = diffusers.optimization.get_scheduler(
187
- args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
 
188
 
189
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
190
  if args.full_fp16:
@@ -218,7 +214,7 @@ def train(args):
218
  # 学習する
219
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
220
  print("running training / 学習開始")
221
- print(f" num examples / サンプル数: {train_dataset.num_train_images}")
222
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
223
  print(f" num epochs / epoch数: {num_train_epochs}")
224
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
@@ -237,7 +233,7 @@ def train(args):
237
 
238
  for epoch in range(num_train_epochs):
239
  print(f"epoch {epoch+1}/{num_train_epochs}")
240
- train_dataset.set_current_epoch(epoch + 1)
241
 
242
  for m in training_models:
243
  m.train()
@@ -286,11 +282,11 @@ def train(args):
286
  loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
287
 
288
  accelerator.backward(loss)
289
- if accelerator.sync_gradients:
290
  params_to_clip = []
291
  for m in training_models:
292
  params_to_clip.extend(m.parameters())
293
- accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
294
 
295
  optimizer.step()
296
  lr_scheduler.step()
@@ -301,11 +297,16 @@ def train(args):
301
  progress_bar.update(1)
302
  global_step += 1
303
 
 
 
304
  current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
305
  if args.logging_dir is not None:
306
- logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
 
 
307
  accelerator.log(logs, step=global_step)
308
 
 
309
  loss_total += current_loss
310
  avr_loss = loss_total / (step+1)
311
  logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
@@ -315,7 +316,7 @@ def train(args):
315
  break
316
 
317
  if args.logging_dir is not None:
318
- logs = {"epoch_loss": loss_total / len(train_dataloader)}
319
  accelerator.log(logs, step=epoch+1)
320
 
321
  accelerator.wait_for_everyone()
@@ -325,6 +326,8 @@ def train(args):
325
  train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
326
  save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
327
 
 
 
328
  is_main_process = accelerator.is_main_process
329
  if is_main_process:
330
  unet = unwrap_model(unet)
@@ -351,6 +354,8 @@ if __name__ == '__main__':
351
  train_util.add_dataset_arguments(parser, False, True, True)
352
  train_util.add_training_arguments(parser, False)
353
  train_util.add_sd_saving_arguments(parser)
 
 
354
 
355
  parser.add_argument("--diffusers_xformers", action='store_true',
356
  help='use xformers by diffusers / Diffusersでxformersを使用する')
 
13
  from diffusers import DDPMScheduler
14
 
15
  import library.train_util as train_util
16
+ import library.config_util as config_util
17
+ from library.config_util import (
18
+ ConfigSanitizer,
19
+ BlueprintGenerator,
20
+ )
21
 
22
  def collate_fn(examples):
23
  return examples[0]
 
34
 
35
  tokenizer = train_util.load_tokenizer(args)
36
 
37
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
38
+ if args.dataset_config is not None:
39
+ print(f"Load dataset config from {args.dataset_config}")
40
+ user_config = config_util.load_user_config(args.dataset_config)
41
+ ignored = ["train_data_dir", "in_json"]
42
+ if any(getattr(args, attr) is not None for attr in ignored):
43
+ print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
44
+ else:
45
+ user_config = {
46
+ "datasets": [{
47
+ "subsets": [{
48
+ "image_dir": args.train_data_dir,
49
+ "metadata_file": args.in_json,
50
+ }]
51
+ }]
52
+ }
53
+
54
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
55
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
56
 
57
  if args.debug_dataset:
58
+ train_util.debug_dataset(train_dataset_group)
59
  return
60
+ if len(train_dataset_group) == 0:
61
  print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
62
  return
63
 
64
+ if cache_latents:
65
+ assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
66
+
67
  # acceleratorを準備する
68
  print("prepare accelerator")
69
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
 
124
  vae.requires_grad_(False)
125
  vae.eval()
126
  with torch.no_grad():
127
+ train_dataset_group.cache_latents(vae)
128
  vae.to("cpu")
129
  if torch.cuda.is_available():
130
  torch.cuda.empty_cache()
 
164
 
165
  # 学習に必要なクラスを準備する
166
  print("prepare optimizer, data loader etc.")
167
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  # dataloaderを準備する
170
  # DataLoaderのプロセス数:0はメインプロセスになる
171
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
172
  train_dataloader = torch.utils.data.DataLoader(
173
+ train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
174
 
175
  # 学習ステップ数を計算する
176
  if args.max_train_epochs is not None:
 
178
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
179
 
180
  # lr schedulerを用意する
181
+ lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
182
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
183
+ num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
184
 
185
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
186
  if args.full_fp16:
 
214
  # 学習する
215
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
216
  print("running training / 学習開始")
217
+ print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
218
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
219
  print(f" num epochs / epoch数: {num_train_epochs}")
220
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
 
233
 
234
  for epoch in range(num_train_epochs):
235
  print(f"epoch {epoch+1}/{num_train_epochs}")
236
+ train_dataset_group.set_current_epoch(epoch + 1)
237
 
238
  for m in training_models:
239
  m.train()
 
282
  loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
283
 
284
  accelerator.backward(loss)
285
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
286
  params_to_clip = []
287
  for m in training_models:
288
  params_to_clip.extend(m.parameters())
289
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
290
 
291
  optimizer.step()
292
  lr_scheduler.step()
 
297
  progress_bar.update(1)
298
  global_step += 1
299
 
300
+ train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
301
+
302
  current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
303
  if args.logging_dir is not None:
304
+ logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
305
+ if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
306
+ logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
307
  accelerator.log(logs, step=global_step)
308
 
309
+ # TODO moving averageにする
310
  loss_total += current_loss
311
  avr_loss = loss_total / (step+1)
312
  logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
 
316
  break
317
 
318
  if args.logging_dir is not None:
319
+ logs = {"loss/epoch": loss_total / len(train_dataloader)}
320
  accelerator.log(logs, step=epoch+1)
321
 
322
  accelerator.wait_for_everyone()
 
326
  train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
327
  save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
328
 
329
+ train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
330
+
331
  is_main_process = accelerator.is_main_process
332
  if is_main_process:
333
  unet = unwrap_model(unet)
 
354
  train_util.add_dataset_arguments(parser, False, True, True)
355
  train_util.add_training_arguments(parser, False)
356
  train_util.add_sd_saving_arguments(parser)
357
+ train_util.add_optimizer_arguments(parser)
358
+ config_util.add_config_arguments(parser)
359
 
360
  parser.add_argument("--diffusers_xformers", action='store_true',
361
  help='use xformers by diffusers / Diffusersでxformersを使用する')
finetune/blip/blip.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import warnings
9
+ warnings.filterwarnings("ignore")
10
+
11
+ # from models.vit import VisionTransformer, interpolate_pos_embed
12
+ # from models.med import BertConfig, BertModel, BertLMHeadModel
13
+ from blip.vit import VisionTransformer, interpolate_pos_embed
14
+ from blip.med import BertConfig, BertModel, BertLMHeadModel
15
+ from transformers import BertTokenizer
16
+
17
+ import torch
18
+ from torch import nn
19
+ import torch.nn.functional as F
20
+
21
+ import os
22
+ from urllib.parse import urlparse
23
+ from timm.models.hub import download_cached_file
24
+
25
+ class BLIP_Base(nn.Module):
26
+ def __init__(self,
27
+ med_config = 'configs/med_config.json',
28
+ image_size = 224,
29
+ vit = 'base',
30
+ vit_grad_ckpt = False,
31
+ vit_ckpt_layer = 0,
32
+ ):
33
+ """
34
+ Args:
35
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
36
+ image_size (int): input image size
37
+ vit (str): model size of vision transformer
38
+ """
39
+ super().__init__()
40
+
41
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
42
+ self.tokenizer = init_tokenizer()
43
+ med_config = BertConfig.from_json_file(med_config)
44
+ med_config.encoder_width = vision_width
45
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
46
+
47
+
48
+ def forward(self, image, caption, mode):
49
+
50
+ assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
51
+ text = self.tokenizer(caption, return_tensors="pt").to(image.device)
52
+
53
+ if mode=='image':
54
+ # return image features
55
+ image_embeds = self.visual_encoder(image)
56
+ return image_embeds
57
+
58
+ elif mode=='text':
59
+ # return text features
60
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
61
+ return_dict = True, mode = 'text')
62
+ return text_output.last_hidden_state
63
+
64
+ elif mode=='multimodal':
65
+ # return multimodel features
66
+ image_embeds = self.visual_encoder(image)
67
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
68
+
69
+ text.input_ids[:,0] = self.tokenizer.enc_token_id
70
+ output = self.text_encoder(text.input_ids,
71
+ attention_mask = text.attention_mask,
72
+ encoder_hidden_states = image_embeds,
73
+ encoder_attention_mask = image_atts,
74
+ return_dict = True,
75
+ )
76
+ return output.last_hidden_state
77
+
78
+
79
+
80
+ class BLIP_Decoder(nn.Module):
81
+ def __init__(self,
82
+ med_config = 'configs/med_config.json',
83
+ image_size = 384,
84
+ vit = 'base',
85
+ vit_grad_ckpt = False,
86
+ vit_ckpt_layer = 0,
87
+ prompt = 'a picture of ',
88
+ ):
89
+ """
90
+ Args:
91
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
92
+ image_size (int): input image size
93
+ vit (str): model size of vision transformer
94
+ """
95
+ super().__init__()
96
+
97
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
98
+ self.tokenizer = init_tokenizer()
99
+ med_config = BertConfig.from_json_file(med_config)
100
+ med_config.encoder_width = vision_width
101
+ self.text_decoder = BertLMHeadModel(config=med_config)
102
+
103
+ self.prompt = prompt
104
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
105
+
106
+
107
+ def forward(self, image, caption):
108
+
109
+ image_embeds = self.visual_encoder(image)
110
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
111
+
112
+ text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
113
+
114
+ text.input_ids[:,0] = self.tokenizer.bos_token_id
115
+
116
+ decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
117
+ decoder_targets[:,:self.prompt_length] = -100
118
+
119
+ decoder_output = self.text_decoder(text.input_ids,
120
+ attention_mask = text.attention_mask,
121
+ encoder_hidden_states = image_embeds,
122
+ encoder_attention_mask = image_atts,
123
+ labels = decoder_targets,
124
+ return_dict = True,
125
+ )
126
+ loss_lm = decoder_output.loss
127
+
128
+ return loss_lm
129
+
130
+ def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
131
+ image_embeds = self.visual_encoder(image)
132
+
133
+ if not sample:
134
+ image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
135
+
136
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
137
+ model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
138
+
139
+ prompt = [self.prompt] * image.size(0)
140
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
141
+ input_ids[:,0] = self.tokenizer.bos_token_id
142
+ input_ids = input_ids[:, :-1]
143
+
144
+ if sample:
145
+ #nucleus sampling
146
+ outputs = self.text_decoder.generate(input_ids=input_ids,
147
+ max_length=max_length,
148
+ min_length=min_length,
149
+ do_sample=True,
150
+ top_p=top_p,
151
+ num_return_sequences=1,
152
+ eos_token_id=self.tokenizer.sep_token_id,
153
+ pad_token_id=self.tokenizer.pad_token_id,
154
+ repetition_penalty=1.1,
155
+ **model_kwargs)
156
+ else:
157
+ #beam search
158
+ outputs = self.text_decoder.generate(input_ids=input_ids,
159
+ max_length=max_length,
160
+ min_length=min_length,
161
+ num_beams=num_beams,
162
+ eos_token_id=self.tokenizer.sep_token_id,
163
+ pad_token_id=self.tokenizer.pad_token_id,
164
+ repetition_penalty=repetition_penalty,
165
+ **model_kwargs)
166
+
167
+ captions = []
168
+ for output in outputs:
169
+ caption = self.tokenizer.decode(output, skip_special_tokens=True)
170
+ captions.append(caption[len(self.prompt):])
171
+ return captions
172
+
173
+
174
+ def blip_decoder(pretrained='',**kwargs):
175
+ model = BLIP_Decoder(**kwargs)
176
+ if pretrained:
177
+ model,msg = load_checkpoint(model,pretrained)
178
+ assert(len(msg.missing_keys)==0)
179
+ return model
180
+
181
+ def blip_feature_extractor(pretrained='',**kwargs):
182
+ model = BLIP_Base(**kwargs)
183
+ if pretrained:
184
+ model,msg = load_checkpoint(model,pretrained)
185
+ assert(len(msg.missing_keys)==0)
186
+ return model
187
+
188
+ def init_tokenizer():
189
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
190
+ tokenizer.add_special_tokens({'bos_token':'[DEC]'})
191
+ tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
192
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
193
+ return tokenizer
194
+
195
+
196
+ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
197
+
198
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
199
+ if vit=='base':
200
+ vision_width = 768
201
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
202
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
203
+ drop_path_rate=0 or drop_path_rate
204
+ )
205
+ elif vit=='large':
206
+ vision_width = 1024
207
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
208
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
209
+ drop_path_rate=0.1 or drop_path_rate
210
+ )
211
+ return visual_encoder, vision_width
212
+
213
+ def is_url(url_or_filename):
214
+ parsed = urlparse(url_or_filename)
215
+ return parsed.scheme in ("http", "https")
216
+
217
+ def load_checkpoint(model,url_or_filename):
218
+ if is_url(url_or_filename):
219
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
220
+ checkpoint = torch.load(cached_file, map_location='cpu')
221
+ elif os.path.isfile(url_or_filename):
222
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
223
+ else:
224
+ raise RuntimeError('checkpoint url or path is invalid')
225
+
226
+ state_dict = checkpoint['model']
227
+
228
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
229
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
230
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
231
+ model.visual_encoder_m)
232
+ for key in model.state_dict().keys():
233
+ if key in state_dict.keys():
234
+ if state_dict[key].shape!=model.state_dict()[key].shape:
235
+ del state_dict[key]
236
+
237
+ msg = model.load_state_dict(state_dict,strict=False)
238
+ print('load checkpoint from %s'%url_or_filename)
239
+ return model,msg
240
+
finetune/blip/med.py ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ '''
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ class BertEmbeddings(nn.Module):
53
+ """Construct the embeddings from word and position embeddings."""
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
58
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
59
+
60
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
61
+ # any TensorFlow checkpoint file
62
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
63
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
64
+
65
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
66
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
67
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
68
+
69
+ self.config = config
70
+
71
+ def forward(
72
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
73
+ ):
74
+ if input_ids is not None:
75
+ input_shape = input_ids.size()
76
+ else:
77
+ input_shape = inputs_embeds.size()[:-1]
78
+
79
+ seq_length = input_shape[1]
80
+
81
+ if position_ids is None:
82
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
83
+
84
+ if inputs_embeds is None:
85
+ inputs_embeds = self.word_embeddings(input_ids)
86
+
87
+ embeddings = inputs_embeds
88
+
89
+ if self.position_embedding_type == "absolute":
90
+ position_embeddings = self.position_embeddings(position_ids)
91
+ embeddings += position_embeddings
92
+ embeddings = self.LayerNorm(embeddings)
93
+ embeddings = self.dropout(embeddings)
94
+ return embeddings
95
+
96
+
97
+ class BertSelfAttention(nn.Module):
98
+ def __init__(self, config, is_cross_attention):
99
+ super().__init__()
100
+ self.config = config
101
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
102
+ raise ValueError(
103
+ "The hidden size (%d) is not a multiple of the number of attention "
104
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
105
+ )
106
+
107
+ self.num_attention_heads = config.num_attention_heads
108
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
109
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
110
+
111
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
112
+ if is_cross_attention:
113
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
114
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
115
+ else:
116
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
117
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
118
+
119
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
120
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
121
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
122
+ self.max_position_embeddings = config.max_position_embeddings
123
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
124
+ self.save_attention = False
125
+
126
+ def save_attn_gradients(self, attn_gradients):
127
+ self.attn_gradients = attn_gradients
128
+
129
+ def get_attn_gradients(self):
130
+ return self.attn_gradients
131
+
132
+ def save_attention_map(self, attention_map):
133
+ self.attention_map = attention_map
134
+
135
+ def get_attention_map(self):
136
+ return self.attention_map
137
+
138
+ def transpose_for_scores(self, x):
139
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
140
+ x = x.view(*new_x_shape)
141
+ return x.permute(0, 2, 1, 3)
142
+
143
+ def forward(
144
+ self,
145
+ hidden_states,
146
+ attention_mask=None,
147
+ head_mask=None,
148
+ encoder_hidden_states=None,
149
+ encoder_attention_mask=None,
150
+ past_key_value=None,
151
+ output_attentions=False,
152
+ ):
153
+ mixed_query_layer = self.query(hidden_states)
154
+
155
+ # If this is instantiated as a cross-attention module, the keys
156
+ # and values come from an encoder; the attention mask needs to be
157
+ # such that the encoder's padding tokens are not attended to.
158
+ is_cross_attention = encoder_hidden_states is not None
159
+
160
+ if is_cross_attention:
161
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
162
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
163
+ attention_mask = encoder_attention_mask
164
+ elif past_key_value is not None:
165
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
166
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
167
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
168
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
169
+ else:
170
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
171
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
172
+
173
+ query_layer = self.transpose_for_scores(mixed_query_layer)
174
+
175
+ past_key_value = (key_layer, value_layer)
176
+
177
+ # Take the dot product between "query" and "key" to get the raw attention scores.
178
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
179
+
180
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
181
+ seq_length = hidden_states.size()[1]
182
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
183
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
184
+ distance = position_ids_l - position_ids_r
185
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
186
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
187
+
188
+ if self.position_embedding_type == "relative_key":
189
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
190
+ attention_scores = attention_scores + relative_position_scores
191
+ elif self.position_embedding_type == "relative_key_query":
192
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
193
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
194
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
195
+
196
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
197
+ if attention_mask is not None:
198
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
199
+ attention_scores = attention_scores + attention_mask
200
+
201
+ # Normalize the attention scores to probabilities.
202
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
203
+
204
+ if is_cross_attention and self.save_attention:
205
+ self.save_attention_map(attention_probs)
206
+ attention_probs.register_hook(self.save_attn_gradients)
207
+
208
+ # This is actually dropping out entire tokens to attend to, which might
209
+ # seem a bit unusual, but is taken from the original Transformer paper.
210
+ attention_probs_dropped = self.dropout(attention_probs)
211
+
212
+ # Mask heads if we want to
213
+ if head_mask is not None:
214
+ attention_probs_dropped = attention_probs_dropped * head_mask
215
+
216
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
217
+
218
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
219
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
220
+ context_layer = context_layer.view(*new_context_layer_shape)
221
+
222
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
223
+
224
+ outputs = outputs + (past_key_value,)
225
+ return outputs
226
+
227
+
228
+ class BertSelfOutput(nn.Module):
229
+ def __init__(self, config):
230
+ super().__init__()
231
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
232
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
233
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
234
+
235
+ def forward(self, hidden_states, input_tensor):
236
+ hidden_states = self.dense(hidden_states)
237
+ hidden_states = self.dropout(hidden_states)
238
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
239
+ return hidden_states
240
+
241
+
242
+ class BertAttention(nn.Module):
243
+ def __init__(self, config, is_cross_attention=False):
244
+ super().__init__()
245
+ self.self = BertSelfAttention(config, is_cross_attention)
246
+ self.output = BertSelfOutput(config)
247
+ self.pruned_heads = set()
248
+
249
+ def prune_heads(self, heads):
250
+ if len(heads) == 0:
251
+ return
252
+ heads, index = find_pruneable_heads_and_indices(
253
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
254
+ )
255
+
256
+ # Prune linear layers
257
+ self.self.query = prune_linear_layer(self.self.query, index)
258
+ self.self.key = prune_linear_layer(self.self.key, index)
259
+ self.self.value = prune_linear_layer(self.self.value, index)
260
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
261
+
262
+ # Update hyper params and store pruned heads
263
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
264
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
265
+ self.pruned_heads = self.pruned_heads.union(heads)
266
+
267
+ def forward(
268
+ self,
269
+ hidden_states,
270
+ attention_mask=None,
271
+ head_mask=None,
272
+ encoder_hidden_states=None,
273
+ encoder_attention_mask=None,
274
+ past_key_value=None,
275
+ output_attentions=False,
276
+ ):
277
+ self_outputs = self.self(
278
+ hidden_states,
279
+ attention_mask,
280
+ head_mask,
281
+ encoder_hidden_states,
282
+ encoder_attention_mask,
283
+ past_key_value,
284
+ output_attentions,
285
+ )
286
+ attention_output = self.output(self_outputs[0], hidden_states)
287
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
288
+ return outputs
289
+
290
+
291
+ class BertIntermediate(nn.Module):
292
+ def __init__(self, config):
293
+ super().__init__()
294
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
295
+ if isinstance(config.hidden_act, str):
296
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
297
+ else:
298
+ self.intermediate_act_fn = config.hidden_act
299
+
300
+ def forward(self, hidden_states):
301
+ hidden_states = self.dense(hidden_states)
302
+ hidden_states = self.intermediate_act_fn(hidden_states)
303
+ return hidden_states
304
+
305
+
306
+ class BertOutput(nn.Module):
307
+ def __init__(self, config):
308
+ super().__init__()
309
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
310
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
311
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
312
+
313
+ def forward(self, hidden_states, input_tensor):
314
+ hidden_states = self.dense(hidden_states)
315
+ hidden_states = self.dropout(hidden_states)
316
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
317
+ return hidden_states
318
+
319
+
320
+ class BertLayer(nn.Module):
321
+ def __init__(self, config, layer_num):
322
+ super().__init__()
323
+ self.config = config
324
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
325
+ self.seq_len_dim = 1
326
+ self.attention = BertAttention(config)
327
+ self.layer_num = layer_num
328
+ if self.config.add_cross_attention:
329
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
330
+ self.intermediate = BertIntermediate(config)
331
+ self.output = BertOutput(config)
332
+
333
+ def forward(
334
+ self,
335
+ hidden_states,
336
+ attention_mask=None,
337
+ head_mask=None,
338
+ encoder_hidden_states=None,
339
+ encoder_attention_mask=None,
340
+ past_key_value=None,
341
+ output_attentions=False,
342
+ mode=None,
343
+ ):
344
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
345
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
346
+ self_attention_outputs = self.attention(
347
+ hidden_states,
348
+ attention_mask,
349
+ head_mask,
350
+ output_attentions=output_attentions,
351
+ past_key_value=self_attn_past_key_value,
352
+ )
353
+ attention_output = self_attention_outputs[0]
354
+
355
+ outputs = self_attention_outputs[1:-1]
356
+ present_key_value = self_attention_outputs[-1]
357
+
358
+ if mode=='multimodal':
359
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
360
+
361
+ cross_attention_outputs = self.crossattention(
362
+ attention_output,
363
+ attention_mask,
364
+ head_mask,
365
+ encoder_hidden_states,
366
+ encoder_attention_mask,
367
+ output_attentions=output_attentions,
368
+ )
369
+ attention_output = cross_attention_outputs[0]
370
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
371
+ layer_output = apply_chunking_to_forward(
372
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
373
+ )
374
+ outputs = (layer_output,) + outputs
375
+
376
+ outputs = outputs + (present_key_value,)
377
+
378
+ return outputs
379
+
380
+ def feed_forward_chunk(self, attention_output):
381
+ intermediate_output = self.intermediate(attention_output)
382
+ layer_output = self.output(intermediate_output, attention_output)
383
+ return layer_output
384
+
385
+
386
+ class BertEncoder(nn.Module):
387
+ def __init__(self, config):
388
+ super().__init__()
389
+ self.config = config
390
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
391
+ self.gradient_checkpointing = False
392
+
393
+ def forward(
394
+ self,
395
+ hidden_states,
396
+ attention_mask=None,
397
+ head_mask=None,
398
+ encoder_hidden_states=None,
399
+ encoder_attention_mask=None,
400
+ past_key_values=None,
401
+ use_cache=None,
402
+ output_attentions=False,
403
+ output_hidden_states=False,
404
+ return_dict=True,
405
+ mode='multimodal',
406
+ ):
407
+ all_hidden_states = () if output_hidden_states else None
408
+ all_self_attentions = () if output_attentions else None
409
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
410
+
411
+ next_decoder_cache = () if use_cache else None
412
+
413
+ for i in range(self.config.num_hidden_layers):
414
+ layer_module = self.layer[i]
415
+ if output_hidden_states:
416
+ all_hidden_states = all_hidden_states + (hidden_states,)
417
+
418
+ layer_head_mask = head_mask[i] if head_mask is not None else None
419
+ past_key_value = past_key_values[i] if past_key_values is not None else None
420
+
421
+ if self.gradient_checkpointing and self.training:
422
+
423
+ if use_cache:
424
+ logger.warn(
425
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
426
+ )
427
+ use_cache = False
428
+
429
+ def create_custom_forward(module):
430
+ def custom_forward(*inputs):
431
+ return module(*inputs, past_key_value, output_attentions)
432
+
433
+ return custom_forward
434
+
435
+ layer_outputs = torch.utils.checkpoint.checkpoint(
436
+ create_custom_forward(layer_module),
437
+ hidden_states,
438
+ attention_mask,
439
+ layer_head_mask,
440
+ encoder_hidden_states,
441
+ encoder_attention_mask,
442
+ mode=mode,
443
+ )
444
+ else:
445
+ layer_outputs = layer_module(
446
+ hidden_states,
447
+ attention_mask,
448
+ layer_head_mask,
449
+ encoder_hidden_states,
450
+ encoder_attention_mask,
451
+ past_key_value,
452
+ output_attentions,
453
+ mode=mode,
454
+ )
455
+
456
+ hidden_states = layer_outputs[0]
457
+ if use_cache:
458
+ next_decoder_cache += (layer_outputs[-1],)
459
+ if output_attentions:
460
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
461
+
462
+ if output_hidden_states:
463
+ all_hidden_states = all_hidden_states + (hidden_states,)
464
+
465
+ if not return_dict:
466
+ return tuple(
467
+ v
468
+ for v in [
469
+ hidden_states,
470
+ next_decoder_cache,
471
+ all_hidden_states,
472
+ all_self_attentions,
473
+ all_cross_attentions,
474
+ ]
475
+ if v is not None
476
+ )
477
+ return BaseModelOutputWithPastAndCrossAttentions(
478
+ last_hidden_state=hidden_states,
479
+ past_key_values=next_decoder_cache,
480
+ hidden_states=all_hidden_states,
481
+ attentions=all_self_attentions,
482
+ cross_attentions=all_cross_attentions,
483
+ )
484
+
485
+
486
+ class BertPooler(nn.Module):
487
+ def __init__(self, config):
488
+ super().__init__()
489
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
490
+ self.activation = nn.Tanh()
491
+
492
+ def forward(self, hidden_states):
493
+ # We "pool" the model by simply taking the hidden state corresponding
494
+ # to the first token.
495
+ first_token_tensor = hidden_states[:, 0]
496
+ pooled_output = self.dense(first_token_tensor)
497
+ pooled_output = self.activation(pooled_output)
498
+ return pooled_output
499
+
500
+
501
+ class BertPredictionHeadTransform(nn.Module):
502
+ def __init__(self, config):
503
+ super().__init__()
504
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
505
+ if isinstance(config.hidden_act, str):
506
+ self.transform_act_fn = ACT2FN[config.hidden_act]
507
+ else:
508
+ self.transform_act_fn = config.hidden_act
509
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
510
+
511
+ def forward(self, hidden_states):
512
+ hidden_states = self.dense(hidden_states)
513
+ hidden_states = self.transform_act_fn(hidden_states)
514
+ hidden_states = self.LayerNorm(hidden_states)
515
+ return hidden_states
516
+
517
+
518
+ class BertLMPredictionHead(nn.Module):
519
+ def __init__(self, config):
520
+ super().__init__()
521
+ self.transform = BertPredictionHeadTransform(config)
522
+
523
+ # The output weights are the same as the input embeddings, but there is
524
+ # an output-only bias for each token.
525
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
526
+
527
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
528
+
529
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
530
+ self.decoder.bias = self.bias
531
+
532
+ def forward(self, hidden_states):
533
+ hidden_states = self.transform(hidden_states)
534
+ hidden_states = self.decoder(hidden_states)
535
+ return hidden_states
536
+
537
+
538
+ class BertOnlyMLMHead(nn.Module):
539
+ def __init__(self, config):
540
+ super().__init__()
541
+ self.predictions = BertLMPredictionHead(config)
542
+
543
+ def forward(self, sequence_output):
544
+ prediction_scores = self.predictions(sequence_output)
545
+ return prediction_scores
546
+
547
+
548
+ class BertPreTrainedModel(PreTrainedModel):
549
+ """
550
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
551
+ models.
552
+ """
553
+
554
+ config_class = BertConfig
555
+ base_model_prefix = "bert"
556
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
557
+
558
+ def _init_weights(self, module):
559
+ """ Initialize the weights """
560
+ if isinstance(module, (nn.Linear, nn.Embedding)):
561
+ # Slightly different from the TF version which uses truncated_normal for initialization
562
+ # cf https://github.com/pytorch/pytorch/pull/5617
563
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
564
+ elif isinstance(module, nn.LayerNorm):
565
+ module.bias.data.zero_()
566
+ module.weight.data.fill_(1.0)
567
+ if isinstance(module, nn.Linear) and module.bias is not None:
568
+ module.bias.data.zero_()
569
+
570
+
571
+ class BertModel(BertPreTrainedModel):
572
+ """
573
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
574
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
575
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
576
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
577
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
578
+ input to the forward pass.
579
+ """
580
+
581
+ def __init__(self, config, add_pooling_layer=True):
582
+ super().__init__(config)
583
+ self.config = config
584
+
585
+ self.embeddings = BertEmbeddings(config)
586
+
587
+ self.encoder = BertEncoder(config)
588
+
589
+ self.pooler = BertPooler(config) if add_pooling_layer else None
590
+
591
+ self.init_weights()
592
+
593
+
594
+ def get_input_embeddings(self):
595
+ return self.embeddings.word_embeddings
596
+
597
+ def set_input_embeddings(self, value):
598
+ self.embeddings.word_embeddings = value
599
+
600
+ def _prune_heads(self, heads_to_prune):
601
+ """
602
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
603
+ class PreTrainedModel
604
+ """
605
+ for layer, heads in heads_to_prune.items():
606
+ self.encoder.layer[layer].attention.prune_heads(heads)
607
+
608
+
609
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
610
+ """
611
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
612
+
613
+ Arguments:
614
+ attention_mask (:obj:`torch.Tensor`):
615
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
616
+ input_shape (:obj:`Tuple[int]`):
617
+ The shape of the input to the model.
618
+ device: (:obj:`torch.device`):
619
+ The device of the input to the model.
620
+
621
+ Returns:
622
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
623
+ """
624
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
625
+ # ourselves in which case we just need to make it broadcastable to all heads.
626
+ if attention_mask.dim() == 3:
627
+ extended_attention_mask = attention_mask[:, None, :, :]
628
+ elif attention_mask.dim() == 2:
629
+ # Provided a padding mask of dimensions [batch_size, seq_length]
630
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
631
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
632
+ if is_decoder:
633
+ batch_size, seq_length = input_shape
634
+
635
+ seq_ids = torch.arange(seq_length, device=device)
636
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
637
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
638
+ # causal and attention masks must have same type with pytorch version < 1.3
639
+ causal_mask = causal_mask.to(attention_mask.dtype)
640
+
641
+ if causal_mask.shape[1] < attention_mask.shape[1]:
642
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
643
+ causal_mask = torch.cat(
644
+ [
645
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
646
+ causal_mask,
647
+ ],
648
+ axis=-1,
649
+ )
650
+
651
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
652
+ else:
653
+ extended_attention_mask = attention_mask[:, None, None, :]
654
+ else:
655
+ raise ValueError(
656
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
657
+ input_shape, attention_mask.shape
658
+ )
659
+ )
660
+
661
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
662
+ # masked positions, this operation will create a tensor which is 0.0 for
663
+ # positions we want to attend and -10000.0 for masked positions.
664
+ # Since we are adding it to the raw scores before the softmax, this is
665
+ # effectively the same as removing these entirely.
666
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
667
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
668
+ return extended_attention_mask
669
+
670
+ def forward(
671
+ self,
672
+ input_ids=None,
673
+ attention_mask=None,
674
+ position_ids=None,
675
+ head_mask=None,
676
+ inputs_embeds=None,
677
+ encoder_embeds=None,
678
+ encoder_hidden_states=None,
679
+ encoder_attention_mask=None,
680
+ past_key_values=None,
681
+ use_cache=None,
682
+ output_attentions=None,
683
+ output_hidden_states=None,
684
+ return_dict=None,
685
+ is_decoder=False,
686
+ mode='multimodal',
687
+ ):
688
+ r"""
689
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
690
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
691
+ the model is configured as a decoder.
692
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
693
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
694
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
695
+ - 1 for tokens that are **not masked**,
696
+ - 0 for tokens that are **masked**.
697
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
698
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
699
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
700
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
701
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
702
+ use_cache (:obj:`bool`, `optional`):
703
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
704
+ decoding (see :obj:`past_key_values`).
705
+ """
706
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
707
+ output_hidden_states = (
708
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
709
+ )
710
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
711
+
712
+ if is_decoder:
713
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
714
+ else:
715
+ use_cache = False
716
+
717
+ if input_ids is not None and inputs_embeds is not None:
718
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
719
+ elif input_ids is not None:
720
+ input_shape = input_ids.size()
721
+ batch_size, seq_length = input_shape
722
+ device = input_ids.device
723
+ elif inputs_embeds is not None:
724
+ input_shape = inputs_embeds.size()[:-1]
725
+ batch_size, seq_length = input_shape
726
+ device = inputs_embeds.device
727
+ elif encoder_embeds is not None:
728
+ input_shape = encoder_embeds.size()[:-1]
729
+ batch_size, seq_length = input_shape
730
+ device = encoder_embeds.device
731
+ else:
732
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
733
+
734
+ # past_key_values_length
735
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
736
+
737
+ if attention_mask is None:
738
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
739
+
740
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
741
+ # ourselves in which case we just need to make it broadcastable to all heads.
742
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
743
+ device, is_decoder)
744
+
745
+ # If a 2D or 3D attention mask is provided for the cross-attention
746
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
747
+ if encoder_hidden_states is not None:
748
+ if type(encoder_hidden_states) == list:
749
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
750
+ else:
751
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
752
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
753
+
754
+ if type(encoder_attention_mask) == list:
755
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
756
+ elif encoder_attention_mask is None:
757
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
758
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
759
+ else:
760
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
761
+ else:
762
+ encoder_extended_attention_mask = None
763
+
764
+ # Prepare head mask if needed
765
+ # 1.0 in head_mask indicate we keep the head
766
+ # attention_probs has shape bsz x n_heads x N x N
767
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
768
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
769
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
770
+
771
+ if encoder_embeds is None:
772
+ embedding_output = self.embeddings(
773
+ input_ids=input_ids,
774
+ position_ids=position_ids,
775
+ inputs_embeds=inputs_embeds,
776
+ past_key_values_length=past_key_values_length,
777
+ )
778
+ else:
779
+ embedding_output = encoder_embeds
780
+
781
+ encoder_outputs = self.encoder(
782
+ embedding_output,
783
+ attention_mask=extended_attention_mask,
784
+ head_mask=head_mask,
785
+ encoder_hidden_states=encoder_hidden_states,
786
+ encoder_attention_mask=encoder_extended_attention_mask,
787
+ past_key_values=past_key_values,
788
+ use_cache=use_cache,
789
+ output_attentions=output_attentions,
790
+ output_hidden_states=output_hidden_states,
791
+ return_dict=return_dict,
792
+ mode=mode,
793
+ )
794
+ sequence_output = encoder_outputs[0]
795
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
796
+
797
+ if not return_dict:
798
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
799
+
800
+ return BaseModelOutputWithPoolingAndCrossAttentions(
801
+ last_hidden_state=sequence_output,
802
+ pooler_output=pooled_output,
803
+ past_key_values=encoder_outputs.past_key_values,
804
+ hidden_states=encoder_outputs.hidden_states,
805
+ attentions=encoder_outputs.attentions,
806
+ cross_attentions=encoder_outputs.cross_attentions,
807
+ )
808
+
809
+
810
+
811
+ class BertLMHeadModel(BertPreTrainedModel):
812
+
813
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
814
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
815
+
816
+ def __init__(self, config):
817
+ super().__init__(config)
818
+
819
+ self.bert = BertModel(config, add_pooling_layer=False)
820
+ self.cls = BertOnlyMLMHead(config)
821
+
822
+ self.init_weights()
823
+
824
+ def get_output_embeddings(self):
825
+ return self.cls.predictions.decoder
826
+
827
+ def set_output_embeddings(self, new_embeddings):
828
+ self.cls.predictions.decoder = new_embeddings
829
+
830
+ def forward(
831
+ self,
832
+ input_ids=None,
833
+ attention_mask=None,
834
+ position_ids=None,
835
+ head_mask=None,
836
+ inputs_embeds=None,
837
+ encoder_hidden_states=None,
838
+ encoder_attention_mask=None,
839
+ labels=None,
840
+ past_key_values=None,
841
+ use_cache=None,
842
+ output_attentions=None,
843
+ output_hidden_states=None,
844
+ return_dict=None,
845
+ return_logits=False,
846
+ is_decoder=True,
847
+ reduction='mean',
848
+ mode='multimodal',
849
+ ):
850
+ r"""
851
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
852
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
853
+ the model is configured as a decoder.
854
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
855
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
856
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
857
+ - 1 for tokens that are **not masked**,
858
+ - 0 for tokens that are **masked**.
859
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
860
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
861
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
862
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
863
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
864
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
865
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
866
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
867
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
868
+ use_cache (:obj:`bool`, `optional`):
869
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
870
+ decoding (see :obj:`past_key_values`).
871
+ Returns:
872
+ Example::
873
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
874
+ >>> import torch
875
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
876
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
877
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
878
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
879
+ >>> outputs = model(**inputs)
880
+ >>> prediction_logits = outputs.logits
881
+ """
882
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
883
+ if labels is not None:
884
+ use_cache = False
885
+
886
+ outputs = self.bert(
887
+ input_ids,
888
+ attention_mask=attention_mask,
889
+ position_ids=position_ids,
890
+ head_mask=head_mask,
891
+ inputs_embeds=inputs_embeds,
892
+ encoder_hidden_states=encoder_hidden_states,
893
+ encoder_attention_mask=encoder_attention_mask,
894
+ past_key_values=past_key_values,
895
+ use_cache=use_cache,
896
+ output_attentions=output_attentions,
897
+ output_hidden_states=output_hidden_states,
898
+ return_dict=return_dict,
899
+ is_decoder=is_decoder,
900
+ mode=mode,
901
+ )
902
+
903
+ sequence_output = outputs[0]
904
+ prediction_scores = self.cls(sequence_output)
905
+
906
+ if return_logits:
907
+ return prediction_scores[:, :-1, :].contiguous()
908
+
909
+ lm_loss = None
910
+ if labels is not None:
911
+ # we are doing next-token prediction; shift prediction scores and input ids by one
912
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
913
+ labels = labels[:, 1:].contiguous()
914
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
915
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
916
+ if reduction=='none':
917
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
918
+
919
+ if not return_dict:
920
+ output = (prediction_scores,) + outputs[2:]
921
+ return ((lm_loss,) + output) if lm_loss is not None else output
922
+
923
+ return CausalLMOutputWithCrossAttentions(
924
+ loss=lm_loss,
925
+ logits=prediction_scores,
926
+ past_key_values=outputs.past_key_values,
927
+ hidden_states=outputs.hidden_states,
928
+ attentions=outputs.attentions,
929
+ cross_attentions=outputs.cross_attentions,
930
+ )
931
+
932
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
933
+ input_shape = input_ids.shape
934
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
935
+ if attention_mask is None:
936
+ attention_mask = input_ids.new_ones(input_shape)
937
+
938
+ # cut decoder_input_ids if past is used
939
+ if past is not None:
940
+ input_ids = input_ids[:, -1:]
941
+
942
+ return {
943
+ "input_ids": input_ids,
944
+ "attention_mask": attention_mask,
945
+ "past_key_values": past,
946
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
947
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
948
+ "is_decoder": True,
949
+ }
950
+
951
+ def _reorder_cache(self, past, beam_idx):
952
+ reordered_past = ()
953
+ for layer_past in past:
954
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
955
+ return reordered_past
finetune/blip/med_config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30524,
19
+ "encoder_width": 768,
20
+ "add_cross_attention": true
21
+ }
22
+
finetune/blip/vit.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on timm code base
8
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ '''
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from functools import partial
15
+
16
+ from timm.models.vision_transformer import _cfg, PatchEmbed
17
+ from timm.models.registry import register_model
18
+ from timm.models.layers import trunc_normal_, DropPath
19
+ from timm.models.helpers import named_apply, adapt_input_conv
20
+
21
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
22
+
23
+ class Mlp(nn.Module):
24
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
25
+ """
26
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
42
+
43
+
44
+ class Attention(nn.Module):
45
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
50
+ self.scale = qk_scale or head_dim ** -0.5
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+ self.attn_gradients = None
56
+ self.attention_map = None
57
+
58
+ def save_attn_gradients(self, attn_gradients):
59
+ self.attn_gradients = attn_gradients
60
+
61
+ def get_attn_gradients(self):
62
+ return self.attn_gradients
63
+
64
+ def save_attention_map(self, attention_map):
65
+ self.attention_map = attention_map
66
+
67
+ def get_attention_map(self):
68
+ return self.attention_map
69
+
70
+ def forward(self, x, register_hook=False):
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
73
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
74
+
75
+ attn = (q @ k.transpose(-2, -1)) * self.scale
76
+ attn = attn.softmax(dim=-1)
77
+ attn = self.attn_drop(attn)
78
+
79
+ if register_hook:
80
+ self.save_attention_map(attn)
81
+ attn.register_hook(self.save_attn_gradients)
82
+
83
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
84
+ x = self.proj(x)
85
+ x = self.proj_drop(x)
86
+ return x
87
+
88
+
89
+ class Block(nn.Module):
90
+
91
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
92
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
93
+ super().__init__()
94
+ self.norm1 = norm_layer(dim)
95
+ self.attn = Attention(
96
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
97
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
98
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
99
+ self.norm2 = norm_layer(dim)
100
+ mlp_hidden_dim = int(dim * mlp_ratio)
101
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
102
+
103
+ if use_grad_checkpointing:
104
+ self.attn = checkpoint_wrapper(self.attn)
105
+ self.mlp = checkpoint_wrapper(self.mlp)
106
+
107
+ def forward(self, x, register_hook=False):
108
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
109
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
110
+ return x
111
+
112
+
113
+ class VisionTransformer(nn.Module):
114
+ """ Vision Transformer
115
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
116
+ https://arxiv.org/abs/2010.11929
117
+ """
118
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
119
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
120
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
121
+ use_grad_checkpointing=False, ckpt_layer=0):
122
+ """
123
+ Args:
124
+ img_size (int, tuple): input image size
125
+ patch_size (int, tuple): patch size
126
+ in_chans (int): number of input channels
127
+ num_classes (int): number of classes for classification head
128
+ embed_dim (int): embedding dimension
129
+ depth (int): depth of transformer
130
+ num_heads (int): number of attention heads
131
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
132
+ qkv_bias (bool): enable bias for qkv if True
133
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
134
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
135
+ drop_rate (float): dropout rate
136
+ attn_drop_rate (float): attention dropout rate
137
+ drop_path_rate (float): stochastic depth rate
138
+ norm_layer: (nn.Module): normalization layer
139
+ """
140
+ super().__init__()
141
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
142
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
143
+
144
+ self.patch_embed = PatchEmbed(
145
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
146
+
147
+ num_patches = self.patch_embed.num_patches
148
+
149
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
150
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
151
+ self.pos_drop = nn.Dropout(p=drop_rate)
152
+
153
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
154
+ self.blocks = nn.ModuleList([
155
+ Block(
156
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
157
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
158
+ use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
159
+ )
160
+ for i in range(depth)])
161
+ self.norm = norm_layer(embed_dim)
162
+
163
+ trunc_normal_(self.pos_embed, std=.02)
164
+ trunc_normal_(self.cls_token, std=.02)
165
+ self.apply(self._init_weights)
166
+
167
+ def _init_weights(self, m):
168
+ if isinstance(m, nn.Linear):
169
+ trunc_normal_(m.weight, std=.02)
170
+ if isinstance(m, nn.Linear) and m.bias is not None:
171
+ nn.init.constant_(m.bias, 0)
172
+ elif isinstance(m, nn.LayerNorm):
173
+ nn.init.constant_(m.bias, 0)
174
+ nn.init.constant_(m.weight, 1.0)
175
+
176
+ @torch.jit.ignore
177
+ def no_weight_decay(self):
178
+ return {'pos_embed', 'cls_token'}
179
+
180
+ def forward(self, x, register_blk=-1):
181
+ B = x.shape[0]
182
+ x = self.patch_embed(x)
183
+
184
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
185
+ x = torch.cat((cls_tokens, x), dim=1)
186
+
187
+ x = x + self.pos_embed[:,:x.size(1),:]
188
+ x = self.pos_drop(x)
189
+
190
+ for i,blk in enumerate(self.blocks):
191
+ x = blk(x, register_blk==i)
192
+ x = self.norm(x)
193
+
194
+ return x
195
+
196
+ @torch.jit.ignore()
197
+ def load_pretrained(self, checkpoint_path, prefix=''):
198
+ _load_weights(self, checkpoint_path, prefix)
199
+
200
+
201
+ @torch.no_grad()
202
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
203
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
204
+ """
205
+ import numpy as np
206
+
207
+ def _n2p(w, t=True):
208
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
209
+ w = w.flatten()
210
+ if t:
211
+ if w.ndim == 4:
212
+ w = w.transpose([3, 2, 0, 1])
213
+ elif w.ndim == 3:
214
+ w = w.transpose([2, 0, 1])
215
+ elif w.ndim == 2:
216
+ w = w.transpose([1, 0])
217
+ return torch.from_numpy(w)
218
+
219
+ w = np.load(checkpoint_path)
220
+ if not prefix and 'opt/target/embedding/kernel' in w:
221
+ prefix = 'opt/target/'
222
+
223
+ if hasattr(model.patch_embed, 'backbone'):
224
+ # hybrid
225
+ backbone = model.patch_embed.backbone
226
+ stem_only = not hasattr(backbone, 'stem')
227
+ stem = backbone if stem_only else backbone.stem
228
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
229
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
230
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
231
+ if not stem_only:
232
+ for i, stage in enumerate(backbone.stages):
233
+ for j, block in enumerate(stage.blocks):
234
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
235
+ for r in range(3):
236
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
237
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
238
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
239
+ if block.downsample is not None:
240
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
241
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
242
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
243
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
244
+ else:
245
+ embed_conv_w = adapt_input_conv(
246
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
247
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
248
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
249
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
250
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
251
+ if pos_embed_w.shape != model.pos_embed.shape:
252
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
253
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
254
+ model.pos_embed.copy_(pos_embed_w)
255
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
256
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
257
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
258
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
259
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
260
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
261
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
262
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
263
+ for i, block in enumerate(model.blocks.children()):
264
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
265
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
266
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
267
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
268
+ block.attn.qkv.weight.copy_(torch.cat([
269
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
270
+ block.attn.qkv.bias.copy_(torch.cat([
271
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
272
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
273
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
274
+ for r in range(2):
275
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
276
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
277
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
278
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
279
+
280
+
281
+ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
282
+ # interpolate position embedding
283
+ embedding_size = pos_embed_checkpoint.shape[-1]
284
+ num_patches = visual_encoder.patch_embed.num_patches
285
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
286
+ # height (== width) for the checkpoint position embedding
287
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
288
+ # height (== width) for the new position embedding
289
+ new_size = int(num_patches ** 0.5)
290
+
291
+ if orig_size!=new_size:
292
+ # class_token and dist_token are kept unchanged
293
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
294
+ # only the position tokens are interpolated
295
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
296
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
297
+ pos_tokens = torch.nn.functional.interpolate(
298
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
299
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
300
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
301
+ print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
302
+
303
+ return new_pos_embed
304
+ else:
305
+ return pos_embed_checkpoint
finetune/clean_captions_and_tags.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # このスクリプトのライセンスは、Apache License 2.0とします
2
+ # (c) 2022 Kohya S. @kohya_ss
3
+
4
+ import argparse
5
+ import glob
6
+ import os
7
+ import json
8
+ import re
9
+
10
+ from tqdm import tqdm
11
+
12
+ PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
13
+ PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
14
+ PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ')
15
+ PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ')
16
+
17
+ # 複数人がいるとき、複数の髪色や目の色が定義されていれば削除する
18
+ PATTERNS_REMOVE_IN_MULTI = [
19
+ PATTERN_HAIR_LENGTH,
20
+ PATTERN_HAIR_CUT,
21
+ re.compile(r', [\w\-]+ eyes, '),
22
+ re.compile(r', ([\w\-]+ sleeves|sleeveless), '),
23
+ # 複数の髪型定義がある場合は削除する
24
+ re.compile(
25
+ r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '),
26
+ ]
27
+
28
+
29
+ def clean_tags(image_key, tags):
30
+ # replace '_' to ' '
31
+ tags = tags.replace('^_^', '^@@@^')
32
+ tags = tags.replace('_', ' ')
33
+ tags = tags.replace('^@@@^', '^_^')
34
+
35
+ # remove rating: deepdanbooruのみ
36
+ tokens = tags.split(", rating")
37
+ if len(tokens) == 1:
38
+ # WD14 taggerのときはこちらになるのでメッセージは出さない
39
+ # print("no rating:")
40
+ # print(f"{image_key} {tags}")
41
+ pass
42
+ else:
43
+ if len(tokens) > 2:
44
+ print("multiple ratings:")
45
+ print(f"{image_key} {tags}")
46
+ tags = tokens[0]
47
+
48
+ tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
49
+
50
+ # 複数の人物がいる場合は髪色等のタグを削除する
51
+ if 'girls' in tags or 'boys' in tags:
52
+ for pat in PATTERNS_REMOVE_IN_MULTI:
53
+ found = pat.findall(tags)
54
+ if len(found) > 1: # 二つ以上、タグがある
55
+ tags = pat.sub("", tags)
56
+
57
+ # 髪の特殊対応
58
+ srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) # 髪の長さタグは例外なので避けておく(全員が同じ髪の長さの場合)
59
+ if srch_hair_len:
60
+ org = srch_hair_len.group()
61
+ tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags)
62
+
63
+ found = PATTERN_HAIR.findall(tags)
64
+ if len(found) > 1:
65
+ tags = PATTERN_HAIR.sub("", tags)
66
+
67
+ if srch_hair_len:
68
+ tags = tags.replace(", @@@, ", org) # 戻す
69
+
70
+ # white shirtとshirtみたいな重複タグの削除
71
+ found = PATTERN_WORD.findall(tags)
72
+ for word in found:
73
+ if re.search(f", ((\w+) )+{word}, ", tags):
74
+ tags = tags.replace(f", {word}, ", "")
75
+
76
+ tags = tags.replace(", , ", ", ")
77
+ assert tags.startswith(", ") and tags.endswith(", ")
78
+ tags = tags[2:-2]
79
+ return tags
80
+
81
+
82
+ # 上から順に検索、置換される
83
+ # ('置換元文字列', '置換後文字列')
84
+ CAPTION_REPLACEMENTS = [
85
+ ('anime anime', 'anime'),
86
+ ('young ', ''),
87
+ ('anime girl', 'girl'),
88
+ ('cartoon female', 'girl'),
89
+ ('cartoon lady', 'girl'),
90
+ ('cartoon character', 'girl'), # a or ~s
91
+ ('cartoon woman', 'girl'),
92
+ ('cartoon women', 'girls'),
93
+ ('cartoon girl', 'girl'),
94
+ ('anime female', 'girl'),
95
+ ('anime lady', 'girl'),
96
+ ('anime character', 'girl'), # a or ~s
97
+ ('anime woman', 'girl'),
98
+ ('anime women', 'girls'),
99
+ ('lady', 'girl'),
100
+ ('female', 'girl'),
101
+ ('woman', 'girl'),
102
+ ('women', 'girls'),
103
+ ('people', 'girls'),
104
+ ('person', 'girl'),
105
+ ('a cartoon figure', 'a figure'),
106
+ ('a cartoon image', 'an image'),
107
+ ('a cartoon picture', 'a picture'),
108
+ ('an anime cartoon image', 'an image'),
109
+ ('a cartoon anime drawing', 'a drawing'),
110
+ ('a cartoon drawing', 'a drawing'),
111
+ ('girl girl', 'girl'),
112
+ ]
113
+
114
+
115
+ def clean_caption(caption):
116
+ for rf, rt in CAPTION_REPLACEMENTS:
117
+ replaced = True
118
+ while replaced:
119
+ bef = caption
120
+ caption = caption.replace(rf, rt)
121
+ replaced = bef != caption
122
+ return caption
123
+
124
+
125
+ def main(args):
126
+ if os.path.exists(args.in_json):
127
+ print(f"loading existing metadata: {args.in_json}")
128
+ with open(args.in_json, "rt", encoding='utf-8') as f:
129
+ metadata = json.load(f)
130
+ else:
131
+ print("no metadata / メタデータファイルがありません")
132
+ return
133
+
134
+ print("cleaning captions and tags.")
135
+ image_keys = list(metadata.keys())
136
+ for image_key in tqdm(image_keys):
137
+ tags = metadata[image_key].get('tags')
138
+ if tags is None:
139
+ print(f"image does not have tags / メタデータにタグがありません: {image_key}")
140
+ else:
141
+ org = tags
142
+ tags = clean_tags(image_key, tags)
143
+ metadata[image_key]['tags'] = tags
144
+ if args.debug and org != tags:
145
+ print("FROM: " + org)
146
+ print("TO: " + tags)
147
+
148
+ caption = metadata[image_key].get('caption')
149
+ if caption is None:
150
+ print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
151
+ else:
152
+ org = caption
153
+ caption = clean_caption(caption)
154
+ metadata[image_key]['caption'] = caption
155
+ if args.debug and org != caption:
156
+ print("FROM: " + org)
157
+ print("TO: " + caption)
158
+
159
+ # metadataを書き出して終わり
160
+ print(f"writing metadata: {args.out_json}")
161
+ with open(args.out_json, "wt", encoding='utf-8') as f:
162
+ json.dump(metadata, f, indent=2)
163
+ print("done!")
164
+
165
+
166
+ if __name__ == '__main__':
167
+ parser = argparse.ArgumentParser()
168
+ # parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
169
+ parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
170
+ parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
171
+ parser.add_argument("--debug", action="store_true", help="debug mode")
172
+
173
+ args, unknown = parser.parse_known_args()
174
+ if len(unknown) == 1:
175
+ print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
176
+ print("All captions and tags in the metadata are processed.")
177
+ print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
178
+ print("メタデータ内のすべてのキャプションとタグが処理されます。")
179
+ args.in_json = args.out_json
180
+ args.out_json = unknown[0]
181
+ elif len(unknown) > 0:
182
+ raise ValueError(f"error: unrecognized arguments: {unknown}")
183
+
184
+ main(args)
finetune/hypernetwork_nai.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NAI compatible
2
+
3
+ import torch
4
+
5
+
6
+ class HypernetworkModule(torch.nn.Module):
7
+ def __init__(self, dim, multiplier=1.0):
8
+ super().__init__()
9
+
10
+ linear1 = torch.nn.Linear(dim, dim * 2)
11
+ linear2 = torch.nn.Linear(dim * 2, dim)
12
+ linear1.weight.data.normal_(mean=0.0, std=0.01)
13
+ linear1.bias.data.zero_()
14
+ linear2.weight.data.normal_(mean=0.0, std=0.01)
15
+ linear2.bias.data.zero_()
16
+ linears = [linear1, linear2]
17
+
18
+ self.linear = torch.nn.Sequential(*linears)
19
+ self.multiplier = multiplier
20
+
21
+ def forward(self, x):
22
+ return x + self.linear(x) * self.multiplier
23
+
24
+
25
+ class Hypernetwork(torch.nn.Module):
26
+ enable_sizes = [320, 640, 768, 1280]
27
+ # return self.modules[Hypernetwork.enable_sizes.index(size)]
28
+
29
+ def __init__(self, multiplier=1.0) -> None:
30
+ super().__init__()
31
+ self.modules = []
32
+ for size in Hypernetwork.enable_sizes:
33
+ self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier)))
34
+ self.register_module(f"{size}_0", self.modules[-1][0])
35
+ self.register_module(f"{size}_1", self.modules[-1][1])
36
+
37
+ def apply_to_stable_diffusion(self, text_encoder, vae, unet):
38
+ blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks
39
+ for block in blocks:
40
+ for subblk in block:
41
+ if 'SpatialTransformer' in str(type(subblk)):
42
+ for tf_block in subblk.transformer_blocks:
43
+ for attn in [tf_block.attn1, tf_block.attn2]:
44
+ size = attn.context_dim
45
+ if size in Hypernetwork.enable_sizes:
46
+ attn.hypernetwork = self
47
+ else:
48
+ attn.hypernetwork = None
49
+
50
+ def apply_to_diffusers(self, text_encoder, vae, unet):
51
+ blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks
52
+ for block in blocks:
53
+ if hasattr(block, 'attentions'):
54
+ for subblk in block.attentions:
55
+ if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~
56
+ for tf_block in subblk.transformer_blocks:
57
+ for attn in [tf_block.attn1, tf_block.attn2]:
58
+ size = attn.to_k.in_features
59
+ if size in Hypernetwork.enable_sizes:
60
+ attn.hypernetwork = self
61
+ else:
62
+ attn.hypernetwork = None
63
+ return True # TODO error checking
64
+
65
+ def forward(self, x, context):
66
+ size = context.shape[-1]
67
+ assert size in Hypernetwork.enable_sizes
68
+ module = self.modules[Hypernetwork.enable_sizes.index(size)]
69
+ return module[0].forward(context), module[1].forward(context)
70
+
71
+ def load_from_state_dict(self, state_dict):
72
+ # old ver to new ver
73
+ changes = {
74
+ 'linear1.bias': 'linear.0.bias',
75
+ 'linear1.weight': 'linear.0.weight',
76
+ 'linear2.bias': 'linear.1.bias',
77
+ 'linear2.weight': 'linear.1.weight',
78
+ }
79
+ for key_from, key_to in changes.items():
80
+ if key_from in state_dict:
81
+ state_dict[key_to] = state_dict[key_from]
82
+ del state_dict[key_from]
83
+
84
+ for size, sd in state_dict.items():
85
+ if type(size) == int:
86
+ self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True)
87
+ self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True)
88
+ return True
89
+
90
+ def get_state_dict(self):
91
+ state_dict = {}
92
+ for i, size in enumerate(Hypernetwork.enable_sizes):
93
+ sd0 = self.modules[i][0].state_dict()
94
+ sd1 = self.modules[i][1].state_dict()
95
+ state_dict[size] = [sd0, sd1]
96
+ return state_dict
finetune/make_captions.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import os
4
+ import json
5
+ import random
6
+
7
+ from PIL import Image
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ import torch
11
+ from torchvision import transforms
12
+ from torchvision.transforms.functional import InterpolationMode
13
+ from blip.blip import blip_decoder
14
+ import library.train_util as train_util
15
+
16
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+
18
+
19
+ IMAGE_SIZE = 384
20
+
21
+ # 正方形でいいのか? という気がするがソースがそうなので
22
+ IMAGE_TRANSFORM = transforms.Compose([
23
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
26
+ ])
27
+
28
+ # 共通化したいが微妙に処理が異なる……
29
+ class ImageLoadingTransformDataset(torch.utils.data.Dataset):
30
+ def __init__(self, image_paths):
31
+ self.images = image_paths
32
+
33
+ def __len__(self):
34
+ return len(self.images)
35
+
36
+ def __getitem__(self, idx):
37
+ img_path = self.images[idx]
38
+
39
+ try:
40
+ image = Image.open(img_path).convert("RGB")
41
+ # convert to tensor temporarily so dataloader will accept it
42
+ tensor = IMAGE_TRANSFORM(image)
43
+ except Exception as e:
44
+ print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
45
+ return None
46
+
47
+ return (tensor, img_path)
48
+
49
+
50
+ def collate_fn_remove_corrupted(batch):
51
+ """Collate function that allows to remove corrupted examples in the
52
+ dataloader. It expects that the dataloader returns 'None' when that occurs.
53
+ The 'None's in the batch are removed.
54
+ """
55
+ # Filter out all the Nones (corrupted examples)
56
+ batch = list(filter(lambda x: x is not None, batch))
57
+ return batch
58
+
59
+
60
+ def main(args):
61
+ # fix the seed for reproducibility
62
+ seed = args.seed # + utils.get_rank()
63
+ torch.manual_seed(seed)
64
+ np.random.seed(seed)
65
+ random.seed(seed)
66
+
67
+ if not os.path.exists("blip"):
68
+ args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
69
+
70
+ cwd = os.getcwd()
71
+ print('Current Working Directory is: ', cwd)
72
+ os.chdir('finetune')
73
+
74
+ print(f"load images from {args.train_data_dir}")
75
+ image_paths = train_util.glob_images(args.train_data_dir)
76
+ print(f"found {len(image_paths)} images.")
77
+
78
+ print(f"loading BLIP caption: {args.caption_weights}")
79
+ model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json")
80
+ model.eval()
81
+ model = model.to(DEVICE)
82
+ print("BLIP loaded")
83
+
84
+ # captioningする
85
+ def run_batch(path_imgs):
86
+ imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
87
+
88
+ with torch.no_grad():
89
+ if args.beam_search:
90
+ captions = model.generate(imgs, sample=False, num_beams=args.num_beams,
91
+ max_length=args.max_length, min_length=args.min_length)
92
+ else:
93
+ captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length)
94
+
95
+ for (image_path, _), caption in zip(path_imgs, captions):
96
+ with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
97
+ f.write(caption + "\n")
98
+ if args.debug:
99
+ print(image_path, caption)
100
+
101
+ # 読み込みの高速化のためにDataLoaderを使うオプション
102
+ if args.max_data_loader_n_workers is not None:
103
+ dataset = ImageLoadingTransformDataset(image_paths)
104
+ data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
105
+ num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
106
+ else:
107
+ data = [[(None, ip)] for ip in image_paths]
108
+
109
+ b_imgs = []
110
+ for data_entry in tqdm(data, smoothing=0.0):
111
+ for data in data_entry:
112
+ if data is None:
113
+ continue
114
+
115
+ img_tensor, image_path = data
116
+ if img_tensor is None:
117
+ try:
118
+ raw_image = Image.open(image_path)
119
+ if raw_image.mode != 'RGB':
120
+ raw_image = raw_image.convert("RGB")
121
+ img_tensor = IMAGE_TRANSFORM(raw_image)
122
+ except Exception as e:
123
+ print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
124
+ continue
125
+
126
+ b_imgs.append((image_path, img_tensor))
127
+ if len(b_imgs) >= args.batch_size:
128
+ run_batch(b_imgs)
129
+ b_imgs.clear()
130
+ if len(b_imgs) > 0:
131
+ run_batch(b_imgs)
132
+
133
+ print("done!")
134
+
135
+
136
+ if __name__ == '__main__':
137
+ parser = argparse.ArgumentParser()
138
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
139
+ parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
140
+ help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)")
141
+ parser.add_argument("--caption_extention", type=str, default=None,
142
+ help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
143
+ parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
144
+ parser.add_argument("--beam_search", action="store_true",
145
+ help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)")
146
+ parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
147
+ parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
148
+ help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
149
+ parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
150
+ parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
151
+ parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
152
+ parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
153
+ parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed')
154
+ parser.add_argument("--debug", action="store_true", help="debug mode")
155
+
156
+ args = parser.parse_args()
157
+
158
+ # スペルミスしていたオプションを復元する
159
+ if args.caption_extention is not None:
160
+ args.caption_extension = args.caption_extention
161
+
162
+ main(args)
finetune/make_captions_by_git.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import re
4
+
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+ import torch
8
+ from transformers import AutoProcessor, AutoModelForCausalLM
9
+ from transformers.generation.utils import GenerationMixin
10
+
11
+ import library.train_util as train_util
12
+
13
+
14
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+
16
+ PATTERN_REPLACE = [
17
+ re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
18
+ re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'),
19
+ re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"),
20
+ re.compile(r'with the number \d+ on (it|\w+ \w+)'),
21
+ re.compile(r'with the words "'),
22
+ re.compile(r'word \w+ on it'),
23
+ re.compile(r'that says the word \w+ on it'),
24
+ re.compile('that says\'the word "( on it)?'),
25
+ ]
26
+
27
+ # 誤検知しまくりの with the word xxxx を消す
28
+
29
+
30
+ def remove_words(captions, debug):
31
+ removed_caps = []
32
+ for caption in captions:
33
+ cap = caption
34
+ for pat in PATTERN_REPLACE:
35
+ cap = pat.sub("", cap)
36
+ if debug and cap != caption:
37
+ print(caption)
38
+ print(cap)
39
+ removed_caps.append(cap)
40
+ return removed_caps
41
+
42
+
43
+ def collate_fn_remove_corrupted(batch):
44
+ """Collate function that allows to remove corrupted examples in the
45
+ dataloader. It expects that the dataloader returns 'None' when that occurs.
46
+ The 'None's in the batch are removed.
47
+ """
48
+ # Filter out all the Nones (corrupted examples)
49
+ batch = list(filter(lambda x: x is not None, batch))
50
+ return batch
51
+
52
+
53
+ def main(args):
54
+ # GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
55
+ org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
56
+ curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
57
+
58
+ # input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す
59
+ # ここより上で置き換えようとするとすごく大変
60
+ def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs):
61
+ input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs)
62
+ if input_ids.size()[0] != curr_batch_size[0]:
63
+ input_ids = input_ids.repeat(curr_batch_size[0], 1)
64
+ return input_ids
65
+ GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
66
+
67
+ print(f"load images from {args.train_data_dir}")
68
+ image_paths = train_util.glob_images(args.train_data_dir)
69
+ print(f"found {len(image_paths)} images.")
70
+
71
+ # できればcacheに依存せず明示的にダウンロードしたい
72
+ print(f"loading GIT: {args.model_id}")
73
+ git_processor = AutoProcessor.from_pretrained(args.model_id)
74
+ git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
75
+ print("GIT loaded")
76
+
77
+ # captioningする
78
+ def run_batch(path_imgs):
79
+ imgs = [im for _, im in path_imgs]
80
+
81
+ curr_batch_size[0] = len(path_imgs)
82
+ inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
83
+ generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
84
+ captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
85
+
86
+ if args.remove_words:
87
+ captions = remove_words(captions, args.debug)
88
+
89
+ for (image_path, _), caption in zip(path_imgs, captions):
90
+ with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
91
+ f.write(caption + "\n")
92
+ if args.debug:
93
+ print(image_path, caption)
94
+
95
+ # 読み込みの高速化のためにDataLoaderを使うオプション
96
+ if args.max_data_loader_n_workers is not None:
97
+ dataset = train_util.ImageLoadingDataset(image_paths)
98
+ data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
99
+ num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
100
+ else:
101
+ data = [[(None, ip)] for ip in image_paths]
102
+
103
+ b_imgs = []
104
+ for data_entry in tqdm(data, smoothing=0.0):
105
+ for data in data_entry:
106
+ if data is None:
107
+ continue
108
+
109
+ image, image_path = data
110
+ if image is None:
111
+ try:
112
+ image = Image.open(image_path)
113
+ if image.mode != 'RGB':
114
+ image = image.convert("RGB")
115
+ except Exception as e:
116
+ print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
117
+ continue
118
+
119
+ b_imgs.append((image_path, image))
120
+ if len(b_imgs) >= args.batch_size:
121
+ run_batch(b_imgs)
122
+ b_imgs.clear()
123
+
124
+ if len(b_imgs) > 0:
125
+ run_batch(b_imgs)
126
+
127
+ print("done!")
128
+
129
+
130
+ if __name__ == '__main__':
131
+ parser = argparse.ArgumentParser()
132
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
133
+ parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
134
+ parser.add_argument("--model_id", type=str, default="microsoft/git-large-textcaps",
135
+ help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID")
136
+ parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
137
+ parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
138
+ help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
139
+ parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
140
+ parser.add_argument("--remove_words", action="store_true",
141
+ help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
142
+ parser.add_argument("--debug", action="store_true", help="debug mode")
143
+
144
+ args = parser.parse_args()
145
+ main(args)
finetune/merge_captions_to_metadata.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+ from typing import List
5
+ from tqdm import tqdm
6
+ import library.train_util as train_util
7
+
8
+
9
+ def main(args):
10
+ assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
11
+
12
+ train_data_dir_path = Path(args.train_data_dir)
13
+ image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
14
+ print(f"found {len(image_paths)} images.")
15
+
16
+ if args.in_json is None and Path(args.out_json).is_file():
17
+ args.in_json = args.out_json
18
+
19
+ if args.in_json is not None:
20
+ print(f"loading existing metadata: {args.in_json}")
21
+ metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
22
+ print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
23
+ else:
24
+ print("new metadata will be created / 新しいメタデータファイルが作成されます")
25
+ metadata = {}
26
+
27
+ print("merge caption texts to metadata json.")
28
+ for image_path in tqdm(image_paths):
29
+ caption_path = image_path.with_suffix(args.caption_extension)
30
+ caption = caption_path.read_text(encoding='utf-8').strip()
31
+
32
+ image_key = str(image_path) if args.full_path else image_path.stem
33
+ if image_key not in metadata:
34
+ metadata[image_key] = {}
35
+
36
+ metadata[image_key]['caption'] = caption
37
+ if args.debug:
38
+ print(image_key, caption)
39
+
40
+ # metadataを書き出して終わり
41
+ print(f"writing metadata: {args.out_json}")
42
+ Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
43
+ print("done!")
44
+
45
+
46
+ if __name__ == '__main__':
47
+ parser = argparse.ArgumentParser()
48
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
49
+ parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
50
+ parser.add_argument("--in_json", type=str,
51
+ help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
52
+ parser.add_argument("--caption_extention", type=str, default=None,
53
+ help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
54
+ parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
55
+ parser.add_argument("--full_path", action="store_true",
56
+ help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
57
+ parser.add_argument("--recursive", action="store_true",
58
+ help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
59
+ parser.add_argument("--debug", action="store_true", help="debug mode")
60
+
61
+ args = parser.parse_args()
62
+
63
+ # スペルミスしていたオプションを復元する
64
+ if args.caption_extention is not None:
65
+ args.caption_extension = args.caption_extention
66
+
67
+ main(args)
finetune/merge_dd_tags_to_metadata.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+ from typing import List
5
+ from tqdm import tqdm
6
+ import library.train_util as train_util
7
+
8
+
9
+ def main(args):
10
+ assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
11
+
12
+ train_data_dir_path = Path(args.train_data_dir)
13
+ image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
14
+ print(f"found {len(image_paths)} images.")
15
+
16
+ if args.in_json is None and Path(args.out_json).is_file():
17
+ args.in_json = args.out_json
18
+
19
+ if args.in_json is not None:
20
+ print(f"loading existing metadata: {args.in_json}")
21
+ metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
22
+ print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
23
+ else:
24
+ print("new metadata will be created / 新しいメタデータファイルが作成されます")
25
+ metadata = {}
26
+
27
+ print("merge tags to metadata json.")
28
+ for image_path in tqdm(image_paths):
29
+ tags_path = image_path.with_suffix(args.caption_extension)
30
+ tags = tags_path.read_text(encoding='utf-8').strip()
31
+
32
+ image_key = str(image_path) if args.full_path else image_path.stem
33
+ if image_key not in metadata:
34
+ metadata[image_key] = {}
35
+
36
+ metadata[image_key]['tags'] = tags
37
+ if args.debug:
38
+ print(image_key, tags)
39
+
40
+ # metadataを書き出して終わり
41
+ print(f"writing metadata: {args.out_json}")
42
+ Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
43
+
44
+ print("done!")
45
+
46
+
47
+ if __name__ == '__main__':
48
+ parser = argparse.ArgumentParser()
49
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
50
+ parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
51
+ parser.add_argument("--in_json", type=str,
52
+ help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
53
+ parser.add_argument("--full_path", action="store_true",
54
+ help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
55
+ parser.add_argument("--recursive", action="store_true",
56
+ help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
57
+ parser.add_argument("--caption_extension", type=str, default=".txt",
58
+ help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子")
59
+ parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
60
+
61
+ args = parser.parse_args()
62
+ main(args)
finetune/prepare_buckets_latents.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import json
4
+
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+ from PIL import Image
8
+ import cv2
9
+ import torch
10
+ from torchvision import transforms
11
+
12
+ import library.model_util as model_util
13
+ import library.train_util as train_util
14
+
15
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+
17
+ IMAGE_TRANSFORMS = transforms.Compose(
18
+ [
19
+ transforms.ToTensor(),
20
+ transforms.Normalize([0.5], [0.5]),
21
+ ]
22
+ )
23
+
24
+
25
+ def collate_fn_remove_corrupted(batch):
26
+ """Collate function that allows to remove corrupted examples in the
27
+ dataloader. It expects that the dataloader returns 'None' when that occurs.
28
+ The 'None's in the batch are removed.
29
+ """
30
+ # Filter out all the Nones (corrupted examples)
31
+ batch = list(filter(lambda x: x is not None, batch))
32
+ return batch
33
+
34
+
35
+ def get_latents(vae, images, weight_dtype):
36
+ img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
37
+ img_tensors = torch.stack(img_tensors)
38
+ img_tensors = img_tensors.to(DEVICE, weight_dtype)
39
+ with torch.no_grad():
40
+ latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy()
41
+ return latents
42
+
43
+
44
+ def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
45
+ if is_full_path:
46
+ base_name = os.path.splitext(os.path.basename(image_key))[0]
47
+ else:
48
+ base_name = image_key
49
+ if flip:
50
+ base_name += '_flip'
51
+ return os.path.join(data_dir, base_name)
52
+
53
+
54
+ def main(args):
55
+ # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
56
+ if args.bucket_reso_steps % 8 > 0:
57
+ print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
58
+
59
+ image_paths = train_util.glob_images(args.train_data_dir)
60
+ print(f"found {len(image_paths)} images.")
61
+
62
+ if os.path.exists(args.in_json):
63
+ print(f"loading existing metadata: {args.in_json}")
64
+ with open(args.in_json, "rt", encoding='utf-8') as f:
65
+ metadata = json.load(f)
66
+ else:
67
+ print(f"no metadata / メタデータファイルがありません: {args.in_json}")
68
+ return
69
+
70
+ weight_dtype = torch.float32
71
+ if args.mixed_precision == "fp16":
72
+ weight_dtype = torch.float16
73
+ elif args.mixed_precision == "bf16":
74
+ weight_dtype = torch.bfloat16
75
+
76
+ vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
77
+ vae.eval()
78
+ vae.to(DEVICE, dtype=weight_dtype)
79
+
80
+ # bucketのサイズを計算する
81
+ max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
82
+ assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
83
+
84
+ bucket_manager = train_util.BucketManager(args.bucket_no_upscale, max_reso,
85
+ args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps)
86
+ if not args.bucket_no_upscale:
87
+ bucket_manager.make_buckets()
88
+ else:
89
+ print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
90
+
91
+ # 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
92
+ img_ar_errors = []
93
+
94
+ def process_batch(is_last):
95
+ for bucket in bucket_manager.buckets:
96
+ if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
97
+ latents = get_latents(vae, [img for _, img in bucket], weight_dtype)
98
+ assert latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8, \
99
+ f"latent shape {latents.shape}, {bucket[0][1].shape}"
100
+
101
+ for (image_key, _), latent in zip(bucket, latents):
102
+ npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
103
+ np.savez(npz_file_name, latent)
104
+
105
+ # flip
106
+ if args.flip_aug:
107
+ latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない
108
+
109
+ for (image_key, _), latent in zip(bucket, latents):
110
+ npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
111
+ np.savez(npz_file_name, latent)
112
+ else:
113
+ # remove existing flipped npz
114
+ for image_key, _ in bucket:
115
+ npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz"
116
+ if os.path.isfile(npz_file_name):
117
+ print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
118
+ os.remove(npz_file_name)
119
+
120
+ bucket.clear()
121
+
122
+ # 読み込みの高速化のためにDataLoaderを使うオプション
123
+ if args.max_data_loader_n_workers is not None:
124
+ dataset = train_util.ImageLoadingDataset(image_paths)
125
+ data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
126
+ num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
127
+ else:
128
+ data = [[(None, ip)] for ip in image_paths]
129
+
130
+ bucket_counts = {}
131
+ for data_entry in tqdm(data, smoothing=0.0):
132
+ if data_entry[0] is None:
133
+ continue
134
+
135
+ img_tensor, image_path = data_entry[0]
136
+ if img_tensor is not None:
137
+ image = transforms.functional.to_pil_image(img_tensor)
138
+ else:
139
+ try:
140
+ image = Image.open(image_path)
141
+ if image.mode != 'RGB':
142
+ image = image.convert("RGB")
143
+ except Exception as e:
144
+ print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
145
+ continue
146
+
147
+ image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
148
+ if image_key not in metadata:
149
+ metadata[image_key] = {}
150
+
151
+ # 本当はこのあとの部分もDataSetに持っていけば高速化できるがいろいろ大変
152
+
153
+ reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height)
154
+ img_ar_errors.append(abs(ar_error))
155
+ bucket_counts[reso] = bucket_counts.get(reso, 0) + 1
156
+
157
+ # メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
158
+ metadata[image_key]['train_resolution'] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
159
+
160
+ if not args.bucket_no_upscale:
161
+ # upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
162
+ assert resized_size[0] == reso[0] or resized_size[1] == reso[
163
+ 1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
164
+ assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
165
+ 1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
166
+
167
+ assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
168
+ 1], f"internal error resized size is small: {resized_size}, {reso}"
169
+
170
+ # 既に存在するファイルがあればshapeを確認して同じならskipする
171
+ if args.skip_existing:
172
+ npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"]
173
+ if args.flip_aug:
174
+ npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz")
175
+
176
+ found = True
177
+ for npz_file in npz_files:
178
+ if not os.path.exists(npz_file):
179
+ found = False
180
+ break
181
+
182
+ dat = np.load(npz_file)['arr_0']
183
+ if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認
184
+ found = False
185
+ break
186
+ if found:
187
+ continue
188
+
189
+ # 画像をリサイズしてトリミングする
190
+ # PILにinter_areaがないのでcv2で……
191
+ image = np.array(image)
192
+ if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要?
193
+ image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
194
+
195
+ if resized_size[0] > reso[0]:
196
+ trim_size = resized_size[0] - reso[0]
197
+ image = image[:, trim_size//2:trim_size//2 + reso[0]]
198
+
199
+ if resized_size[1] > reso[1]:
200
+ trim_size = resized_size[1] - reso[1]
201
+ image = image[trim_size//2:trim_size//2 + reso[1]]
202
+
203
+ assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
204
+
205
+ # # debug
206
+ # cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1])
207
+
208
+ # バッチへ追加
209
+ bucket_manager.add_image(reso, (image_key, image))
210
+
211
+ # バッチを推論するか判定して推論する
212
+ process_batch(False)
213
+
214
+ # 残りを処理する
215
+ process_batch(True)
216
+
217
+ bucket_manager.sort()
218
+ for i, reso in enumerate(bucket_manager.resos):
219
+ count = bucket_counts.get(reso, 0)
220
+ if count > 0:
221
+ print(f"bucket {i} {reso}: {count}")
222
+ img_ar_errors = np.array(img_ar_errors)
223
+ print(f"mean ar error: {np.mean(img_ar_errors)}")
224
+
225
+ # metadataを書き出して終わり
226
+ print(f"writing metadata: {args.out_json}")
227
+ with open(args.out_json, "wt", encoding='utf-8') as f:
228
+ json.dump(metadata, f, indent=2)
229
+ print("done!")
230
+
231
+
232
+ if __name__ == '__main__':
233
+ parser = argparse.ArgumentParser()
234
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
235
+ parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
236
+ parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
237
+ parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
238
+ parser.add_argument("--v2", action='store_true',
239
+ help='not used (for backward compatibility) / 使用されません(互換性のため残してあります)')
240
+ parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
241
+ parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
242
+ help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
243
+ parser.add_argument("--max_resolution", type=str, default="512,512",
244
+ help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
245
+ parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
246
+ parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
247
+ parser.add_argument("--bucket_reso_steps", type=int, default=64,
248
+ help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
249
+ parser.add_argument("--bucket_no_upscale", action="store_true",
250
+ help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
251
+ parser.add_argument("--mixed_precision", type=str, default="no",
252
+ choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
253
+ parser.add_argument("--full_path", action="store_true",
254
+ help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
255
+ parser.add_argument("--flip_aug", action="store_true",
256
+ help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
257
+ parser.add_argument("--skip_existing", action="store_true",
258
+ help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)")
259
+
260
+ args = parser.parse_args()
261
+ main(args)
finetune/tag_images_by_wd14_tagger.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import csv
3
+ import glob
4
+ import os
5
+
6
+ from PIL import Image
7
+ import cv2
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ from tensorflow.keras.models import load_model
11
+ from huggingface_hub import hf_hub_download
12
+ import torch
13
+
14
+ import library.train_util as train_util
15
+
16
+ # from wd14 tagger
17
+ IMAGE_SIZE = 448
18
+
19
+ # wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
20
+ DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2'
21
+ FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
22
+ SUB_DIR = "variables"
23
+ SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
24
+ CSV_FILE = FILES[-1]
25
+
26
+
27
+ def preprocess_image(image):
28
+ image = np.array(image)
29
+ image = image[:, :, ::-1] # RGB->BGR
30
+
31
+ # pad to square
32
+ size = max(image.shape[0:2])
33
+ pad_x = size - image.shape[1]
34
+ pad_y = size - image.shape[0]
35
+ pad_l = pad_x // 2
36
+ pad_t = pad_y // 2
37
+ image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255)
38
+
39
+ interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
40
+ image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
41
+
42
+ image = image.astype(np.float32)
43
+ return image
44
+
45
+
46
+ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
47
+ def __init__(self, image_paths):
48
+ self.images = image_paths
49
+
50
+ def __len__(self):
51
+ return len(self.images)
52
+
53
+ def __getitem__(self, idx):
54
+ img_path = self.images[idx]
55
+
56
+ try:
57
+ image = Image.open(img_path).convert("RGB")
58
+ image = preprocess_image(image)
59
+ tensor = torch.tensor(image)
60
+ except Exception as e:
61
+ print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
62
+ return None
63
+
64
+ return (tensor, img_path)
65
+
66
+
67
+ def collate_fn_remove_corrupted(batch):
68
+ """Collate function that allows to remove corrupted examples in the
69
+ dataloader. It expects that the dataloader returns 'None' when that occurs.
70
+ The 'None's in the batch are removed.
71
+ """
72
+ # Filter out all the Nones (corrupted examples)
73
+ batch = list(filter(lambda x: x is not None, batch))
74
+ return batch
75
+
76
+
77
+ def main(args):
78
+ # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
79
+ # depreacatedの警告が出るけどなくなったらその時
80
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
81
+ if not os.path.exists(args.model_dir) or args.force_download:
82
+ print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
83
+ for file in FILES:
84
+ hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
85
+ for file in SUB_DIR_FILES:
86
+ hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join(
87
+ args.model_dir, SUB_DIR), force_download=True, force_filename=file)
88
+ else:
89
+ print("using existing wd14 tagger model")
90
+
91
+ # 画像を読み込む
92
+ image_paths = train_util.glob_images(args.train_data_dir)
93
+ print(f"found {len(image_paths)} images.")
94
+
95
+ print("loading model and labels")
96
+ model = load_model(args.model_dir)
97
+
98
+ # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
99
+ # 依存ライブラリを増やしたくないので自力で読むよ
100
+ with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
101
+ reader = csv.reader(f)
102
+ l = [row for row in reader]
103
+ header = l[0] # tag_id,name,category,count
104
+ rows = l[1:]
105
+ assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}"
106
+
107
+ tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ
108
+
109
+ # 推論する
110
+ def run_batch(path_imgs):
111
+ imgs = np.array([im for _, im in path_imgs])
112
+
113
+ probs = model(imgs, training=False)
114
+ probs = probs.numpy()
115
+
116
+ for (image_path, _), prob in zip(path_imgs, probs):
117
+ # 最初の4つはratingなので無視する
118
+ # # First 4 labels are actually ratings: pick one with argmax
119
+ # ratings_names = label_names[:4]
120
+ # rating_index = ratings_names["probs"].argmax()
121
+ # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
122
+
123
+ # それ以降はタグなのでconfidenceがthresholdより高いものを追加する
124
+ # Everything else is tags: pick any where prediction confidence > threshold
125
+ tag_text = ""
126
+ for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
127
+ if p >= args.thresh and i < len(tags):
128
+ tag_text += ", " + tags[i]
129
+
130
+ if len(tag_text) > 0:
131
+ tag_text = tag_text[2:] # 最初の ", " を消す
132
+
133
+ with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
134
+ f.write(tag_text + '\n')
135
+ if args.debug:
136
+ print(image_path, tag_text)
137
+
138
+ # 読み込みの高速化のためにDataLoaderを使うオプション
139
+ if args.max_data_loader_n_workers is not None:
140
+ dataset = ImageLoadingPrepDataset(image_paths)
141
+ data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
142
+ num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
143
+ else:
144
+ data = [[(None, ip)] for ip in image_paths]
145
+
146
+ b_imgs = []
147
+ for data_entry in tqdm(data, smoothing=0.0):
148
+ for data in data_entry:
149
+ if data is None:
150
+ continue
151
+
152
+ image, image_path = data
153
+ if image is not None:
154
+ image = image.detach().numpy()
155
+ else:
156
+ try:
157
+ image = Image.open(image_path)
158
+ if image.mode != 'RGB':
159
+ image = image.convert("RGB")
160
+ image = preprocess_image(image)
161
+ except Exception as e:
162
+ print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
163
+ continue
164
+ b_imgs.append((image_path, image))
165
+
166
+ if len(b_imgs) >= args.batch_size:
167
+ run_batch(b_imgs)
168
+ b_imgs.clear()
169
+
170
+ if len(b_imgs) > 0:
171
+ run_batch(b_imgs)
172
+
173
+ print("done!")
174
+
175
+
176
+ if __name__ == '__main__':
177
+ parser = argparse.ArgumentParser()
178
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
179
+ parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
180
+ help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
181
+ parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
182
+ help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ")
183
+ parser.add_argument("--force_download", action='store_true',
184
+ help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします")
185
+ parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
186
+ parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
187
+ parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
188
+ help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
189
+ parser.add_argument("--caption_extention", type=str, default=None,
190
+ help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
191
+ parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
192
+ parser.add_argument("--debug", action="store_true", help="debug mode")
193
+
194
+ args = parser.parse_args()
195
+
196
+ # スペルミスしていたオプションを復元する
197
+ if args.caption_extention is not None:
198
+ args.caption_extension = args.caption_extention
199
+
200
+ main(args)
gen_img_diffusers.py CHANGED
@@ -47,7 +47,7 @@ VGG(
47
  """
48
 
49
  import json
50
- from typing import List, Optional, Union
51
  import glob
52
  import importlib
53
  import inspect
@@ -60,7 +60,6 @@ import math
60
  import os
61
  import random
62
  import re
63
- from typing import Any, Callable, List, Optional, Union
64
 
65
  import diffusers
66
  import numpy as np
@@ -81,6 +80,9 @@ from PIL import Image
81
  from PIL.PngImagePlugin import PngInfo
82
 
83
  import library.model_util as model_util
 
 
 
84
 
85
  # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
86
  TOKENIZER_PATH = "openai/clip-vit-large-patch14"
@@ -487,6 +489,9 @@ class PipelineLike():
487
  self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
488
  self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
489
 
 
 
 
490
  # Textual Inversion
491
  def add_token_replacement(self, target_token_id, rep_token_ids):
492
  self.token_replacements[target_token_id] = rep_token_ids
@@ -500,7 +505,11 @@ class PipelineLike():
500
  new_tokens.append(token)
501
  return new_tokens
502
 
 
 
 
503
  # region xformersとか使う部分:独自に書き換えるので関係なし
 
504
  def enable_xformers_memory_efficient_attention(self):
505
  r"""
506
  Enable memory efficient attention as implemented in xformers.
@@ -581,6 +590,8 @@ class PipelineLike():
581
  latents: Optional[torch.FloatTensor] = None,
582
  max_embeddings_multiples: Optional[int] = 3,
583
  output_type: Optional[str] = "pil",
 
 
584
  # return_dict: bool = True,
585
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
586
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
@@ -672,6 +683,9 @@ class PipelineLike():
672
  else:
673
  raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
674
 
 
 
 
675
  if strength < 0 or strength > 1:
676
  raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
677
 
@@ -752,7 +766,7 @@ class PipelineLike():
752
  text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
753
  text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
754
 
755
- if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None:
756
  if isinstance(clip_guide_images, PIL.Image.Image):
757
  clip_guide_images = [clip_guide_images]
758
 
@@ -765,7 +779,7 @@ class PipelineLike():
765
  image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
766
  if len(image_embeddings_clip) == 1:
767
  image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
768
- else:
769
  size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
770
  clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
771
  clip_guide_images = torch.cat(clip_guide_images, dim=0)
@@ -774,6 +788,10 @@ class PipelineLike():
774
  image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
775
  if len(image_embeddings_vgg16) == 1:
776
  image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
 
 
 
 
777
 
778
  # set timesteps
779
  self.scheduler.set_timesteps(num_inference_steps, self.device)
@@ -781,7 +799,6 @@ class PipelineLike():
781
  latents_dtype = text_embeddings.dtype
782
  init_latents_orig = None
783
  mask = None
784
- noise = None
785
 
786
  if init_image is None:
787
  # get the initial random noise unless the user supplied it
@@ -813,6 +830,8 @@ class PipelineLike():
813
  if isinstance(init_image[0], PIL.Image.Image):
814
  init_image = [preprocess_image(im) for im in init_image]
815
  init_image = torch.cat(init_image)
 
 
816
 
817
  # mask image to tensor
818
  if mask_image is not None:
@@ -823,9 +842,24 @@ class PipelineLike():
823
 
824
  # encode the init image into latents and scale the latents
825
  init_image = init_image.to(device=self.device, dtype=latents_dtype)
826
- init_latent_dist = self.vae.encode(init_image).latent_dist
827
- init_latents = init_latent_dist.sample(generator=generator)
828
- init_latents = 0.18215 * init_latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
829
  if len(init_latents) == 1:
830
  init_latents = init_latents.repeat((batch_size, 1, 1, 1))
831
  init_latents_orig = init_latents
@@ -864,12 +898,21 @@ class PipelineLike():
864
  extra_step_kwargs["eta"] = eta
865
 
866
  num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
 
 
 
 
867
  for i, t in enumerate(tqdm(timesteps)):
868
  # expand the latents if we are doing classifier free guidance
869
  latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
870
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
871
  # predict the noise residual
872
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
 
 
 
 
873
 
874
  # perform guidance
875
  if do_classifier_free_guidance:
@@ -911,8 +954,19 @@ class PipelineLike():
911
  if is_cancelled_callback is not None and is_cancelled_callback():
912
  return None
913
 
 
 
 
914
  latents = 1 / 0.18215 * latents
915
- image = self.vae.decode(latents).sample
 
 
 
 
 
 
 
 
916
 
917
  image = (image / 2 + 0.5).clamp(0, 1)
918
 
@@ -1799,7 +1853,7 @@ def preprocess_mask(mask):
1799
  mask = mask.convert("L")
1800
  w, h = mask.size
1801
  w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
1802
- mask = mask.resize((w // 8, h // 8), resample=PIL.Image.LANCZOS)
1803
  mask = np.array(mask).astype(np.float32) / 255.0
1804
  mask = np.tile(mask, (4, 1, 1))
1805
  mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
@@ -1817,6 +1871,35 @@ def preprocess_mask(mask):
1817
  # return text_encoder
1818
 
1819
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1820
  def main(args):
1821
  if args.fp16:
1822
  dtype = torch.float16
@@ -1881,10 +1964,7 @@ def main(args):
1881
  # tokenizerを読み込む
1882
  print("loading tokenizer")
1883
  if use_stable_diffusion_format:
1884
- if args.v2:
1885
- tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
1886
- else:
1887
- tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
1888
 
1889
  # schedulerを用意する
1890
  sched_init_args = {}
@@ -1995,11 +2075,13 @@ def main(args):
1995
  # networkを組み込む
1996
  if args.network_module:
1997
  networks = []
 
1998
  for i, network_module in enumerate(args.network_module):
1999
  print("import network module:", network_module)
2000
  imported_module = importlib.import_module(network_module)
2001
 
2002
  network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
 
2003
 
2004
  net_kwargs = {}
2005
  if args.network_args and i < len(args.network_args):
@@ -2014,7 +2096,7 @@ def main(args):
2014
  network_weight = args.network_weights[i]
2015
  print("load network weights from:", network_weight)
2016
 
2017
- if model_util.is_safetensors(network_weight):
2018
  from safetensors.torch import safe_open
2019
  with safe_open(network_weight, framework="pt") as f:
2020
  metadata = f.metadata()
@@ -2037,6 +2119,18 @@ def main(args):
2037
  else:
2038
  networks = []
2039
 
 
 
 
 
 
 
 
 
 
 
 
 
2040
  if args.opt_channels_last:
2041
  print(f"set optimizing: channels last")
2042
  text_encoder.to(memory_format=torch.channels_last)
@@ -2050,9 +2144,14 @@ def main(args):
2050
  if vgg16_model is not None:
2051
  vgg16_model.to(memory_format=torch.channels_last)
2052
 
 
 
 
 
2053
  pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
2054
  clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
2055
  vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
 
2056
  print("pipeline is ready.")
2057
 
2058
  if args.diffusers_xformers:
@@ -2186,9 +2285,12 @@ def main(args):
2186
 
2187
  prev_image = None # for VGG16 guided
2188
  if args.guide_image_path is not None:
2189
- print(f"load image for CLIP/VGG16 guidance: {args.guide_image_path}")
2190
- guide_images = load_images(args.guide_image_path)
2191
- print(f"loaded {len(guide_images)} guide images for CLIP/VGG16 guidance")
 
 
 
2192
  if len(guide_images) == 0:
2193
  print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
2194
  guide_images = None
@@ -2219,33 +2321,46 @@ def main(args):
2219
  iter_seed = random.randint(0, 0x7fffffff)
2220
 
2221
  # バッチ処理の関数
2222
- def process_batch(batch, highres_fix, highres_1st=False):
2223
  batch_size = len(batch)
2224
 
2225
  # highres_fixの処理
2226
  if highres_fix and not highres_1st:
2227
- # 1st stageのバッチを作成して呼び出す
2228
- print("process 1st stage1")
2229
  batch_1st = []
2230
- for params1, (width, height, steps, scale, negative_scale, strength) in batch:
2231
- width_1st = int(width * args.highres_fix_scale + .5)
2232
- height_1st = int(height * args.highres_fix_scale + .5)
2233
  width_1st = width_1st - width_1st % 32
2234
  height_1st = height_1st - height_1st % 32
2235
- batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, negative_scale, strength)))
 
 
 
2236
  images_1st = process_batch(batch_1st, True, True)
2237
 
2238
  # 2nd stageのバッチを作成して以下処理する
2239
- print("process 2nd stage1")
 
 
 
 
 
 
 
 
2240
  batch_2nd = []
2241
- for i, (b1, image) in enumerate(zip(batch, images_1st)):
2242
- image = image.resize((width, height), resample=PIL.Image.LANCZOS)
2243
- (step, prompt, negative_prompt, seed, _, _, clip_prompt, guide_image), params2 = b1
2244
- batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2))
 
2245
  batch = batch_2nd
2246
 
2247
- (step_first, _, _, _, init_image, mask_image, _, guide_image), (width,
2248
- height, steps, scale, negative_scale, strength) = batch[0]
 
2249
  noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
2250
 
2251
  prompts = []
@@ -2278,7 +2393,7 @@ def main(args):
2278
  all_images_are_same = True
2279
  all_masks_are_same = True
2280
  all_guide_images_are_same = True
2281
- for i, ((_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
2282
  prompts.append(prompt)
2283
  negative_prompts.append(negative_prompt)
2284
  seeds.append(seed)
@@ -2295,9 +2410,13 @@ def main(args):
2295
  all_masks_are_same = mask_images[-2] is mask_image
2296
 
2297
  if guide_image is not None:
2298
- guide_images.append(guide_image)
2299
- if i > 0 and all_guide_images_are_same:
2300
- all_guide_images_are_same = guide_images[-2] is guide_image
 
 
 
 
2301
 
2302
  # make start code
2303
  torch.manual_seed(seed)
@@ -2320,10 +2439,24 @@ def main(args):
2320
  if guide_images is not None and all_guide_images_are_same:
2321
  guide_images = guide_images[0]
2322
 
 
 
 
 
 
 
 
 
2323
  # generate
 
 
 
 
2324
  images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
2325
- output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
2326
- if highres_1st and not args.highres_fix_save_1st:
 
 
2327
  return images
2328
 
2329
  # save image
@@ -2398,6 +2531,7 @@ def main(args):
2398
  strength = 0.8 if args.strength is None else args.strength
2399
  negative_prompt = ""
2400
  clip_prompt = None
 
2401
 
2402
  prompt_args = prompt.strip().split(' --')
2403
  prompt = prompt_args[0]
@@ -2461,6 +2595,15 @@ def main(args):
2461
  clip_prompt = m.group(1)
2462
  print(f"clip prompt: {clip_prompt}")
2463
  continue
 
 
 
 
 
 
 
 
 
2464
  except ValueError as ex:
2465
  print(f"Exception in parsing / 解析エラー: {parg}")
2466
  print(ex)
@@ -2498,7 +2641,12 @@ def main(args):
2498
  mask_image = mask_images[global_step % len(mask_images)]
2499
 
2500
  if guide_images is not None:
2501
- guide_image = guide_images[global_step % len(guide_images)]
 
 
 
 
 
2502
  elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
2503
  if prev_image is None:
2504
  print("Generate 1st image without guide image.")
@@ -2506,10 +2654,9 @@ def main(args):
2506
  print("Use previous image as guide image.")
2507
  guide_image = prev_image
2508
 
2509
- # TODO named tupleか何かにする
2510
- b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
2511
- (width, height, steps, scale, negative_scale, strength))
2512
- if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
2513
  process_batch(batch_data, highres_fix)
2514
  batch_data.clear()
2515
 
@@ -2553,6 +2700,8 @@ if __name__ == '__main__':
2553
  parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
2554
  parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
2555
  parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
 
 
2556
  parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
2557
  parser.add_argument('--sampler', type=str, default='ddim',
2558
  choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
@@ -2564,6 +2713,8 @@ if __name__ == '__main__':
2564
  parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
2565
  parser.add_argument("--vae", type=str, default=None,
2566
  help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
 
 
2567
  # parser.add_argument("--replace_clip_l14_336", action='store_true',
2568
  # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
2569
  parser.add_argument("--seed", type=int, default=None,
@@ -2578,12 +2729,15 @@ if __name__ == '__main__':
2578
  parser.add_argument("--opt_channels_last", action='store_true',
2579
  help='set channels last option to model / モデルにchannels lastを指定し最適化する')
2580
  parser.add_argument("--network_module", type=str, default=None, nargs='*',
2581
- help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
2582
  parser.add_argument("--network_weights", type=str, default=None, nargs='*',
2583
- help='Hypernetwork weights to load / Hypernetworkの重み')
2584
- parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
 
2585
  parser.add_argument("--network_args", type=str, default=None, nargs='*',
2586
  help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
 
 
2587
  parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
2588
  help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
2589
  parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
@@ -2597,15 +2751,26 @@ if __name__ == '__main__':
2597
  help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
2598
  parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
2599
  help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
2600
- parser.add_argument("--guide_image_path", type=str, default=None, help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
 
2601
  parser.add_argument("--highres_fix_scale", type=float, default=None,
2602
  help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
2603
  parser.add_argument("--highres_fix_steps", type=int, default=28,
2604
  help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
2605
  parser.add_argument("--highres_fix_save_1st", action='store_true',
2606
  help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
 
 
2607
  parser.add_argument("--negative_scale", type=float, default=None,
2608
  help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
2609
 
 
 
 
 
 
 
 
 
2610
  args = parser.parse_args()
2611
  main(args)
 
47
  """
48
 
49
  import json
50
+ from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
51
  import glob
52
  import importlib
53
  import inspect
 
60
  import os
61
  import random
62
  import re
 
63
 
64
  import diffusers
65
  import numpy as np
 
80
  from PIL.PngImagePlugin import PngInfo
81
 
82
  import library.model_util as model_util
83
+ import library.train_util as train_util
84
+ import tools.original_control_net as original_control_net
85
+ from tools.original_control_net import ControlNetInfo
86
 
87
  # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
88
  TOKENIZER_PATH = "openai/clip-vit-large-patch14"
 
489
  self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
490
  self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
491
 
492
+ # ControlNet
493
+ self.control_nets: List[ControlNetInfo] = []
494
+
495
  # Textual Inversion
496
  def add_token_replacement(self, target_token_id, rep_token_ids):
497
  self.token_replacements[target_token_id] = rep_token_ids
 
505
  new_tokens.append(token)
506
  return new_tokens
507
 
508
+ def set_control_nets(self, ctrl_nets):
509
+ self.control_nets = ctrl_nets
510
+
511
  # region xformersとか使う部分:独自に書き換えるので関係なし
512
+
513
  def enable_xformers_memory_efficient_attention(self):
514
  r"""
515
  Enable memory efficient attention as implemented in xformers.
 
590
  latents: Optional[torch.FloatTensor] = None,
591
  max_embeddings_multiples: Optional[int] = 3,
592
  output_type: Optional[str] = "pil",
593
+ vae_batch_size: float = None,
594
+ return_latents: bool = False,
595
  # return_dict: bool = True,
596
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
597
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
 
683
  else:
684
  raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
685
 
686
+ vae_batch_size = batch_size if vae_batch_size is None else (
687
+ int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size)))
688
+
689
  if strength < 0 or strength > 1:
690
  raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
691
 
 
766
  text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
767
  text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
768
 
769
+ if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None or self.control_nets:
770
  if isinstance(clip_guide_images, PIL.Image.Image):
771
  clip_guide_images = [clip_guide_images]
772
 
 
779
  image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
780
  if len(image_embeddings_clip) == 1:
781
  image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
782
+ elif self.vgg16_guidance_scale > 0:
783
  size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
784
  clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
785
  clip_guide_images = torch.cat(clip_guide_images, dim=0)
 
788
  image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
789
  if len(image_embeddings_vgg16) == 1:
790
  image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
791
+ else:
792
+ # ControlNetのhintにguide imageを流用する
793
+ # 前処理はControlNet側で行う
794
+ pass
795
 
796
  # set timesteps
797
  self.scheduler.set_timesteps(num_inference_steps, self.device)
 
799
  latents_dtype = text_embeddings.dtype
800
  init_latents_orig = None
801
  mask = None
 
802
 
803
  if init_image is None:
804
  # get the initial random noise unless the user supplied it
 
830
  if isinstance(init_image[0], PIL.Image.Image):
831
  init_image = [preprocess_image(im) for im in init_image]
832
  init_image = torch.cat(init_image)
833
+ if isinstance(init_image, list):
834
+ init_image = torch.stack(init_image)
835
 
836
  # mask image to tensor
837
  if mask_image is not None:
 
842
 
843
  # encode the init image into latents and scale the latents
844
  init_image = init_image.to(device=self.device, dtype=latents_dtype)
845
+ if init_image.size()[2:] == (height // 8, width // 8):
846
+ init_latents = init_image
847
+ else:
848
+ if vae_batch_size >= batch_size:
849
+ init_latent_dist = self.vae.encode(init_image).latent_dist
850
+ init_latents = init_latent_dist.sample(generator=generator)
851
+ else:
852
+ if torch.cuda.is_available():
853
+ torch.cuda.empty_cache()
854
+ init_latents = []
855
+ for i in tqdm(range(0, batch_size, vae_batch_size)):
856
+ init_latent_dist = self.vae.encode(init_image[i:i + vae_batch_size]
857
+ if vae_batch_size > 1 else init_image[i].unsqueeze(0)).latent_dist
858
+ init_latents.append(init_latent_dist.sample(generator=generator))
859
+ init_latents = torch.cat(init_latents)
860
+
861
+ init_latents = 0.18215 * init_latents
862
+
863
  if len(init_latents) == 1:
864
  init_latents = init_latents.repeat((batch_size, 1, 1, 1))
865
  init_latents_orig = init_latents
 
898
  extra_step_kwargs["eta"] = eta
899
 
900
  num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
901
+
902
+ if self.control_nets:
903
+ guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
904
+
905
  for i, t in enumerate(tqdm(timesteps)):
906
  # expand the latents if we are doing classifier free guidance
907
  latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
908
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
909
+
910
  # predict the noise residual
911
+ if self.control_nets:
912
+ noise_pred = original_control_net.call_unet_and_control_net(
913
+ i, num_latent_input, self.unet, self.control_nets, guided_hints, i / len(timesteps), latent_model_input, t, text_embeddings).sample
914
+ else:
915
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
916
 
917
  # perform guidance
918
  if do_classifier_free_guidance:
 
954
  if is_cancelled_callback is not None and is_cancelled_callback():
955
  return None
956
 
957
+ if return_latents:
958
+ return (latents, False)
959
+
960
  latents = 1 / 0.18215 * latents
961
+ if vae_batch_size >= batch_size:
962
+ image = self.vae.decode(latents).sample
963
+ else:
964
+ if torch.cuda.is_available():
965
+ torch.cuda.empty_cache()
966
+ images = []
967
+ for i in tqdm(range(0, batch_size, vae_batch_size)):
968
+ images.append(self.vae.decode(latents[i:i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample)
969
+ image = torch.cat(images)
970
 
971
  image = (image / 2 + 0.5).clamp(0, 1)
972
 
 
1853
  mask = mask.convert("L")
1854
  w, h = mask.size
1855
  w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
1856
+ mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS)
1857
  mask = np.array(mask).astype(np.float32) / 255.0
1858
  mask = np.tile(mask, (4, 1, 1))
1859
  mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
 
1871
  # return text_encoder
1872
 
1873
 
1874
+ class BatchDataBase(NamedTuple):
1875
+ # バッチ分割が必要ないデータ
1876
+ step: int
1877
+ prompt: str
1878
+ negative_prompt: str
1879
+ seed: int
1880
+ init_image: Any
1881
+ mask_image: Any
1882
+ clip_prompt: str
1883
+ guide_image: Any
1884
+
1885
+
1886
+ class BatchDataExt(NamedTuple):
1887
+ # バッチ分割が必要なデータ
1888
+ width: int
1889
+ height: int
1890
+ steps: int
1891
+ scale: float
1892
+ negative_scale: float
1893
+ strength: float
1894
+ network_muls: Tuple[float]
1895
+
1896
+
1897
+ class BatchData(NamedTuple):
1898
+ return_latents: bool
1899
+ base: BatchDataBase
1900
+ ext: BatchDataExt
1901
+
1902
+
1903
  def main(args):
1904
  if args.fp16:
1905
  dtype = torch.float16
 
1964
  # tokenizerを読み込む
1965
  print("loading tokenizer")
1966
  if use_stable_diffusion_format:
1967
+ tokenizer = train_util.load_tokenizer(args)
 
 
 
1968
 
1969
  # schedulerを用意する
1970
  sched_init_args = {}
 
2075
  # networkを組み込む
2076
  if args.network_module:
2077
  networks = []
2078
+ network_default_muls = []
2079
  for i, network_module in enumerate(args.network_module):
2080
  print("import network module:", network_module)
2081
  imported_module = importlib.import_module(network_module)
2082
 
2083
  network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
2084
+ network_default_muls.append(network_mul)
2085
 
2086
  net_kwargs = {}
2087
  if args.network_args and i < len(args.network_args):
 
2096
  network_weight = args.network_weights[i]
2097
  print("load network weights from:", network_weight)
2098
 
2099
+ if model_util.is_safetensors(network_weight) and args.network_show_meta:
2100
  from safetensors.torch import safe_open
2101
  with safe_open(network_weight, framework="pt") as f:
2102
  metadata = f.metadata()
 
2119
  else:
2120
  networks = []
2121
 
2122
+ # ControlNetの処理
2123
+ control_nets: List[ControlNetInfo] = []
2124
+ if args.control_net_models:
2125
+ for i, model in enumerate(args.control_net_models):
2126
+ prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
2127
+ weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
2128
+ ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
2129
+
2130
+ ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
2131
+ prep = original_control_net.load_preprocess(prep_type)
2132
+ control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
2133
+
2134
  if args.opt_channels_last:
2135
  print(f"set optimizing: channels last")
2136
  text_encoder.to(memory_format=torch.channels_last)
 
2144
  if vgg16_model is not None:
2145
  vgg16_model.to(memory_format=torch.channels_last)
2146
 
2147
+ for cn in control_nets:
2148
+ cn.unet.to(memory_format=torch.channels_last)
2149
+ cn.net.to(memory_format=torch.channels_last)
2150
+
2151
  pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
2152
  clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
2153
  vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
2154
+ pipe.set_control_nets(control_nets)
2155
  print("pipeline is ready.")
2156
 
2157
  if args.diffusers_xformers:
 
2285
 
2286
  prev_image = None # for VGG16 guided
2287
  if args.guide_image_path is not None:
2288
+ print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
2289
+ guide_images = []
2290
+ for p in args.guide_image_path:
2291
+ guide_images.extend(load_images(p))
2292
+
2293
+ print(f"loaded {len(guide_images)} guide images for guidance")
2294
  if len(guide_images) == 0:
2295
  print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
2296
  guide_images = None
 
2321
  iter_seed = random.randint(0, 0x7fffffff)
2322
 
2323
  # バッチ処理の関数
2324
+ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
2325
  batch_size = len(batch)
2326
 
2327
  # highres_fixの処理
2328
  if highres_fix and not highres_1st:
2329
+ # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
2330
+ print("process 1st stage")
2331
  batch_1st = []
2332
+ for _, base, ext in batch:
2333
+ width_1st = int(ext.width * args.highres_fix_scale + .5)
2334
+ height_1st = int(ext.height * args.highres_fix_scale + .5)
2335
  width_1st = width_1st - width_1st % 32
2336
  height_1st = height_1st - height_1st % 32
2337
+
2338
+ ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale,
2339
+ ext.negative_scale, ext.strength, ext.network_muls)
2340
+ batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
2341
  images_1st = process_batch(batch_1st, True, True)
2342
 
2343
  # 2nd stageのバッチを作成して以下処理する
2344
+ print("process 2nd stage")
2345
+ if args.highres_fix_latents_upscaling:
2346
+ org_dtype = images_1st.dtype
2347
+ if images_1st.dtype == torch.bfloat16:
2348
+ images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
2349
+ images_1st = torch.nn.functional.interpolate(
2350
+ images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode='bilinear') # , antialias=True)
2351
+ images_1st = images_1st.to(org_dtype)
2352
+
2353
  batch_2nd = []
2354
+ for i, (bd, image) in enumerate(zip(batch, images_1st)):
2355
+ if not args.highres_fix_latents_upscaling:
2356
+ image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
2357
+ bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext)
2358
+ batch_2nd.append(bd_2nd)
2359
  batch = batch_2nd
2360
 
2361
+ # このバッチの情報を取り出す
2362
+ return_latents, (step_first, _, _, _, init_image, mask_image, _, guide_image), \
2363
+ (width, height, steps, scale, negative_scale, strength, network_muls) = batch[0]
2364
  noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
2365
 
2366
  prompts = []
 
2393
  all_images_are_same = True
2394
  all_masks_are_same = True
2395
  all_guide_images_are_same = True
2396
+ for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
2397
  prompts.append(prompt)
2398
  negative_prompts.append(negative_prompt)
2399
  seeds.append(seed)
 
2410
  all_masks_are_same = mask_images[-2] is mask_image
2411
 
2412
  if guide_image is not None:
2413
+ if type(guide_image) is list:
2414
+ guide_images.extend(guide_image)
2415
+ all_guide_images_are_same = False
2416
+ else:
2417
+ guide_images.append(guide_image)
2418
+ if i > 0 and all_guide_images_are_same:
2419
+ all_guide_images_are_same = guide_images[-2] is guide_image
2420
 
2421
  # make start code
2422
  torch.manual_seed(seed)
 
2439
  if guide_images is not None and all_guide_images_are_same:
2440
  guide_images = guide_images[0]
2441
 
2442
+ # ControlNet使用時はguide imageをリサイズする
2443
+ if control_nets:
2444
+ # TODO resampleのメソッド
2445
+ guide_images = guide_images if type(guide_images) == list else [guide_images]
2446
+ guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images]
2447
+ if len(guide_images) == 1:
2448
+ guide_images = guide_images[0]
2449
+
2450
  # generate
2451
+ if networks:
2452
+ for n, m in zip(networks, network_muls if network_muls else network_default_muls):
2453
+ n.set_multiplier(m)
2454
+
2455
  images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
2456
+ output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises,
2457
+ vae_batch_size=args.vae_batch_size, return_latents=return_latents,
2458
+ clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
2459
+ if highres_1st and not args.highres_fix_save_1st: # return images or latents
2460
  return images
2461
 
2462
  # save image
 
2531
  strength = 0.8 if args.strength is None else args.strength
2532
  negative_prompt = ""
2533
  clip_prompt = None
2534
+ network_muls = None
2535
 
2536
  prompt_args = prompt.strip().split(' --')
2537
  prompt = prompt_args[0]
 
2595
  clip_prompt = m.group(1)
2596
  print(f"clip prompt: {clip_prompt}")
2597
  continue
2598
+
2599
+ m = re.match(r'am ([\d\.\-,]+)', parg, re.IGNORECASE)
2600
+ if m: # network multiplies
2601
+ network_muls = [float(v) for v in m.group(1).split(",")]
2602
+ while len(network_muls) < len(networks):
2603
+ network_muls.append(network_muls[-1])
2604
+ print(f"network mul: {network_muls}")
2605
+ continue
2606
+
2607
  except ValueError as ex:
2608
  print(f"Exception in parsing / 解析エラー: {parg}")
2609
  print(ex)
 
2641
  mask_image = mask_images[global_step % len(mask_images)]
2642
 
2643
  if guide_images is not None:
2644
+ if control_nets: # 複数件の場合あり
2645
+ c = len(control_nets)
2646
+ p = global_step % (len(guide_images) // c)
2647
+ guide_image = guide_images[p * c:p * c + c]
2648
+ else:
2649
+ guide_image = guide_images[global_step % len(guide_images)]
2650
  elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
2651
  if prev_image is None:
2652
  print("Generate 1st image without guide image.")
 
2654
  print("Use previous image as guide image.")
2655
  guide_image = prev_image
2656
 
2657
+ b1 = BatchData(False, BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
2658
+ BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None))
2659
+ if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
 
2660
  process_batch(batch_data, highres_fix)
2661
  batch_data.clear()
2662
 
 
2700
  parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
2701
  parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
2702
  parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
2703
+ parser.add_argument("--vae_batch_size", type=float, default=None,
2704
+ help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率")
2705
  parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
2706
  parser.add_argument('--sampler', type=str, default='ddim',
2707
  choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
 
2713
  parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
2714
  parser.add_argument("--vae", type=str, default=None,
2715
  help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
2716
+ parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
2717
+ help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
2718
  # parser.add_argument("--replace_clip_l14_336", action='store_true',
2719
  # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
2720
  parser.add_argument("--seed", type=int, default=None,
 
2729
  parser.add_argument("--opt_channels_last", action='store_true',
2730
  help='set channels last option to model / モデルにchannels lastを指定し最適化する')
2731
  parser.add_argument("--network_module", type=str, default=None, nargs='*',
2732
+ help='additional network module to use / 追加ネットワークを使う時そのモジュール名')
2733
  parser.add_argument("--network_weights", type=str, default=None, nargs='*',
2734
+ help='additional network weights to load / 追加ネットワークの重み')
2735
+ parser.add_argument("--network_mul", type=float, default=None, nargs='*',
2736
+ help='additional network multiplier / 追加ネットワークの効果の倍率')
2737
  parser.add_argument("--network_args", type=str, default=None, nargs='*',
2738
  help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
2739
+ parser.add_argument("--network_show_meta", action='store_true',
2740
+ help='show metadata of network model / ネットワークモデルのメタデータを表示する')
2741
  parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
2742
  help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
2743
  parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
 
2751
  help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
2752
  parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
2753
  help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
2754
+ parser.add_argument("--guide_image_path", type=str, default=None, nargs="*",
2755
+ help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
2756
  parser.add_argument("--highres_fix_scale", type=float, default=None,
2757
  help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
2758
  parser.add_argument("--highres_fix_steps", type=int, default=28,
2759
  help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
2760
  parser.add_argument("--highres_fix_save_1st", action='store_true',
2761
  help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
2762
+ parser.add_argument("--highres_fix_latents_upscaling", action='store_true',
2763
+ help="use latents upscaling for highres fix / highres fixでlatentで拡大する")
2764
  parser.add_argument("--negative_scale", type=float, default=None,
2765
  help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
2766
 
2767
+ parser.add_argument("--control_net_models", type=str, default=None, nargs='*',
2768
+ help='ControlNet models to use / 使用するControlNetのモデル名')
2769
+ parser.add_argument("--control_net_preps", type=str, default=None, nargs='*',
2770
+ help='ControlNet preprocess to use / 使用するControlNetのプリプロセス名')
2771
+ parser.add_argument("--control_net_weights", type=float, default=None, nargs='*', help='ControlNet weights / ControlNetの重み')
2772
+ parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*',
2773
+ help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率')
2774
+
2775
  args = parser.parse_args()
2776
  main(args)
library/config_util.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from dataclasses import (
3
+ asdict,
4
+ dataclass,
5
+ )
6
+ import functools
7
+ from textwrap import dedent, indent
8
+ import json
9
+ from pathlib import Path
10
+ # from toolz import curry
11
+ from typing import (
12
+ List,
13
+ Optional,
14
+ Sequence,
15
+ Tuple,
16
+ Union,
17
+ )
18
+
19
+ import toml
20
+ import voluptuous
21
+ from voluptuous import (
22
+ Any,
23
+ ExactSequence,
24
+ MultipleInvalid,
25
+ Object,
26
+ Required,
27
+ Schema,
28
+ )
29
+ from transformers import CLIPTokenizer
30
+
31
+ from . import train_util
32
+ from .train_util import (
33
+ DreamBoothSubset,
34
+ FineTuningSubset,
35
+ DreamBoothDataset,
36
+ FineTuningDataset,
37
+ DatasetGroup,
38
+ )
39
+
40
+
41
+ def add_config_arguments(parser: argparse.ArgumentParser):
42
+ parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル")
43
+
44
+ # TODO: inherit Params class in Subset, Dataset
45
+
46
+ @dataclass
47
+ class BaseSubsetParams:
48
+ image_dir: Optional[str] = None
49
+ num_repeats: int = 1
50
+ shuffle_caption: bool = False
51
+ keep_tokens: int = 0
52
+ color_aug: bool = False
53
+ flip_aug: bool = False
54
+ face_crop_aug_range: Optional[Tuple[float, float]] = None
55
+ random_crop: bool = False
56
+ caption_dropout_rate: float = 0.0
57
+ caption_dropout_every_n_epochs: int = 0
58
+ caption_tag_dropout_rate: float = 0.0
59
+
60
+ @dataclass
61
+ class DreamBoothSubsetParams(BaseSubsetParams):
62
+ is_reg: bool = False
63
+ class_tokens: Optional[str] = None
64
+ caption_extension: str = ".caption"
65
+
66
+ @dataclass
67
+ class FineTuningSubsetParams(BaseSubsetParams):
68
+ metadata_file: Optional[str] = None
69
+
70
+ @dataclass
71
+ class BaseDatasetParams:
72
+ tokenizer: CLIPTokenizer = None
73
+ max_token_length: int = None
74
+ resolution: Optional[Tuple[int, int]] = None
75
+ debug_dataset: bool = False
76
+
77
+ @dataclass
78
+ class DreamBoothDatasetParams(BaseDatasetParams):
79
+ batch_size: int = 1
80
+ enable_bucket: bool = False
81
+ min_bucket_reso: int = 256
82
+ max_bucket_reso: int = 1024
83
+ bucket_reso_steps: int = 64
84
+ bucket_no_upscale: bool = False
85
+ prior_loss_weight: float = 1.0
86
+
87
+ @dataclass
88
+ class FineTuningDatasetParams(BaseDatasetParams):
89
+ batch_size: int = 1
90
+ enable_bucket: bool = False
91
+ min_bucket_reso: int = 256
92
+ max_bucket_reso: int = 1024
93
+ bucket_reso_steps: int = 64
94
+ bucket_no_upscale: bool = False
95
+
96
+ @dataclass
97
+ class SubsetBlueprint:
98
+ params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
99
+
100
+ @dataclass
101
+ class DatasetBlueprint:
102
+ is_dreambooth: bool
103
+ params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
104
+ subsets: Sequence[SubsetBlueprint]
105
+
106
+ @dataclass
107
+ class DatasetGroupBlueprint:
108
+ datasets: Sequence[DatasetBlueprint]
109
+ @dataclass
110
+ class Blueprint:
111
+ dataset_group: DatasetGroupBlueprint
112
+
113
+
114
+ class ConfigSanitizer:
115
+ # @curry
116
+ @staticmethod
117
+ def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
118
+ Schema(ExactSequence([klass, klass]))(value)
119
+ return tuple(value)
120
+
121
+ # @curry
122
+ @staticmethod
123
+ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
124
+ Schema(Any(klass, ExactSequence([klass, klass])))(value)
125
+ try:
126
+ Schema(klass)(value)
127
+ return (value, value)
128
+ except:
129
+ return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
130
+
131
+ # subset schema
132
+ SUBSET_ASCENDABLE_SCHEMA = {
133
+ "color_aug": bool,
134
+ "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float),
135
+ "flip_aug": bool,
136
+ "num_repeats": int,
137
+ "random_crop": bool,
138
+ "shuffle_caption": bool,
139
+ "keep_tokens": int,
140
+ }
141
+ # DO means DropOut
142
+ DO_SUBSET_ASCENDABLE_SCHEMA = {
143
+ "caption_dropout_every_n_epochs": int,
144
+ "caption_dropout_rate": Any(float, int),
145
+ "caption_tag_dropout_rate": Any(float, int),
146
+ }
147
+ # DB means DreamBooth
148
+ DB_SUBSET_ASCENDABLE_SCHEMA = {
149
+ "caption_extension": str,
150
+ "class_tokens": str,
151
+ }
152
+ DB_SUBSET_DISTINCT_SCHEMA = {
153
+ Required("image_dir"): str,
154
+ "is_reg": bool,
155
+ }
156
+ # FT means FineTuning
157
+ FT_SUBSET_DISTINCT_SCHEMA = {
158
+ Required("metadata_file"): str,
159
+ "image_dir": str,
160
+ }
161
+
162
+ # datasets schema
163
+ DATASET_ASCENDABLE_SCHEMA = {
164
+ "batch_size": int,
165
+ "bucket_no_upscale": bool,
166
+ "bucket_reso_steps": int,
167
+ "enable_bucket": bool,
168
+ "max_bucket_reso": int,
169
+ "min_bucket_reso": int,
170
+ "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
171
+ }
172
+
173
+ # options handled by argparse but not handled by user config
174
+ ARGPARSE_SPECIFIC_SCHEMA = {
175
+ "debug_dataset": bool,
176
+ "max_token_length": Any(None, int),
177
+ "prior_loss_weight": Any(float, int),
178
+ }
179
+ # for handling default None value of argparse
180
+ ARGPARSE_NULLABLE_OPTNAMES = [
181
+ "face_crop_aug_range",
182
+ "resolution",
183
+ ]
184
+ # prepare map because option name may differ among argparse and user config
185
+ ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
186
+ "train_batch_size": "batch_size",
187
+ "dataset_repeats": "num_repeats",
188
+ }
189
+
190
+ def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None:
191
+ assert support_dreambooth or support_finetuning, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
192
+
193
+ self.db_subset_schema = self.__merge_dict(
194
+ self.SUBSET_ASCENDABLE_SCHEMA,
195
+ self.DB_SUBSET_DISTINCT_SCHEMA,
196
+ self.DB_SUBSET_ASCENDABLE_SCHEMA,
197
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
198
+ )
199
+
200
+ self.ft_subset_schema = self.__merge_dict(
201
+ self.SUBSET_ASCENDABLE_SCHEMA,
202
+ self.FT_SUBSET_DISTINCT_SCHEMA,
203
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
204
+ )
205
+
206
+ self.db_dataset_schema = self.__merge_dict(
207
+ self.DATASET_ASCENDABLE_SCHEMA,
208
+ self.SUBSET_ASCENDABLE_SCHEMA,
209
+ self.DB_SUBSET_ASCENDABLE_SCHEMA,
210
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
211
+ {"subsets": [self.db_subset_schema]},
212
+ )
213
+
214
+ self.ft_dataset_schema = self.__merge_dict(
215
+ self.DATASET_ASCENDABLE_SCHEMA,
216
+ self.SUBSET_ASCENDABLE_SCHEMA,
217
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
218
+ {"subsets": [self.ft_subset_schema]},
219
+ )
220
+
221
+ if support_dreambooth and support_finetuning:
222
+ def validate_flex_dataset(dataset_config: dict):
223
+ subsets_config = dataset_config.get("subsets", [])
224
+
225
+ # check dataset meets FT style
226
+ # NOTE: all FT subsets should have "metadata_file"
227
+ if all(["metadata_file" in subset for subset in subsets_config]):
228
+ return Schema(self.ft_dataset_schema)(dataset_config)
229
+ # check dataset meets DB style
230
+ # NOTE: all DB subsets should have no "metadata_file"
231
+ elif all(["metadata_file" not in subset for subset in subsets_config]):
232
+ return Schema(self.db_dataset_schema)(dataset_config)
233
+ else:
234
+ raise voluptuous.Invalid("DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。")
235
+
236
+ self.dataset_schema = validate_flex_dataset
237
+ elif support_dreambooth:
238
+ self.dataset_schema = self.db_dataset_schema
239
+ else:
240
+ self.dataset_schema = self.ft_dataset_schema
241
+
242
+ self.general_schema = self.__merge_dict(
243
+ self.DATASET_ASCENDABLE_SCHEMA,
244
+ self.SUBSET_ASCENDABLE_SCHEMA,
245
+ self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
246
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
247
+ )
248
+
249
+ self.user_config_validator = Schema({
250
+ "general": self.general_schema,
251
+ "datasets": [self.dataset_schema],
252
+ })
253
+
254
+ self.argparse_schema = self.__merge_dict(
255
+ self.general_schema,
256
+ self.ARGPARSE_SPECIFIC_SCHEMA,
257
+ {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES},
258
+ {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()},
259
+ )
260
+
261
+ self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
262
+
263
+ def sanitize_user_config(self, user_config: dict) -> dict:
264
+ try:
265
+ return self.user_config_validator(user_config)
266
+ except MultipleInvalid:
267
+ # TODO: エラー発生時のメッセージをわかりやすくする
268
+ print("Invalid user config / ユーザ設定の形式が正しくないようです")
269
+ raise
270
+
271
+ # NOTE: In nature, argument parser result is not needed to be sanitize
272
+ # However this will help us to detect program bug
273
+ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
274
+ try:
275
+ return self.argparse_config_validator(argparse_namespace)
276
+ except MultipleInvalid:
277
+ # XXX: this should be a bug
278
+ print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
279
+ raise
280
+
281
+ # NOTE: value would be overwritten by latter dict if there is already the same key
282
+ @staticmethod
283
+ def __merge_dict(*dict_list: dict) -> dict:
284
+ merged = {}
285
+ for schema in dict_list:
286
+ # merged |= schema
287
+ for k, v in schema.items():
288
+ merged[k] = v
289
+ return merged
290
+
291
+
292
+ class BlueprintGenerator:
293
+ BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {
294
+ }
295
+
296
+ def __init__(self, sanitizer: ConfigSanitizer):
297
+ self.sanitizer = sanitizer
298
+
299
+ # runtime_params is for parameters which is only configurable on runtime, such as tokenizer
300
+ def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
301
+ sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
302
+ sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
303
+
304
+ # convert argparse namespace to dict like config
305
+ # NOTE: it is ok to have extra entries in dict
306
+ optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
307
+ argparse_config = {optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()}
308
+
309
+ general_config = sanitized_user_config.get("general", {})
310
+
311
+ dataset_blueprints = []
312
+ for dataset_config in sanitized_user_config.get("datasets", []):
313
+ # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
314
+ subsets = dataset_config.get("subsets", [])
315
+ is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
316
+ if is_dreambooth:
317
+ subset_params_klass = DreamBoothSubsetParams
318
+ dataset_params_klass = DreamBoothDatasetParams
319
+ else:
320
+ subset_params_klass = FineTuningSubsetParams
321
+ dataset_params_klass = FineTuningDatasetParams
322
+
323
+ subset_blueprints = []
324
+ for subset_config in subsets:
325
+ params = self.generate_params_by_fallbacks(subset_params_klass,
326
+ [subset_config, dataset_config, general_config, argparse_config, runtime_params])
327
+ subset_blueprints.append(SubsetBlueprint(params))
328
+
329
+ params = self.generate_params_by_fallbacks(dataset_params_klass,
330
+ [dataset_config, general_config, argparse_config, runtime_params])
331
+ dataset_blueprints.append(DatasetBlueprint(is_dreambooth, params, subset_blueprints))
332
+
333
+ dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
334
+
335
+ return Blueprint(dataset_group_blueprint)
336
+
337
+ @staticmethod
338
+ def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
339
+ name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
340
+ search_value = BlueprintGenerator.search_value
341
+ default_params = asdict(param_klass())
342
+ param_names = default_params.keys()
343
+
344
+ params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
345
+
346
+ return param_klass(**params)
347
+
348
+ @staticmethod
349
+ def search_value(key: str, fallbacks: Sequence[dict], default_value = None):
350
+ for cand in fallbacks:
351
+ value = cand.get(key)
352
+ if value is not None:
353
+ return value
354
+
355
+ return default_value
356
+
357
+
358
+ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
359
+ datasets: List[Union[DreamBoothDataset, FineTuningDataset]] = []
360
+
361
+ for dataset_blueprint in dataset_group_blueprint.datasets:
362
+ if dataset_blueprint.is_dreambooth:
363
+ subset_klass = DreamBoothSubset
364
+ dataset_klass = DreamBoothDataset
365
+ else:
366
+ subset_klass = FineTuningSubset
367
+ dataset_klass = FineTuningDataset
368
+
369
+ subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
370
+ dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
371
+ datasets.append(dataset)
372
+
373
+ # print info
374
+ info = ""
375
+ for i, dataset in enumerate(datasets):
376
+ is_dreambooth = isinstance(dataset, DreamBoothDataset)
377
+ info += dedent(f"""\
378
+ [Dataset {i}]
379
+ batch_size: {dataset.batch_size}
380
+ resolution: {(dataset.width, dataset.height)}
381
+ enable_bucket: {dataset.enable_bucket}
382
+ """)
383
+
384
+ if dataset.enable_bucket:
385
+ info += indent(dedent(f"""\
386
+ min_bucket_reso: {dataset.min_bucket_reso}
387
+ max_bucket_reso: {dataset.max_bucket_reso}
388
+ bucket_reso_steps: {dataset.bucket_reso_steps}
389
+ bucket_no_upscale: {dataset.bucket_no_upscale}
390
+ \n"""), " ")
391
+ else:
392
+ info += "\n"
393
+
394
+ for j, subset in enumerate(dataset.subsets):
395
+ info += indent(dedent(f"""\
396
+ [Subset {j} of Dataset {i}]
397
+ image_dir: "{subset.image_dir}"
398
+ image_count: {subset.img_count}
399
+ num_repeats: {subset.num_repeats}
400
+ shuffle_caption: {subset.shuffle_caption}
401
+ keep_tokens: {subset.keep_tokens}
402
+ caption_dropout_rate: {subset.caption_dropout_rate}
403
+ caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
404
+ caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
405
+ color_aug: {subset.color_aug}
406
+ flip_aug: {subset.flip_aug}
407
+ face_crop_aug_range: {subset.face_crop_aug_range}
408
+ random_crop: {subset.random_crop}
409
+ """), " ")
410
+
411
+ if is_dreambooth:
412
+ info += indent(dedent(f"""\
413
+ is_reg: {subset.is_reg}
414
+ class_tokens: {subset.class_tokens}
415
+ caption_extension: {subset.caption_extension}
416
+ \n"""), " ")
417
+ else:
418
+ info += indent(dedent(f"""\
419
+ metadata_file: {subset.metadata_file}
420
+ \n"""), " ")
421
+
422
+ print(info)
423
+
424
+ # make buckets first because it determines the length of dataset
425
+ for i, dataset in enumerate(datasets):
426
+ print(f"[Dataset {i}]")
427
+ dataset.make_buckets()
428
+
429
+ return DatasetGroup(datasets)
430
+
431
+
432
+ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
433
+ def extract_dreambooth_params(name: str) -> Tuple[int, str]:
434
+ tokens = name.split('_')
435
+ try:
436
+ n_repeats = int(tokens[0])
437
+ except ValueError as e:
438
+ print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
439
+ return 0, ""
440
+ caption_by_folder = '_'.join(tokens[1:])
441
+ return n_repeats, caption_by_folder
442
+
443
+ def generate(base_dir: Optional[str], is_reg: bool):
444
+ if base_dir is None:
445
+ return []
446
+
447
+ base_dir: Path = Path(base_dir)
448
+ if not base_dir.is_dir():
449
+ return []
450
+
451
+ subsets_config = []
452
+ for subdir in base_dir.iterdir():
453
+ if not subdir.is_dir():
454
+ continue
455
+
456
+ num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
457
+ if num_repeats < 1:
458
+ continue
459
+
460
+ subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens}
461
+ subsets_config.append(subset_config)
462
+
463
+ return subsets_config
464
+
465
+ subsets_config = []
466
+ subsets_config += generate(train_data_dir, False)
467
+ subsets_config += generate(reg_data_dir, True)
468
+
469
+ return subsets_config
470
+
471
+
472
+ def load_user_config(file: str) -> dict:
473
+ file: Path = Path(file)
474
+ if not file.is_file():
475
+ raise ValueError(f"file not found / ファイルが見つかりません: {file}")
476
+
477
+ if file.name.lower().endswith('.json'):
478
+ try:
479
+ config = json.load(file)
480
+ except Exception:
481
+ print(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
482
+ raise
483
+ elif file.name.lower().endswith('.toml'):
484
+ try:
485
+ config = toml.load(file)
486
+ except Exception:
487
+ print(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
488
+ raise
489
+ else:
490
+ raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
491
+
492
+ return config
493
+
494
+
495
+ # for config test
496
+ if __name__ == "__main__":
497
+ parser = argparse.ArgumentParser()
498
+ parser.add_argument("--support_dreambooth", action="store_true")
499
+ parser.add_argument("--support_finetuning", action="store_true")
500
+ parser.add_argument("--support_dropout", action="store_true")
501
+ parser.add_argument("dataset_config")
502
+ config_args, remain = parser.parse_known_args()
503
+
504
+ parser = argparse.ArgumentParser()
505
+ train_util.add_dataset_arguments(parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout)
506
+ train_util.add_training_arguments(parser, config_args.support_dreambooth)
507
+ argparse_namespace = parser.parse_args(remain)
508
+ train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
509
+
510
+ print("[argparse_namespace]")
511
+ print(vars(argparse_namespace))
512
+
513
+ user_config = load_user_config(config_args.dataset_config)
514
+
515
+ print("\n[user_config]")
516
+ print(user_config)
517
+
518
+ sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout)
519
+ sanitized_user_config = sanitizer.sanitize_user_config(user_config)
520
+
521
+ print("\n[sanitized_user_config]")
522
+ print(sanitized_user_config)
523
+
524
+ blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
525
+
526
+ print("\n[blueprint]")
527
+ print(blueprint)
library/train_util.py CHANGED
@@ -1,12 +1,21 @@
1
  # common functions for training
2
 
3
  import argparse
 
4
  import json
 
5
  import shutil
6
  import time
7
- from typing import Dict, List, NamedTuple, Tuple
 
 
 
 
 
 
 
 
8
  from accelerate import Accelerator
9
- from torch.autograd.function import Function
10
  import glob
11
  import math
12
  import os
@@ -17,10 +26,16 @@ from io import BytesIO
17
 
18
  from tqdm import tqdm
19
  import torch
 
20
  from torchvision import transforms
21
  from transformers import CLIPTokenizer
 
22
  import diffusers
23
- from diffusers import DDPMScheduler, StableDiffusionPipeline
 
 
 
 
24
  import albumentations as albu
25
  import numpy as np
26
  from PIL import Image
@@ -195,23 +210,93 @@ class BucketBatchIndex(NamedTuple):
195
  batch_index: int
196
 
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  class BaseDataset(torch.utils.data.Dataset):
199
- def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
200
  super().__init__()
201
- self.tokenizer: CLIPTokenizer = tokenizer
202
  self.max_token_length = max_token_length
203
- self.shuffle_caption = shuffle_caption
204
- self.shuffle_keep_tokens = shuffle_keep_tokens
205
  # width/height is used when enable_bucket==False
206
  self.width, self.height = (None, None) if resolution is None else resolution
207
- self.face_crop_aug_range = face_crop_aug_range
208
- self.flip_aug = flip_aug
209
- self.color_aug = color_aug
210
  self.debug_dataset = debug_dataset
211
- self.random_crop = random_crop
 
 
212
  self.token_padding_disabled = False
213
- self.dataset_dirs_info = {}
214
- self.reg_dataset_dirs_info = {}
215
  self.tag_frequency = {}
216
 
217
  self.enable_bucket = False
@@ -225,49 +310,28 @@ class BaseDataset(torch.utils.data.Dataset):
225
  self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
226
 
227
  self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
228
- self.dropout_rate: float = 0
229
- self.dropout_every_n_epochs: int = None
230
- self.tag_dropout_rate: float = 0
231
 
232
  # augmentation
233
- flip_p = 0.5 if flip_aug else 0.0
234
- if color_aug:
235
- # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
236
- self.aug = albu.Compose([
237
- albu.OneOf([
238
- albu.HueSaturationValue(8, 0, 0, p=.5),
239
- albu.RandomGamma((95, 105), p=.5),
240
- ], p=.33),
241
- albu.HorizontalFlip(p=flip_p)
242
- ], p=1.)
243
- elif flip_aug:
244
- self.aug = albu.Compose([
245
- albu.HorizontalFlip(p=flip_p)
246
- ], p=1.)
247
- else:
248
- self.aug = None
249
 
250
  self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
251
 
252
  self.image_data: Dict[str, ImageInfo] = {}
 
253
 
254
  self.replacements = {}
255
 
256
  def set_current_epoch(self, epoch):
257
  self.current_epoch = epoch
258
-
259
- def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
260
- # コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
261
- self.dropout_rate = dropout_rate
262
- self.dropout_every_n_epochs = dropout_every_n_epochs
263
- self.tag_dropout_rate = tag_dropout_rate
264
 
265
  def set_tag_frequency(self, dir_name, captions):
266
  frequency_for_dir = self.tag_frequency.get(dir_name, {})
267
  self.tag_frequency[dir_name] = frequency_for_dir
268
  for caption in captions:
269
  for tag in caption.split(","):
270
- if tag and not tag.isspace():
 
271
  tag = tag.lower()
272
  frequency = frequency_for_dir.get(tag, 0)
273
  frequency_for_dir[tag] = frequency + 1
@@ -278,42 +342,36 @@ class BaseDataset(torch.utils.data.Dataset):
278
  def add_replacement(self, str_from, str_to):
279
  self.replacements[str_from] = str_to
280
 
281
- def process_caption(self, caption):
282
  # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
283
- is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
284
- is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
285
 
286
  if is_drop_out:
287
  caption = ""
288
  else:
289
- if self.shuffle_caption or self.tag_dropout_rate > 0:
290
  def dropout_tags(tokens):
291
- if self.tag_dropout_rate <= 0:
292
  return tokens
293
  l = []
294
  for token in tokens:
295
- if random.random() >= self.tag_dropout_rate:
296
  l.append(token)
297
  return l
298
 
299
- tokens = [t.strip() for t in caption.strip().split(",")]
300
- if self.shuffle_keep_tokens is None:
301
- if self.shuffle_caption:
302
- random.shuffle(tokens)
303
-
304
- tokens = dropout_tags(tokens)
305
- else:
306
- if len(tokens) > self.shuffle_keep_tokens:
307
- keep_tokens = tokens[:self.shuffle_keep_tokens]
308
- tokens = tokens[self.shuffle_keep_tokens:]
309
 
310
- if self.shuffle_caption:
311
- random.shuffle(tokens)
312
 
313
- tokens = dropout_tags(tokens)
314
 
315
- tokens = keep_tokens + tokens
316
- caption = ", ".join(tokens)
317
 
318
  # textual inversion対応
319
  for str_from, str_to in self.replacements.items():
@@ -367,8 +425,9 @@ class BaseDataset(torch.utils.data.Dataset):
367
  input_ids = torch.stack(iids_list) # 3,77
368
  return input_ids
369
 
370
- def register_image(self, info: ImageInfo):
371
  self.image_data[info.image_key] = info
 
372
 
373
  def make_buckets(self):
374
  '''
@@ -467,7 +526,7 @@ class BaseDataset(torch.utils.data.Dataset):
467
  img = np.array(image, np.uint8)
468
  return img
469
 
470
- def trim_and_resize_if_required(self, image, reso, resized_size):
471
  image_height, image_width = image.shape[0:2]
472
 
473
  if image_width != resized_size[0] or image_height != resized_size[1]:
@@ -477,22 +536,27 @@ class BaseDataset(torch.utils.data.Dataset):
477
  image_height, image_width = image.shape[0:2]
478
  if image_width > reso[0]:
479
  trim_size = image_width - reso[0]
480
- p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
481
  # print("w", trim_size, p)
482
  image = image[:, p:p + reso[0]]
483
  if image_height > reso[1]:
484
  trim_size = image_height - reso[1]
485
- p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
486
  # print("h", trim_size, p)
487
  image = image[p:p + reso[1]]
488
 
489
  assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
490
  return image
491
 
 
 
 
492
  def cache_latents(self, vae):
493
  # TODO ここを高速化したい
494
  print("caching latents.")
495
  for info in tqdm(self.image_data.values()):
 
 
496
  if info.latents_npz is not None:
497
  info.latents = self.load_latents_from_npz(info, False)
498
  info.latents = torch.FloatTensor(info.latents)
@@ -502,13 +566,13 @@ class BaseDataset(torch.utils.data.Dataset):
502
  continue
503
 
504
  image = self.load_image(info.absolute_path)
505
- image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
506
 
507
  img_tensor = self.image_transforms(image)
508
  img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
509
  info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
510
 
511
- if self.flip_aug:
512
  image = image[:, ::-1].copy() # cannot convert to Tensor without copy
513
  img_tensor = self.image_transforms(image)
514
  img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
@@ -518,11 +582,11 @@ class BaseDataset(torch.utils.data.Dataset):
518
  image = Image.open(image_path)
519
  return image.size
520
 
521
- def load_image_with_face_info(self, image_path: str):
522
  img = self.load_image(image_path)
523
 
524
  face_cx = face_cy = face_w = face_h = 0
525
- if self.face_crop_aug_range is not None:
526
  tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
527
  if len(tokens) >= 5:
528
  face_cx = int(tokens[-4])
@@ -533,7 +597,7 @@ class BaseDataset(torch.utils.data.Dataset):
533
  return img, face_cx, face_cy, face_w, face_h
534
 
535
  # いい感じに切り出す
536
- def crop_target(self, image, face_cx, face_cy, face_w, face_h):
537
  height, width = image.shape[0:2]
538
  if height == self.height and width == self.width:
539
  return image
@@ -541,8 +605,8 @@ class BaseDataset(torch.utils.data.Dataset):
541
  # 画像サイズはsizeより大きいのでリサイズする
542
  face_size = max(face_w, face_h)
543
  min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
544
- min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
545
- max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
546
  if min_scale >= max_scale: # range指定がmin==max
547
  scale = min_scale
548
  else:
@@ -560,13 +624,13 @@ class BaseDataset(torch.utils.data.Dataset):
560
  for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
561
  p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
562
 
563
- if self.random_crop:
564
  # 背景も含めるために顔を中心に置く確率を高めつつずらす
565
  range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
566
  p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
567
  else:
568
  # range指定があるときのみ、すこしだけランダムに(わりと適当)
569
- if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
570
  if face_size > self.size // 10 and face_size >= 40:
571
  p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
572
 
@@ -589,9 +653,6 @@ class BaseDataset(torch.utils.data.Dataset):
589
  return self._length
590
 
591
  def __getitem__(self, index):
592
- if index == 0:
593
- self.shuffle_buckets()
594
-
595
  bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
596
  bucket_batch_size = self.buckets_indices[index].bucket_batch_size
597
  image_index = self.buckets_indices[index].batch_index * bucket_batch_size
@@ -604,28 +665,29 @@ class BaseDataset(torch.utils.data.Dataset):
604
 
605
  for image_key in bucket[image_index:image_index + bucket_batch_size]:
606
  image_info = self.image_data[image_key]
 
607
  loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
608
 
609
  # image/latentsを処理する
610
  if image_info.latents is not None:
611
- latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped
612
  image = None
613
  elif image_info.latents_npz is not None:
614
- latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5)
615
  latents = torch.FloatTensor(latents)
616
  image = None
617
  else:
618
  # 画像を読み込み、必要ならcropする
619
- img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
620
  im_h, im_w = img.shape[0:2]
621
 
622
  if self.enable_bucket:
623
- img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
624
  else:
625
  if face_cx > 0: # 顔位置情報あり
626
- img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
627
  elif im_h > self.height or im_w > self.width:
628
- assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
629
  if im_h > self.height:
630
  p = random.randint(0, im_h - self.height)
631
  img = img[p:p + self.height]
@@ -637,8 +699,9 @@ class BaseDataset(torch.utils.data.Dataset):
637
  assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
638
 
639
  # augmentation
640
- if self.aug is not None:
641
- img = self.aug(image=img)['image']
 
642
 
643
  latents = None
644
  image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
@@ -646,7 +709,7 @@ class BaseDataset(torch.utils.data.Dataset):
646
  images.append(image)
647
  latents_list.append(latents)
648
 
649
- caption = self.process_caption(image_info.caption)
650
  captions.append(caption)
651
  if not self.token_padding_disabled: # this option might be omitted in future
652
  input_ids_list.append(self.get_input_ids(caption))
@@ -677,9 +740,8 @@ class BaseDataset(torch.utils.data.Dataset):
677
 
678
 
679
  class DreamBoothDataset(BaseDataset):
680
- def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
681
- super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
682
- resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
683
 
684
  assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
685
 
@@ -702,7 +764,7 @@ class DreamBoothDataset(BaseDataset):
702
  self.bucket_reso_steps = None # この情報は使われない
703
  self.bucket_no_upscale = False
704
 
705
- def read_caption(img_path):
706
  # captionの候補ファイル名を作る
707
  base_name = os.path.splitext(img_path)[0]
708
  base_name_face_det = base_name
@@ -725,153 +787,171 @@ class DreamBoothDataset(BaseDataset):
725
  break
726
  return caption
727
 
728
- def load_dreambooth_dir(dir):
729
- if not os.path.isdir(dir):
730
- # print(f"ignore file: {dir}")
731
- return 0, [], []
732
-
733
- tokens = os.path.basename(dir).split('_')
734
- try:
735
- n_repeats = int(tokens[0])
736
- except ValueError as e:
737
- print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
738
- return 0, [], []
739
 
740
- caption_by_folder = '_'.join(tokens[1:])
741
- img_paths = glob_images(dir, "*")
742
- print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
743
 
744
  # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
745
  captions = []
746
  for img_path in img_paths:
747
- cap_for_img = read_caption(img_path)
748
- captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
749
-
750
- self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
 
 
 
 
751
 
752
- return n_repeats, img_paths, captions
753
 
754
- print("prepare train images.")
755
- train_dirs = os.listdir(train_data_dir)
756
  num_train_images = 0
757
- for dir in train_dirs:
758
- n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
759
- num_train_images += n_repeats * len(img_paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
760
 
761
  for img_path, caption in zip(img_paths, captions):
762
- info = ImageInfo(img_path, n_repeats, caption, False, img_path)
763
- self.register_image(info)
 
 
 
764
 
765
- self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
 
766
 
767
  print(f"{num_train_images} train images with repeating.")
768
  self.num_train_images = num_train_images
769
 
770
- # reg imageは数を数えて学習画像と同じ枚数にする
771
- num_reg_images = 0
772
- if reg_data_dir:
773
- print("prepare reg images.")
774
- reg_infos: List[ImageInfo] = []
775
 
776
- reg_dirs = os.listdir(reg_data_dir)
777
- for dir in reg_dirs:
778
- n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
779
- num_reg_images += n_repeats * len(img_paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
780
 
781
- for img_path, caption in zip(img_paths, captions):
782
- info = ImageInfo(img_path, n_repeats, caption, True, img_path)
783
- reg_infos.append(info)
784
 
785
- self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
786
 
787
- print(f"{num_reg_images} reg images.")
788
- if num_train_images < num_reg_images:
789
- print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
 
 
 
 
 
790
 
791
- if num_reg_images == 0:
792
- print("no regularization images / 正則化画像が見つかりませんでした")
 
 
 
 
 
 
 
 
 
 
 
 
793
  else:
794
- # num_repeatsを計算する:どうせ大した数ではないのでループで処理する
795
- n = 0
796
- first_loop = True
797
- while n < num_train_images:
798
- for info in reg_infos:
799
- if first_loop:
800
- self.register_image(info)
801
- n += info.num_repeats
802
- else:
803
- info.num_repeats += 1
804
- n += 1
805
- if n >= num_train_images:
806
- break
807
- first_loop = False
808
 
809
- self.num_reg_images = num_reg_images
 
 
810
 
 
 
 
 
 
 
 
 
 
 
811
 
812
- class FineTuningDataset(BaseDataset):
813
- def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
814
- super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
815
- resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
816
-
817
- # メタデータを読み込む
818
- if os.path.exists(json_file_name):
819
- print(f"loading existing metadata: {json_file_name}")
820
- with open(json_file_name, "rt", encoding='utf-8') as f:
821
- metadata = json.load(f)
822
- else:
823
- raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}")
824
 
825
- self.metadata = metadata
826
- self.train_data_dir = train_data_dir
827
- self.batch_size = batch_size
828
 
829
- tags_list = []
830
- for image_key, img_md in metadata.items():
831
- # path情報を作る
832
- if os.path.exists(image_key):
833
- abs_path = image_key
834
- else:
835
- # わりといい加減だがいい方法が思いつかん
836
- abs_path = glob_images(train_data_dir, image_key)
837
- assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
838
- abs_path = abs_path[0]
839
-
840
- caption = img_md.get('caption')
841
- tags = img_md.get('tags')
842
- if caption is None:
843
- caption = tags
844
- elif tags is not None and len(tags) > 0:
845
- caption = caption + ', ' + tags
846
- tags_list.append(tags)
847
- assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
848
-
849
- image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
850
- image_info.image_size = img_md.get('train_resolution')
851
-
852
- if not self.color_aug and not self.random_crop:
853
- # if npz exists, use them
854
- image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
855
-
856
- self.register_image(image_info)
857
- self.num_train_images = len(metadata) * dataset_repeats
858
- self.num_reg_images = 0
859
 
860
- # TODO do not record tag freq when no tag
861
- self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
862
- self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
 
 
 
863
 
864
  # check existence of all npz files
865
- use_npz_latents = not (self.color_aug or self.random_crop)
866
  if use_npz_latents:
 
867
  npz_any = False
868
  npz_all = True
 
869
  for image_info in self.image_data.values():
 
 
870
  has_npz = image_info.latents_npz is not None
871
  npz_any = npz_any or has_npz
872
 
873
- if self.flip_aug:
874
  has_npz = has_npz and image_info.latents_npz_flipped is not None
 
875
  npz_all = npz_all and has_npz
876
 
877
  if npz_any and not npz_all:
@@ -883,7 +963,7 @@ class FineTuningDataset(BaseDataset):
883
  elif not npz_all:
884
  use_npz_latents = False
885
  print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
886
- if self.flip_aug:
887
  print("maybe no flipped files / ��転されたnpzファイルがないのかもしれません")
888
  # else:
889
  # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
@@ -929,7 +1009,7 @@ class FineTuningDataset(BaseDataset):
929
  for image_info in self.image_data.values():
930
  image_info.latents_npz = image_info.latents_npz_flipped = None
931
 
932
- def image_key_to_npz_file(self, image_key):
933
  base_name = os.path.splitext(image_key)[0]
934
  npz_file_norm = base_name + '.npz'
935
 
@@ -941,8 +1021,8 @@ class FineTuningDataset(BaseDataset):
941
  return npz_file_norm, npz_file_flip
942
 
943
  # image_key is relative path
944
- npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz')
945
- npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz')
946
 
947
  if not os.path.exists(npz_file_norm):
948
  npz_file_norm = None
@@ -953,13 +1033,60 @@ class FineTuningDataset(BaseDataset):
953
  return npz_file_norm, npz_file_flip
954
 
955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
956
  def debug_dataset(train_dataset, show_input_ids=False):
957
  print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
958
  print("Escape for exit. / Escキーで中断、終了します")
959
 
960
  train_dataset.set_current_epoch(1)
961
  k = 0
962
- for i, example in enumerate(train_dataset):
 
 
 
963
  if example['latents'] is not None:
964
  print(f"sample has latents from npz file: {example['latents'].size()}")
965
  for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
@@ -1364,6 +1491,35 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
1364
  help='enable v-parameterization training / v-parameterization学習を有効にする')
1365
  parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
1366
  help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1367
 
1368
 
1369
  def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
@@ -1387,10 +1543,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
1387
  parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
1388
  parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
1389
  help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
1390
- parser.add_argument("--use_8bit_adam", action="store_true",
1391
- help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
1392
- parser.add_argument("--use_lion_optimizer", action="store_true",
1393
- help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
1394
  parser.add_argument("--mem_eff_attn", action="store_true",
1395
  help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
1396
  parser.add_argument("--xformers", action="store_true",
@@ -1398,7 +1550,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
1398
  parser.add_argument("--vae", type=str, default=None,
1399
  help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
1400
 
1401
- parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
1402
  parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
1403
  parser.add_argument("--max_train_epochs", type=int, default=None,
1404
  help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
@@ -1419,15 +1570,23 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
1419
  parser.add_argument("--logging_dir", type=str, default=None,
1420
  help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
1421
  parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
1422
- parser.add_argument("--lr_scheduler", type=str, default="constant",
1423
- help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
1424
- parser.add_argument("--lr_warmup_steps", type=int, default=0,
1425
- help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
1426
  parser.add_argument("--noise_offset", type=float, default=None,
1427
  help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
1428
  parser.add_argument("--lowram", action="store_true",
1429
  help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
1430
 
 
 
 
 
 
 
 
 
 
 
 
 
1431
  if support_dreambooth:
1432
  # DreamBooth training
1433
  parser.add_argument("--prior_loss_weight", type=float, default=1.0,
@@ -1449,8 +1608,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
1449
  parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
1450
  parser.add_argument("--caption_extention", type=str, default=None,
1451
  help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
1452
- parser.add_argument("--keep_tokens", type=int, default=None,
1453
- help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
1454
  parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
1455
  parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
1456
  parser.add_argument("--face_crop_aug_range", type=str, default=None,
@@ -1475,11 +1634,11 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
1475
  if support_caption_dropout:
1476
  # Textual Inversion はcaptionのdropoutをsupportしない
1477
  # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
1478
- parser.add_argument("--caption_dropout_rate", type=float, default=0,
1479
  help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
1480
- parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
1481
  help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
1482
- parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
1483
  help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
1484
 
1485
  if support_dreambooth:
@@ -1504,16 +1663,249 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
1504
  # region utils
1505
 
1506
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1507
  def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
1508
  # backward compatibility
1509
  if args.caption_extention is not None:
1510
  args.caption_extension = args.caption_extention
1511
  args.caption_extention = None
1512
 
1513
- if args.cache_latents:
1514
- assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
1515
- assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
1516
-
1517
  # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
1518
  if args.resolution is not None:
1519
  args.resolution = tuple([int(r) for r in args.resolution.split(',')])
@@ -1536,12 +1928,28 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
1536
 
1537
  def load_tokenizer(args: argparse.Namespace):
1538
  print("prepare tokenizer")
1539
- if args.v2:
1540
- tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
1541
- else:
1542
- tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
1543
- if args.max_token_length is not None:
 
 
 
 
 
 
 
 
 
 
 
1544
  print(f"update token length: {args.max_token_length}")
 
 
 
 
 
1545
  return tokenizer
1546
 
1547
 
@@ -1592,13 +2000,19 @@ def prepare_dtype(args: argparse.Namespace):
1592
 
1593
 
1594
  def load_target_model(args: argparse.Namespace, weight_dtype):
1595
- load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers
 
 
1596
  if load_stable_diffusion_format:
1597
  print("load StableDiffusion checkpoint")
1598
- text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
1599
  else:
1600
  print("load Diffusers pretrained models")
1601
- pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
 
 
 
 
1602
  text_encoder = pipe.text_encoder
1603
  vae = pipe.vae
1604
  unet = pipe.unet
@@ -1767,6 +2181,185 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
1767
  model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
1768
  accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
1769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1770
  # endregion
1771
 
1772
  # region 前処理用
 
1
  # common functions for training
2
 
3
  import argparse
4
+ import importlib
5
  import json
6
+ import re
7
  import shutil
8
  import time
9
+ from typing import (
10
+ Dict,
11
+ List,
12
+ NamedTuple,
13
+ Optional,
14
+ Sequence,
15
+ Tuple,
16
+ Union,
17
+ )
18
  from accelerate import Accelerator
 
19
  import glob
20
  import math
21
  import os
 
26
 
27
  from tqdm import tqdm
28
  import torch
29
+ from torch.optim import Optimizer
30
  from torchvision import transforms
31
  from transformers import CLIPTokenizer
32
+ import transformers
33
  import diffusers
34
+ from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
35
+ from diffusers import (StableDiffusionPipeline, DDPMScheduler,
36
+ EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler,
37
+ LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler,
38
+ KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler)
39
  import albumentations as albu
40
  import numpy as np
41
  from PIL import Image
 
210
  batch_index: int
211
 
212
 
213
+ class AugHelper:
214
+ def __init__(self):
215
+ # prepare all possible augmentators
216
+ color_aug_method = albu.OneOf([
217
+ albu.HueSaturationValue(8, 0, 0, p=.5),
218
+ albu.RandomGamma((95, 105), p=.5),
219
+ ], p=.33)
220
+ flip_aug_method = albu.HorizontalFlip(p=0.5)
221
+
222
+ # key: (use_color_aug, use_flip_aug)
223
+ self.augmentors = {
224
+ (True, True): albu.Compose([
225
+ color_aug_method,
226
+ flip_aug_method,
227
+ ], p=1.),
228
+ (True, False): albu.Compose([
229
+ color_aug_method,
230
+ ], p=1.),
231
+ (False, True): albu.Compose([
232
+ flip_aug_method,
233
+ ], p=1.),
234
+ (False, False): None
235
+ }
236
+
237
+ def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]:
238
+ return self.augmentors[(use_color_aug, use_flip_aug)]
239
+
240
+
241
+ class BaseSubset:
242
+ def __init__(self, image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, keep_tokens: int, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], random_crop: bool, caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float) -> None:
243
+ self.image_dir = image_dir
244
+ self.num_repeats = num_repeats
245
+ self.shuffle_caption = shuffle_caption
246
+ self.keep_tokens = keep_tokens
247
+ self.color_aug = color_aug
248
+ self.flip_aug = flip_aug
249
+ self.face_crop_aug_range = face_crop_aug_range
250
+ self.random_crop = random_crop
251
+ self.caption_dropout_rate = caption_dropout_rate
252
+ self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
253
+ self.caption_tag_dropout_rate = caption_tag_dropout_rate
254
+
255
+ self.img_count = 0
256
+
257
+
258
+ class DreamBoothSubset(BaseSubset):
259
+ def __init__(self, image_dir: str, is_reg: bool, class_tokens: Optional[str], caption_extension: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
260
+ assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
261
+
262
+ super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
263
+ face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
264
+
265
+ self.is_reg = is_reg
266
+ self.class_tokens = class_tokens
267
+ self.caption_extension = caption_extension
268
+
269
+ def __eq__(self, other) -> bool:
270
+ if not isinstance(other, DreamBoothSubset):
271
+ return NotImplemented
272
+ return self.image_dir == other.image_dir
273
+
274
+ class FineTuningSubset(BaseSubset):
275
+ def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
276
+ assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
277
+
278
+ super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
279
+ face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
280
+
281
+ self.metadata_file = metadata_file
282
+
283
+ def __eq__(self, other) -> bool:
284
+ if not isinstance(other, FineTuningSubset):
285
+ return NotImplemented
286
+ return self.metadata_file == other.metadata_file
287
+
288
  class BaseDataset(torch.utils.data.Dataset):
289
+ def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None:
290
  super().__init__()
291
+ self.tokenizer = tokenizer
292
  self.max_token_length = max_token_length
 
 
293
  # width/height is used when enable_bucket==False
294
  self.width, self.height = (None, None) if resolution is None else resolution
 
 
 
295
  self.debug_dataset = debug_dataset
296
+
297
+ self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
298
+
299
  self.token_padding_disabled = False
 
 
300
  self.tag_frequency = {}
301
 
302
  self.enable_bucket = False
 
310
  self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
311
 
312
  self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
 
 
 
313
 
314
  # augmentation
315
+ self.aug_helper = AugHelper()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
  self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
318
 
319
  self.image_data: Dict[str, ImageInfo] = {}
320
+ self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
321
 
322
  self.replacements = {}
323
 
324
  def set_current_epoch(self, epoch):
325
  self.current_epoch = epoch
326
+ self.shuffle_buckets()
 
 
 
 
 
327
 
328
  def set_tag_frequency(self, dir_name, captions):
329
  frequency_for_dir = self.tag_frequency.get(dir_name, {})
330
  self.tag_frequency[dir_name] = frequency_for_dir
331
  for caption in captions:
332
  for tag in caption.split(","):
333
+ tag = tag.strip()
334
+ if tag:
335
  tag = tag.lower()
336
  frequency = frequency_for_dir.get(tag, 0)
337
  frequency_for_dir[tag] = frequency + 1
 
342
  def add_replacement(self, str_from, str_to):
343
  self.replacements[str_from] = str_to
344
 
345
+ def process_caption(self, subset: BaseSubset, caption):
346
  # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
347
+ is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate
348
+ is_drop_out = is_drop_out or subset.caption_dropout_every_n_epochs > 0 and self.current_epoch % subset.caption_dropout_every_n_epochs == 0
349
 
350
  if is_drop_out:
351
  caption = ""
352
  else:
353
+ if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0:
354
  def dropout_tags(tokens):
355
+ if subset.caption_tag_dropout_rate <= 0:
356
  return tokens
357
  l = []
358
  for token in tokens:
359
+ if random.random() >= subset.caption_tag_dropout_rate:
360
  l.append(token)
361
  return l
362
 
363
+ fixed_tokens = []
364
+ flex_tokens = [t.strip() for t in caption.strip().split(",")]
365
+ if subset.keep_tokens > 0:
366
+ fixed_tokens = flex_tokens[:subset.keep_tokens]
367
+ flex_tokens = flex_tokens[subset.keep_tokens:]
 
 
 
 
 
368
 
369
+ if subset.shuffle_caption:
370
+ random.shuffle(flex_tokens)
371
 
372
+ flex_tokens = dropout_tags(flex_tokens)
373
 
374
+ caption = ", ".join(fixed_tokens + flex_tokens)
 
375
 
376
  # textual inversion対応
377
  for str_from, str_to in self.replacements.items():
 
425
  input_ids = torch.stack(iids_list) # 3,77
426
  return input_ids
427
 
428
+ def register_image(self, info: ImageInfo, subset: BaseSubset):
429
  self.image_data[info.image_key] = info
430
+ self.image_to_subset[info.image_key] = subset
431
 
432
  def make_buckets(self):
433
  '''
 
526
  img = np.array(image, np.uint8)
527
  return img
528
 
529
+ def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size):
530
  image_height, image_width = image.shape[0:2]
531
 
532
  if image_width != resized_size[0] or image_height != resized_size[1]:
 
536
  image_height, image_width = image.shape[0:2]
537
  if image_width > reso[0]:
538
  trim_size = image_width - reso[0]
539
+ p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
540
  # print("w", trim_size, p)
541
  image = image[:, p:p + reso[0]]
542
  if image_height > reso[1]:
543
  trim_size = image_height - reso[1]
544
+ p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
545
  # print("h", trim_size, p)
546
  image = image[p:p + reso[1]]
547
 
548
  assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
549
  return image
550
 
551
+ def is_latent_cacheable(self):
552
+ return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
553
+
554
  def cache_latents(self, vae):
555
  # TODO ここを高速化したい
556
  print("caching latents.")
557
  for info in tqdm(self.image_data.values()):
558
+ subset = self.image_to_subset[info.image_key]
559
+
560
  if info.latents_npz is not None:
561
  info.latents = self.load_latents_from_npz(info, False)
562
  info.latents = torch.FloatTensor(info.latents)
 
566
  continue
567
 
568
  image = self.load_image(info.absolute_path)
569
+ image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
570
 
571
  img_tensor = self.image_transforms(image)
572
  img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
573
  info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
574
 
575
+ if subset.flip_aug:
576
  image = image[:, ::-1].copy() # cannot convert to Tensor without copy
577
  img_tensor = self.image_transforms(image)
578
  img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
 
582
  image = Image.open(image_path)
583
  return image.size
584
 
585
+ def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
586
  img = self.load_image(image_path)
587
 
588
  face_cx = face_cy = face_w = face_h = 0
589
+ if subset.face_crop_aug_range is not None:
590
  tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
591
  if len(tokens) >= 5:
592
  face_cx = int(tokens[-4])
 
597
  return img, face_cx, face_cy, face_w, face_h
598
 
599
  # いい感じに切り出す
600
+ def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h):
601
  height, width = image.shape[0:2]
602
  if height == self.height and width == self.width:
603
  return image
 
605
  # 画像サイズはsizeより大きいのでリサイズする
606
  face_size = max(face_w, face_h)
607
  min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
608
+ min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
609
+ max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
610
  if min_scale >= max_scale: # range指定がmin==max
611
  scale = min_scale
612
  else:
 
624
  for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
625
  p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
626
 
627
+ if subset.random_crop:
628
  # 背景も含めるために顔を中心に置く確率を高めつつずらす
629
  range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
630
  p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
631
  else:
632
  # range指定があるときのみ、すこしだけランダムに(わりと適当)
633
+ if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]:
634
  if face_size > self.size // 10 and face_size >= 40:
635
  p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
636
 
 
653
  return self._length
654
 
655
  def __getitem__(self, index):
 
 
 
656
  bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
657
  bucket_batch_size = self.buckets_indices[index].bucket_batch_size
658
  image_index = self.buckets_indices[index].batch_index * bucket_batch_size
 
665
 
666
  for image_key in bucket[image_index:image_index + bucket_batch_size]:
667
  image_info = self.image_data[image_key]
668
+ subset = self.image_to_subset[image_key]
669
  loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
670
 
671
  # image/latentsを処理する
672
  if image_info.latents is not None:
673
+ latents = image_info.latents if not subset.flip_aug or random.random() < .5 else image_info.latents_flipped
674
  image = None
675
  elif image_info.latents_npz is not None:
676
+ latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= .5)
677
  latents = torch.FloatTensor(latents)
678
  image = None
679
  else:
680
  # 画像を読み込み、必要ならcropする
681
+ img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path)
682
  im_h, im_w = img.shape[0:2]
683
 
684
  if self.enable_bucket:
685
+ img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size)
686
  else:
687
  if face_cx > 0: # 顔位置情報あり
688
+ img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
689
  elif im_h > self.height or im_w > self.width:
690
+ assert subset.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
691
  if im_h > self.height:
692
  p = random.randint(0, im_h - self.height)
693
  img = img[p:p + self.height]
 
699
  assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
700
 
701
  # augmentation
702
+ aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
703
+ if aug is not None:
704
+ img = aug(image=img)['image']
705
 
706
  latents = None
707
  image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
 
709
  images.append(image)
710
  latents_list.append(latents)
711
 
712
+ caption = self.process_caption(subset, image_info.caption)
713
  captions.append(caption)
714
  if not self.token_padding_disabled: # this option might be omitted in future
715
  input_ids_list.append(self.get_input_ids(caption))
 
740
 
741
 
742
  class DreamBoothDataset(BaseDataset):
743
+ def __init__(self, subsets: Sequence[DreamBoothSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset) -> None:
744
+ super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
 
745
 
746
  assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
747
 
 
764
  self.bucket_reso_steps = None # この情報は使われない
765
  self.bucket_no_upscale = False
766
 
767
+ def read_caption(img_path, caption_extension):
768
  # captionの候補ファイル名を作る
769
  base_name = os.path.splitext(img_path)[0]
770
  base_name_face_det = base_name
 
787
  break
788
  return caption
789
 
790
+ def load_dreambooth_dir(subset: DreamBoothSubset):
791
+ if not os.path.isdir(subset.image_dir):
792
+ print(f"not directory: {subset.image_dir}")
793
+ return [], []
 
 
 
 
 
 
 
794
 
795
+ img_paths = glob_images(subset.image_dir, "*")
796
+ print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
 
797
 
798
  # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
799
  captions = []
800
  for img_path in img_paths:
801
+ cap_for_img = read_caption(img_path, subset.caption_extension)
802
+ if cap_for_img is None and subset.class_tokens is None:
803
+ print(f"neither caption file nor class tokens are found. use empty caption for {img_path}")
804
+ captions.append("")
805
+ else:
806
+ captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
807
+
808
+ self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
809
 
810
+ return img_paths, captions
811
 
812
+ print("prepare images.")
 
813
  num_train_images = 0
814
+ num_reg_images = 0
815
+ reg_infos: List[ImageInfo] = []
816
+ for subset in subsets:
817
+ if subset.num_repeats < 1:
818
+ print(f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
819
+ continue
820
+
821
+ if subset in self.subsets:
822
+ print(f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
823
+ continue
824
+
825
+ img_paths, captions = load_dreambooth_dir(subset)
826
+ if len(img_paths) < 1:
827
+ print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
828
+ continue
829
+
830
+ if subset.is_reg:
831
+ num_reg_images += subset.num_repeats * len(img_paths)
832
+ else:
833
+ num_train_images += subset.num_repeats * len(img_paths)
834
 
835
  for img_path, caption in zip(img_paths, captions):
836
+ info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
837
+ if subset.is_reg:
838
+ reg_infos.append(info)
839
+ else:
840
+ self.register_image(info, subset)
841
 
842
+ subset.img_count = len(img_paths)
843
+ self.subsets.append(subset)
844
 
845
  print(f"{num_train_images} train images with repeating.")
846
  self.num_train_images = num_train_images
847
 
848
+ print(f"{num_reg_images} reg images.")
849
+ if num_train_images < num_reg_images:
850
+ print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
 
 
851
 
852
+ if num_reg_images == 0:
853
+ print("no regularization images / 正則化画像が見つかりませんでした")
854
+ else:
855
+ # num_repeatsを計算する:どうせ大した数ではないのでループで処理する
856
+ n = 0
857
+ first_loop = True
858
+ while n < num_train_images:
859
+ for info in reg_infos:
860
+ if first_loop:
861
+ self.register_image(info, subset)
862
+ n += info.num_repeats
863
+ else:
864
+ info.num_repeats += 1
865
+ n += 1
866
+ if n >= num_train_images:
867
+ break
868
+ first_loop = False
869
 
870
+ self.num_reg_images = num_reg_images
 
 
871
 
 
872
 
873
+ class FineTuningDataset(BaseDataset):
874
+ def __init__(self, subsets: Sequence[FineTuningSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset) -> None:
875
+ super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
876
+
877
+ self.batch_size = batch_size
878
+
879
+ self.num_train_images = 0
880
+ self.num_reg_images = 0
881
 
882
+ for subset in subsets:
883
+ if subset.num_repeats < 1:
884
+ print(f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
885
+ continue
886
+
887
+ if subset in self.subsets:
888
+ print(f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
889
+ continue
890
+
891
+ # メタデータを読み込む
892
+ if os.path.exists(subset.metadata_file):
893
+ print(f"loading existing metadata: {subset.metadata_file}")
894
+ with open(subset.metadata_file, "rt", encoding='utf-8') as f:
895
+ metadata = json.load(f)
896
  else:
897
+ raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}")
 
 
 
 
 
 
 
 
 
 
 
 
 
898
 
899
+ if len(metadata) < 1:
900
+ print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します")
901
+ continue
902
 
903
+ tags_list = []
904
+ for image_key, img_md in metadata.items():
905
+ # path情報を作る
906
+ if os.path.exists(image_key):
907
+ abs_path = image_key
908
+ else:
909
+ # わりといい加減だがいい方法が思いつかん
910
+ abs_path = glob_images(subset.image_dir, image_key)
911
+ assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
912
+ abs_path = abs_path[0]
913
 
914
+ caption = img_md.get('caption')
915
+ tags = img_md.get('tags')
916
+ if caption is None:
917
+ caption = tags
918
+ elif tags is not None and len(tags) > 0:
919
+ caption = caption + ', ' + tags
920
+ tags_list.append(tags)
921
+ assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
 
 
 
 
922
 
923
+ image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
924
+ image_info.image_size = img_md.get('train_resolution')
 
925
 
926
+ if not subset.color_aug and not subset.random_crop:
927
+ # if npz exists, use them
928
+ image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
929
+
930
+ self.register_image(image_info, subset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
931
 
932
+ self.num_train_images += len(metadata) * subset.num_repeats
933
+
934
+ # TODO do not record tag freq when no tag
935
+ self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list)
936
+ subset.img_count = len(metadata)
937
+ self.subsets.append(subset)
938
 
939
  # check existence of all npz files
940
+ use_npz_latents = all([not(subset.color_aug or subset.random_crop) for subset in self.subsets])
941
  if use_npz_latents:
942
+ flip_aug_in_subset = False
943
  npz_any = False
944
  npz_all = True
945
+
946
  for image_info in self.image_data.values():
947
+ subset = self.image_to_subset[image_info.image_key]
948
+
949
  has_npz = image_info.latents_npz is not None
950
  npz_any = npz_any or has_npz
951
 
952
+ if subset.flip_aug:
953
  has_npz = has_npz and image_info.latents_npz_flipped is not None
954
+ flip_aug_in_subset = True
955
  npz_all = npz_all and has_npz
956
 
957
  if npz_any and not npz_all:
 
963
  elif not npz_all:
964
  use_npz_latents = False
965
  print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
966
+ if flip_aug_in_subset:
967
  print("maybe no flipped files / ��転されたnpzファイルがないのかもしれません")
968
  # else:
969
  # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
 
1009
  for image_info in self.image_data.values():
1010
  image_info.latents_npz = image_info.latents_npz_flipped = None
1011
 
1012
+ def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
1013
  base_name = os.path.splitext(image_key)[0]
1014
  npz_file_norm = base_name + '.npz'
1015
 
 
1021
  return npz_file_norm, npz_file_flip
1022
 
1023
  # image_key is relative path
1024
+ npz_file_norm = os.path.join(subset.image_dir, image_key + '.npz')
1025
+ npz_file_flip = os.path.join(subset.image_dir, image_key + '_flip.npz')
1026
 
1027
  if not os.path.exists(npz_file_norm):
1028
  npz_file_norm = None
 
1033
  return npz_file_norm, npz_file_flip
1034
 
1035
 
1036
+ # behave as Dataset mock
1037
+ class DatasetGroup(torch.utils.data.ConcatDataset):
1038
+ def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
1039
+ self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]]
1040
+
1041
+ super().__init__(datasets)
1042
+
1043
+ self.image_data = {}
1044
+ self.num_train_images = 0
1045
+ self.num_reg_images = 0
1046
+
1047
+ # simply concat together
1048
+ # TODO: handling image_data key duplication among dataset
1049
+ # In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset.
1050
+ for dataset in datasets:
1051
+ self.image_data.update(dataset.image_data)
1052
+ self.num_train_images += dataset.num_train_images
1053
+ self.num_reg_images += dataset.num_reg_images
1054
+
1055
+ def add_replacement(self, str_from, str_to):
1056
+ for dataset in self.datasets:
1057
+ dataset.add_replacement(str_from, str_to)
1058
+
1059
+ # def make_buckets(self):
1060
+ # for dataset in self.datasets:
1061
+ # dataset.make_buckets()
1062
+
1063
+ def cache_latents(self, vae):
1064
+ for i, dataset in enumerate(self.datasets):
1065
+ print(f"[Dataset {i}]")
1066
+ dataset.cache_latents(vae)
1067
+
1068
+ def is_latent_cacheable(self) -> bool:
1069
+ return all([dataset.is_latent_cacheable() for dataset in self.datasets])
1070
+
1071
+ def set_current_epoch(self, epoch):
1072
+ for dataset in self.datasets:
1073
+ dataset.set_current_epoch(epoch)
1074
+
1075
+ def disable_token_padding(self):
1076
+ for dataset in self.datasets:
1077
+ dataset.disable_token_padding()
1078
+
1079
+
1080
  def debug_dataset(train_dataset, show_input_ids=False):
1081
  print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
1082
  print("Escape for exit. / Escキーで中断、終了します")
1083
 
1084
  train_dataset.set_current_epoch(1)
1085
  k = 0
1086
+ indices = list(range(len(train_dataset)))
1087
+ random.shuffle(indices)
1088
+ for i, idx in enumerate(indices):
1089
+ example = train_dataset[idx]
1090
  if example['latents'] is not None:
1091
  print(f"sample has latents from npz file: {example['latents'].size()}")
1092
  for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
 
1491
  help='enable v-parameterization training / v-parameterization学習を有効にする')
1492
  parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
1493
  help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
1494
+ parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
1495
+ help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
1496
+
1497
+
1498
+ def add_optimizer_arguments(parser: argparse.ArgumentParser):
1499
+ parser.add_argument("--optimizer_type", type=str, default="",
1500
+ help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor")
1501
+
1502
+ # backward compatibility
1503
+ parser.add_argument("--use_8bit_adam", action="store_true",
1504
+ help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
1505
+ parser.add_argument("--use_lion_optimizer", action="store_true",
1506
+ help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
1507
+
1508
+ parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
1509
+ parser.add_argument("--max_grad_norm", default=1.0, type=float,
1510
+ help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない")
1511
+
1512
+ parser.add_argument("--optimizer_args", type=str, default=None, nargs='*',
1513
+ help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")")
1514
+
1515
+ parser.add_argument("--lr_scheduler", type=str, default="constant",
1516
+ help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor")
1517
+ parser.add_argument("--lr_warmup_steps", type=int, default=0,
1518
+ help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
1519
+ parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
1520
+ help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
1521
+ parser.add_argument("--lr_scheduler_power", type=float, default=1,
1522
+ help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
1523
 
1524
 
1525
  def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
 
1543
  parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
1544
  parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
1545
  help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
 
 
 
 
1546
  parser.add_argument("--mem_eff_attn", action="store_true",
1547
  help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
1548
  parser.add_argument("--xformers", action="store_true",
 
1550
  parser.add_argument("--vae", type=str, default=None,
1551
  help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
1552
 
 
1553
  parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
1554
  parser.add_argument("--max_train_epochs", type=int, default=None,
1555
  help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
 
1570
  parser.add_argument("--logging_dir", type=str, default=None,
1571
  help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
1572
  parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
 
 
 
 
1573
  parser.add_argument("--noise_offset", type=float, default=None,
1574
  help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
1575
  parser.add_argument("--lowram", action="store_true",
1576
  help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
1577
 
1578
+ parser.add_argument("--sample_every_n_steps", type=int, default=None,
1579
+ help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する")
1580
+ parser.add_argument("--sample_every_n_epochs", type=int, default=None,
1581
+ help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)")
1582
+ parser.add_argument("--sample_prompts", type=str, default=None,
1583
+ help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル")
1584
+ parser.add_argument('--sample_sampler', type=str, default='ddim',
1585
+ choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
1586
+ 'dpmsolver++', 'dpmsingle',
1587
+ 'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'],
1588
+ help=f'sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類')
1589
+
1590
  if support_dreambooth:
1591
  # DreamBooth training
1592
  parser.add_argument("--prior_loss_weight", type=float, default=1.0,
 
1608
  parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
1609
  parser.add_argument("--caption_extention", type=str, default=None,
1610
  help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
1611
+ parser.add_argument("--keep_tokens", type=int, default=0,
1612
+ help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)")
1613
  parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
1614
  parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
1615
  parser.add_argument("--face_crop_aug_range", type=str, default=None,
 
1634
  if support_caption_dropout:
1635
  # Textual Inversion はcaptionのdropoutをsupportしない
1636
  # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
1637
+ parser.add_argument("--caption_dropout_rate", type=float, default=0.0,
1638
  help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
1639
+ parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=0,
1640
  help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
1641
+ parser.add_argument("--caption_tag_dropout_rate", type=float, default=0.0,
1642
  help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
1643
 
1644
  if support_dreambooth:
 
1663
  # region utils
1664
 
1665
 
1666
+ def get_optimizer(args, trainable_params):
1667
+ # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
1668
+
1669
+ optimizer_type = args.optimizer_type
1670
+ if args.use_8bit_adam:
1671
+ assert not args.use_lion_optimizer, "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています"
1672
+ assert optimizer_type is None or optimizer_type == "", "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています"
1673
+ optimizer_type = "AdamW8bit"
1674
+
1675
+ elif args.use_lion_optimizer:
1676
+ assert optimizer_type is None or optimizer_type == "", "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています"
1677
+ optimizer_type = "Lion"
1678
+
1679
+ if optimizer_type is None or optimizer_type == "":
1680
+ optimizer_type = "AdamW"
1681
+ optimizer_type = optimizer_type.lower()
1682
+
1683
+ # 引数を分解する:boolとfloat、tupleのみ対応
1684
+ optimizer_kwargs = {}
1685
+ if args.optimizer_args is not None and len(args.optimizer_args) > 0:
1686
+ for arg in args.optimizer_args:
1687
+ key, value = arg.split('=')
1688
+
1689
+ value = value.split(",")
1690
+ for i in range(len(value)):
1691
+ if value[i].lower() == "true" or value[i].lower() == "false":
1692
+ value[i] = (value[i].lower() == "true")
1693
+ else:
1694
+ value[i] = float(value[i])
1695
+ if len(value) == 1:
1696
+ value = value[0]
1697
+ else:
1698
+ value = tuple(value)
1699
+
1700
+ optimizer_kwargs[key] = value
1701
+ # print("optkwargs:", optimizer_kwargs)
1702
+
1703
+ lr = args.learning_rate
1704
+
1705
+ if optimizer_type == "AdamW8bit".lower():
1706
+ try:
1707
+ import bitsandbytes as bnb
1708
+ except ImportError:
1709
+ raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
1710
+ print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
1711
+ optimizer_class = bnb.optim.AdamW8bit
1712
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1713
+
1714
+ elif optimizer_type == "SGDNesterov8bit".lower():
1715
+ try:
1716
+ import bitsandbytes as bnb
1717
+ except ImportError:
1718
+ raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
1719
+ print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
1720
+ if "momentum" not in optimizer_kwargs:
1721
+ print(f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
1722
+ optimizer_kwargs["momentum"] = 0.9
1723
+
1724
+ optimizer_class = bnb.optim.SGD8bit
1725
+ optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
1726
+
1727
+ elif optimizer_type == "Lion".lower():
1728
+ try:
1729
+ import lion_pytorch
1730
+ except ImportError:
1731
+ raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
1732
+ print(f"use Lion optimizer | {optimizer_kwargs}")
1733
+ optimizer_class = lion_pytorch.Lion
1734
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1735
+
1736
+ elif optimizer_type == "SGDNesterov".lower():
1737
+ print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
1738
+ if "momentum" not in optimizer_kwargs:
1739
+ print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
1740
+ optimizer_kwargs["momentum"] = 0.9
1741
+
1742
+ optimizer_class = torch.optim.SGD
1743
+ optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
1744
+
1745
+ elif optimizer_type == "DAdaptation".lower():
1746
+ try:
1747
+ import dadaptation
1748
+ except ImportError:
1749
+ raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
1750
+ print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
1751
+
1752
+ min_lr = lr
1753
+ if type(trainable_params) == list and type(trainable_params[0]) == dict:
1754
+ for group in trainable_params:
1755
+ min_lr = min(min_lr, group.get("lr", lr))
1756
+
1757
+ if min_lr <= 0.1:
1758
+ print(
1759
+ f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {min_lr}')
1760
+ print('recommend option: lr=1.0 / 推奨は1.0です')
1761
+
1762
+ optimizer_class = dadaptation.DAdaptAdam
1763
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1764
+
1765
+ elif optimizer_type == "Adafactor".lower():
1766
+ # 引数を確認して適宜補正する
1767
+ if "relative_step" not in optimizer_kwargs:
1768
+ optimizer_kwargs["relative_step"] = True # default
1769
+ if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
1770
+ print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします")
1771
+ optimizer_kwargs["relative_step"] = True
1772
+ print(f"use Adafactor optimizer | {optimizer_kwargs}")
1773
+
1774
+ if optimizer_kwargs["relative_step"]:
1775
+ print(f"relative_step is true / relative_stepがtrueです")
1776
+ if lr != 0.0:
1777
+ print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
1778
+ args.learning_rate = None
1779
+
1780
+ # trainable_paramsがgroupだった時の処理:lrを削除する
1781
+ if type(trainable_params) == list and type(trainable_params[0]) == dict:
1782
+ has_group_lr = False
1783
+ for group in trainable_params:
1784
+ p = group.pop("lr", None)
1785
+ has_group_lr = has_group_lr or (p is not None)
1786
+
1787
+ if has_group_lr:
1788
+ # 一応argsを無効にしてお�� TODO 依存関係が逆転してるのであまり望ましくない
1789
+ print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます")
1790
+ args.unet_lr = None
1791
+ args.text_encoder_lr = None
1792
+
1793
+ if args.lr_scheduler != "adafactor":
1794
+ print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
1795
+ args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
1796
+
1797
+ lr = None
1798
+ else:
1799
+ if args.max_grad_norm != 0.0:
1800
+ print(f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません")
1801
+ if args.lr_scheduler != "constant_with_warmup":
1802
+ print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
1803
+ if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
1804
+ print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
1805
+
1806
+ optimizer_class = transformers.optimization.Adafactor
1807
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1808
+
1809
+ elif optimizer_type == "AdamW".lower():
1810
+ print(f"use AdamW optimizer | {optimizer_kwargs}")
1811
+ optimizer_class = torch.optim.AdamW
1812
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1813
+
1814
+ else:
1815
+ # 任意のoptimizerを使う
1816
+ optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
1817
+ print(f"use {optimizer_type} | {optimizer_kwargs}")
1818
+ if "." not in optimizer_type:
1819
+ optimizer_module = torch.optim
1820
+ else:
1821
+ values = optimizer_type.split(".")
1822
+ optimizer_module = importlib.import_module(".".join(values[:-1]))
1823
+ optimizer_type = values[-1]
1824
+
1825
+ optimizer_class = getattr(optimizer_module, optimizer_type)
1826
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1827
+
1828
+ optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
1829
+ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
1830
+
1831
+ return optimizer_name, optimizer_args, optimizer
1832
+
1833
+
1834
+ # Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
1835
+ # code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
1836
+ # Which is a newer release of diffusers than currently packaged with sd-scripts
1837
+ # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
1838
+
1839
+
1840
+ def get_scheduler_fix(
1841
+ name: Union[str, SchedulerType],
1842
+ optimizer: Optimizer,
1843
+ num_warmup_steps: Optional[int] = None,
1844
+ num_training_steps: Optional[int] = None,
1845
+ num_cycles: int = 1,
1846
+ power: float = 1.0,
1847
+ ):
1848
+ """
1849
+ Unified API to get any scheduler from its name.
1850
+ Args:
1851
+ name (`str` or `SchedulerType`):
1852
+ The name of the scheduler to use.
1853
+ optimizer (`torch.optim.Optimizer`):
1854
+ The optimizer that will be used during training.
1855
+ num_warmup_steps (`int`, *optional*):
1856
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
1857
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
1858
+ num_training_steps (`int``, *optional*):
1859
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
1860
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
1861
+ num_cycles (`int`, *optional*):
1862
+ The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
1863
+ power (`float`, *optional*, defaults to 1.0):
1864
+ Power factor. See `POLYNOMIAL` scheduler
1865
+ last_epoch (`int`, *optional*, defaults to -1):
1866
+ The index of the last epoch when resuming training.
1867
+ """
1868
+ if name.startswith("adafactor"):
1869
+ assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
1870
+ initial_lr = float(name.split(':')[1])
1871
+ # print("adafactor scheduler init lr", initial_lr)
1872
+ return transformers.optimization.AdafactorSchedule(optimizer, initial_lr)
1873
+
1874
+ name = SchedulerType(name)
1875
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
1876
+ if name == SchedulerType.CONSTANT:
1877
+ return schedule_func(optimizer)
1878
+
1879
+ # All other schedulers require `num_warmup_steps`
1880
+ if num_warmup_steps is None:
1881
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
1882
+
1883
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
1884
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
1885
+
1886
+ # All other schedulers require `num_training_steps`
1887
+ if num_training_steps is None:
1888
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
1889
+
1890
+ if name == SchedulerType.COSINE_WITH_RESTARTS:
1891
+ return schedule_func(
1892
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
1893
+ )
1894
+
1895
+ if name == SchedulerType.POLYNOMIAL:
1896
+ return schedule_func(
1897
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
1898
+ )
1899
+
1900
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
1901
+
1902
+
1903
  def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
1904
  # backward compatibility
1905
  if args.caption_extention is not None:
1906
  args.caption_extension = args.caption_extention
1907
  args.caption_extention = None
1908
 
 
 
 
 
1909
  # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
1910
  if args.resolution is not None:
1911
  args.resolution = tuple([int(r) for r in args.resolution.split(',')])
 
1928
 
1929
  def load_tokenizer(args: argparse.Namespace):
1930
  print("prepare tokenizer")
1931
+ original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH
1932
+
1933
+ tokenizer: CLIPTokenizer = None
1934
+ if args.tokenizer_cache_dir:
1935
+ local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace('/', '_'))
1936
+ if os.path.exists(local_tokenizer_path):
1937
+ print(f"load tokenizer from cache: {local_tokenizer_path}")
1938
+ tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2
1939
+
1940
+ if tokenizer is None:
1941
+ if args.v2:
1942
+ tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer")
1943
+ else:
1944
+ tokenizer = CLIPTokenizer.from_pretrained(original_path)
1945
+
1946
+ if hasattr(args, "max_token_length") and args.max_token_length is not None:
1947
  print(f"update token length: {args.max_token_length}")
1948
+
1949
+ if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
1950
+ print(f"save Tokenizer to cache: {local_tokenizer_path}")
1951
+ tokenizer.save_pretrained(local_tokenizer_path)
1952
+
1953
  return tokenizer
1954
 
1955
 
 
2000
 
2001
 
2002
  def load_target_model(args: argparse.Namespace, weight_dtype):
2003
+ name_or_path = args.pretrained_model_name_or_path
2004
+ name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
2005
+ load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
2006
  if load_stable_diffusion_format:
2007
  print("load StableDiffusion checkpoint")
2008
+ text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path)
2009
  else:
2010
  print("load Diffusers pretrained models")
2011
+ try:
2012
+ pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
2013
+ except EnvironmentError as ex:
2014
+ print(
2015
+ f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}")
2016
  text_encoder = pipe.text_encoder
2017
  vae = pipe.vae
2018
  unet = pipe.unet
 
2181
  model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
2182
  accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
2183
 
2184
+
2185
+ # scheduler:
2186
+ SCHEDULER_LINEAR_START = 0.00085
2187
+ SCHEDULER_LINEAR_END = 0.0120
2188
+ SCHEDULER_TIMESTEPS = 1000
2189
+ SCHEDLER_SCHEDULE = 'scaled_linear'
2190
+
2191
+
2192
+ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None):
2193
+ """
2194
+ 生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない
2195
+ clip skipは対応した
2196
+ """
2197
+ if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
2198
+ return
2199
+ if args.sample_every_n_epochs is not None:
2200
+ # sample_every_n_steps は無視する
2201
+ if epoch is None or epoch % args.sample_every_n_epochs != 0:
2202
+ return
2203
+ else:
2204
+ if steps % args.sample_every_n_steps != 0:
2205
+ return
2206
+
2207
+ print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
2208
+ if not os.path.isfile(args.sample_prompts):
2209
+ print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
2210
+ return
2211
+
2212
+ # ここでCUDAのキャッシュクリアとかしたほうがいいのか……
2213
+
2214
+ org_vae_device = vae.device # CPUにいるはず
2215
+ vae.to(device)
2216
+
2217
+ # clip skip 対応のための wrapper を作る
2218
+ if args.clip_skip is None:
2219
+ text_encoder_or_wrapper = text_encoder
2220
+ else:
2221
+ class Wrapper():
2222
+ def __init__(self, tenc) -> None:
2223
+ self.tenc = tenc
2224
+ self.config = {}
2225
+ super().__init__()
2226
+
2227
+ def __call__(self, input_ids, attention_mask):
2228
+ enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True)
2229
+ encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
2230
+ encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states)
2231
+ pooled_output = enc_out['pooler_output']
2232
+ return encoder_hidden_states, pooled_output # 1st output is only used
2233
+
2234
+ text_encoder_or_wrapper = Wrapper(text_encoder)
2235
+
2236
+ # read prompts
2237
+ with open(args.sample_prompts, 'rt', encoding='utf-8') as f:
2238
+ prompts = f.readlines()
2239
+
2240
+ # schedulerを用意する
2241
+ sched_init_args = {}
2242
+ if args.sample_sampler == "ddim":
2243
+ scheduler_cls = DDIMScheduler
2244
+ elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
2245
+ scheduler_cls = DDPMScheduler
2246
+ elif args.sample_sampler == "pndm":
2247
+ scheduler_cls = PNDMScheduler
2248
+ elif args.sample_sampler == 'lms' or args.sample_sampler == 'k_lms':
2249
+ scheduler_cls = LMSDiscreteScheduler
2250
+ elif args.sample_sampler == 'euler' or args.sample_sampler == 'k_euler':
2251
+ scheduler_cls = EulerDiscreteScheduler
2252
+ elif args.sample_sampler == 'euler_a' or args.sample_sampler == 'k_euler_a':
2253
+ scheduler_cls = EulerAncestralDiscreteScheduler
2254
+ elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
2255
+ scheduler_cls = DPMSolverMultistepScheduler
2256
+ sched_init_args['algorithm_type'] = args.sample_sampler
2257
+ elif args.sample_sampler == "dpmsingle":
2258
+ scheduler_cls = DPMSolverSinglestepScheduler
2259
+ elif args.sample_sampler == "heun":
2260
+ scheduler_cls = HeunDiscreteScheduler
2261
+ elif args.sample_sampler == 'dpm_2' or args.sample_sampler == 'k_dpm_2':
2262
+ scheduler_cls = KDPM2DiscreteScheduler
2263
+ elif args.sample_sampler == 'dpm_2_a' or args.sample_sampler == 'k_dpm_2_a':
2264
+ scheduler_cls = KDPM2AncestralDiscreteScheduler
2265
+ else:
2266
+ scheduler_cls = DDIMScheduler
2267
+
2268
+ if args.v_parameterization:
2269
+ sched_init_args['prediction_type'] = 'v_prediction'
2270
+
2271
+ scheduler = scheduler_cls(num_train_timesteps=SCHEDULER_TIMESTEPS,
2272
+ beta_start=SCHEDULER_LINEAR_START, beta_end=SCHEDULER_LINEAR_END,
2273
+ beta_schedule=SCHEDLER_SCHEDULE, **sched_init_args)
2274
+
2275
+ # clip_sample=Trueにする
2276
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
2277
+ # print("set clip_sample to True")
2278
+ scheduler.config.clip_sample = True
2279
+
2280
+ pipeline = StableDiffusionPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer,
2281
+ scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False)
2282
+ pipeline.to(device)
2283
+
2284
+ save_dir = args.output_dir + "/sample"
2285
+ os.makedirs(save_dir, exist_ok=True)
2286
+
2287
+ rng_state = torch.get_rng_state()
2288
+ cuda_rng_state = torch.cuda.get_rng_state()
2289
+
2290
+ with torch.no_grad():
2291
+ with accelerator.autocast():
2292
+ for i, prompt in enumerate(prompts):
2293
+ prompt = prompt.strip()
2294
+ if len(prompt) == 0 or prompt[0] == '#':
2295
+ continue
2296
+
2297
+ # subset of gen_img_diffusers
2298
+ prompt_args = prompt.split(' --')
2299
+ prompt = prompt_args[0]
2300
+ negative_prompt = None
2301
+ sample_steps = 30
2302
+ width = height = 512
2303
+ scale = 7.5
2304
+ seed = None
2305
+ for parg in prompt_args:
2306
+ try:
2307
+ m = re.match(r'w (\d+)', parg, re.IGNORECASE)
2308
+ if m:
2309
+ width = int(m.group(1))
2310
+ continue
2311
+
2312
+ m = re.match(r'h (\d+)', parg, re.IGNORECASE)
2313
+ if m:
2314
+ height = int(m.group(1))
2315
+ continue
2316
+
2317
+ m = re.match(r'd (\d+)', parg, re.IGNORECASE)
2318
+ if m:
2319
+ seed = int(m.group(1))
2320
+ continue
2321
+
2322
+ m = re.match(r's (\d+)', parg, re.IGNORECASE)
2323
+ if m: # steps
2324
+ sample_steps = max(1, min(1000, int(m.group(1))))
2325
+ continue
2326
+
2327
+ m = re.match(r'l ([\d\.]+)', parg, re.IGNORECASE)
2328
+ if m: # scale
2329
+ scale = float(m.group(1))
2330
+ continue
2331
+
2332
+ m = re.match(r'n (.+)', parg, re.IGNORECASE)
2333
+ if m: # negative prompt
2334
+ negative_prompt = m.group(1)
2335
+ continue
2336
+
2337
+ except ValueError as ex:
2338
+ print(f"Exception in parsing / 解析エラー: {parg}")
2339
+ print(ex)
2340
+
2341
+ if seed is not None:
2342
+ torch.manual_seed(seed)
2343
+ torch.cuda.manual_seed(seed)
2344
+
2345
+ if prompt_replacement is not None:
2346
+ prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
2347
+ if negative_prompt is not None:
2348
+ negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
2349
+
2350
+ image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
2351
+
2352
+ ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
2353
+ num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
2354
+ seed_suffix = "" if seed is None else f"_{seed}"
2355
+ img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
2356
+
2357
+ image.save(os.path.join(save_dir, img_filename))
2358
+
2359
+ torch.set_rng_state(rng_state)
2360
+ torch.cuda.set_rng_state(cuda_rng_state)
2361
+ vae.to(org_vae_device)
2362
+
2363
  # endregion
2364
 
2365
  # region 前処理用
networks/lora.py CHANGED
@@ -126,6 +126,11 @@ class LoRANetwork(torch.nn.Module):
126
  assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
127
  names.add(lora.lora_name)
128
 
 
 
 
 
 
129
  def load_weights(self, file):
130
  if os.path.splitext(file)[1] == '.safetensors':
131
  from safetensors.torch import load_file, safe_open
 
126
  assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
127
  names.add(lora.lora_name)
128
 
129
+ def set_multiplier(self, multiplier):
130
+ self.multiplier = multiplier
131
+ for lora in self.text_encoder_loras + self.unet_loras:
132
+ lora.multiplier = self.multiplier
133
+
134
  def load_weights(self, file):
135
  if os.path.splitext(file)[1] == '.safetensors':
136
  from safetensors.torch import load_file, safe_open
requirements.txt CHANGED
@@ -12,6 +12,8 @@ safetensors==0.2.6
12
  gradio==3.16.2
13
  altair==4.2.2
14
  easygui==0.98.3
 
 
15
  # for BLIP captioning
16
  requests==2.28.2
17
  timm==0.6.12
@@ -21,5 +23,4 @@ fairscale==0.4.13
21
  tensorflow==2.10.1
22
  huggingface-hub==0.12.0
23
  # for kohya_ss library
24
- #locon.locon_kohya
25
  .
 
12
  gradio==3.16.2
13
  altair==4.2.2
14
  easygui==0.98.3
15
+ toml==0.10.2
16
+ voluptuous==0.13.1
17
  # for BLIP captioning
18
  requests==2.28.2
19
  timm==0.6.12
 
23
  tensorflow==2.10.1
24
  huggingface-hub==0.12.0
25
  # for kohya_ss library
 
26
  .
train_db.py CHANGED
@@ -15,7 +15,11 @@ import diffusers
15
  from diffusers import DDPMScheduler
16
 
17
  import library.train_util as train_util
18
- from library.train_util import DreamBoothDataset
 
 
 
 
19
 
20
 
21
  def collate_fn(examples):
@@ -33,24 +37,33 @@ def train(args):
33
 
34
  tokenizer = train_util.load_tokenizer(args)
35
 
36
- train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
37
- tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
38
- args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
39
- args.bucket_reso_steps, args.bucket_no_upscale,
40
- args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
41
-
42
- if args.no_token_padding:
43
- train_dataset.disable_token_padding()
 
 
 
 
 
44
 
45
- # 学習データのdropout率を設定する
46
- train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
47
 
48
- train_dataset.make_buckets()
 
49
 
50
  if args.debug_dataset:
51
- train_util.debug_dataset(train_dataset)
52
  return
53
 
 
 
 
54
  # acceleratorを準備する
55
  print("prepare accelerator")
56
 
@@ -91,7 +104,7 @@ def train(args):
91
  vae.requires_grad_(False)
92
  vae.eval()
93
  with torch.no_grad():
94
- train_dataset.cache_latents(vae)
95
  vae.to("cpu")
96
  if torch.cuda.is_available():
97
  torch.cuda.empty_cache()
@@ -115,38 +128,18 @@ def train(args):
115
 
116
  # 学習に必要なクラスを準備する
117
  print("prepare optimizer, data loader etc.")
118
-
119
- # 8-bit Adamを使う
120
- if args.use_8bit_adam:
121
- try:
122
- import bitsandbytes as bnb
123
- except ImportError:
124
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
125
- print("use 8-bit Adam optimizer")
126
- optimizer_class = bnb.optim.AdamW8bit
127
- elif args.use_lion_optimizer:
128
- try:
129
- import lion_pytorch
130
- except ImportError:
131
- raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
132
- print("use Lion optimizer")
133
- optimizer_class = lion_pytorch.Lion
134
- else:
135
- optimizer_class = torch.optim.AdamW
136
-
137
  if train_text_encoder:
138
  trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
139
  else:
140
  trainable_params = unet.parameters()
141
 
142
- # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
143
- optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
144
 
145
  # dataloaderを準備する
146
  # DataLoaderのプロセス数:0はメインプロセスになる
147
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
148
  train_dataloader = torch.utils.data.DataLoader(
149
- train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
150
 
151
  # 学習ステップ数を計算する
152
  if args.max_train_epochs is not None:
@@ -156,9 +149,10 @@ def train(args):
156
  if args.stop_text_encoder_training is None:
157
  args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
158
 
159
- # lr schedulerを用意する
160
- lr_scheduler = diffusers.optimization.get_scheduler(
161
- args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
 
162
 
163
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
164
  if args.full_fp16:
@@ -195,8 +189,8 @@ def train(args):
195
  # 学習する
196
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
197
  print("running training / 学習開始")
198
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
199
- print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
200
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
201
  print(f" num epochs / epoch数: {num_train_epochs}")
202
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
@@ -217,7 +211,7 @@ def train(args):
217
  loss_total = 0.0
218
  for epoch in range(num_train_epochs):
219
  print(f"epoch {epoch+1}/{num_train_epochs}")
220
- train_dataset.set_current_epoch(epoch + 1)
221
 
222
  # 指定したステップ数までText Encoderを学習する:epoch最初の状態
223
  unet.train()
@@ -281,12 +275,12 @@ def train(args):
281
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
282
 
283
  accelerator.backward(loss)
284
- if accelerator.sync_gradients:
285
  if train_text_encoder:
286
  params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
287
  else:
288
  params_to_clip = unet.parameters()
289
- accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
290
 
291
  optimizer.step()
292
  lr_scheduler.step()
@@ -297,9 +291,13 @@ def train(args):
297
  progress_bar.update(1)
298
  global_step += 1
299
 
 
 
300
  current_loss = loss.detach().item()
301
  if args.logging_dir is not None:
302
- logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
 
 
303
  accelerator.log(logs, step=global_step)
304
 
305
  if epoch == 0:
@@ -326,6 +324,8 @@ def train(args):
326
  train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
327
  save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
328
 
 
 
329
  is_main_process = accelerator.is_main_process
330
  if is_main_process:
331
  unet = unwrap_model(unet)
@@ -352,6 +352,8 @@ if __name__ == '__main__':
352
  train_util.add_dataset_arguments(parser, True, False, True)
353
  train_util.add_training_arguments(parser, True)
354
  train_util.add_sd_saving_arguments(parser)
 
 
355
 
356
  parser.add_argument("--no_token_padding", action="store_true",
357
  help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
 
15
  from diffusers import DDPMScheduler
16
 
17
  import library.train_util as train_util
18
+ import library.config_util as config_util
19
+ from library.config_util import (
20
+ ConfigSanitizer,
21
+ BlueprintGenerator,
22
+ )
23
 
24
 
25
  def collate_fn(examples):
 
37
 
38
  tokenizer = train_util.load_tokenizer(args)
39
 
40
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
41
+ if args.dataset_config is not None:
42
+ print(f"Load dataset config from {args.dataset_config}")
43
+ user_config = config_util.load_user_config(args.dataset_config)
44
+ ignored = ["train_data_dir", "reg_data_dir"]
45
+ if any(getattr(args, attr) is not None for attr in ignored):
46
+ print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
47
+ else:
48
+ user_config = {
49
+ "datasets": [{
50
+ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
51
+ }]
52
+ }
53
 
54
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
55
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
56
 
57
+ if args.no_token_padding:
58
+ train_dataset_group.disable_token_padding()
59
 
60
  if args.debug_dataset:
61
+ train_util.debug_dataset(train_dataset_group)
62
  return
63
 
64
+ if cache_latents:
65
+ assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
66
+
67
  # acceleratorを準備する
68
  print("prepare accelerator")
69
 
 
104
  vae.requires_grad_(False)
105
  vae.eval()
106
  with torch.no_grad():
107
+ train_dataset_group.cache_latents(vae)
108
  vae.to("cpu")
109
  if torch.cuda.is_available():
110
  torch.cuda.empty_cache()
 
128
 
129
  # 学習に必要なクラスを準備する
130
  print("prepare optimizer, data loader etc.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  if train_text_encoder:
132
  trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
133
  else:
134
  trainable_params = unet.parameters()
135
 
136
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params)
 
137
 
138
  # dataloaderを準備する
139
  # DataLoaderのプロセス数:0はメインプロセスになる
140
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
141
  train_dataloader = torch.utils.data.DataLoader(
142
+ train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
143
 
144
  # 学習ステップ数を計算する
145
  if args.max_train_epochs is not None:
 
149
  if args.stop_text_encoder_training is None:
150
  args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
151
 
152
+ # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
153
+ lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
154
+ num_training_steps=args.max_train_steps,
155
+ num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
156
 
157
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
158
  if args.full_fp16:
 
189
  # 学習する
190
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
191
  print("running training / 学習開始")
192
+ print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
193
+ print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
194
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
195
  print(f" num epochs / epoch数: {num_train_epochs}")
196
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
 
211
  loss_total = 0.0
212
  for epoch in range(num_train_epochs):
213
  print(f"epoch {epoch+1}/{num_train_epochs}")
214
+ train_dataset_group.set_current_epoch(epoch + 1)
215
 
216
  # 指定したステップ数までText Encoderを学習する:epoch最初の状態
217
  unet.train()
 
275
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
276
 
277
  accelerator.backward(loss)
278
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
279
  if train_text_encoder:
280
  params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
281
  else:
282
  params_to_clip = unet.parameters()
283
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
284
 
285
  optimizer.step()
286
  lr_scheduler.step()
 
291
  progress_bar.update(1)
292
  global_step += 1
293
 
294
+ train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
295
+
296
  current_loss = loss.detach().item()
297
  if args.logging_dir is not None:
298
+ logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
299
+ if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
300
+ logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
301
  accelerator.log(logs, step=global_step)
302
 
303
  if epoch == 0:
 
324
  train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
325
  save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
326
 
327
+ train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
328
+
329
  is_main_process = accelerator.is_main_process
330
  if is_main_process:
331
  unet = unwrap_model(unet)
 
352
  train_util.add_dataset_arguments(parser, True, False, True)
353
  train_util.add_training_arguments(parser, True)
354
  train_util.add_sd_saving_arguments(parser)
355
+ train_util.add_optimizer_arguments(parser)
356
+ config_util.add_config_arguments(parser)
357
 
358
  parser.add_argument("--no_token_padding", action="store_true",
359
  help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
train_network.py CHANGED
@@ -1,8 +1,4 @@
1
- from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
2
- from torch.optim import Optimizer
3
- from torch.cuda.amp import autocast
4
  from torch.nn.parallel import DistributedDataParallel as DDP
5
- from typing import Optional, Union
6
  import importlib
7
  import argparse
8
  import gc
@@ -15,92 +11,39 @@ import json
15
  from tqdm import tqdm
16
  import torch
17
  from accelerate.utils import set_seed
18
- import diffusers
19
  from diffusers import DDPMScheduler
20
 
21
  import library.train_util as train_util
22
- from library.train_util import DreamBoothDataset, FineTuningDataset
 
 
 
 
 
 
 
23
 
24
 
25
  def collate_fn(examples):
26
  return examples[0]
27
 
28
 
 
29
  def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
30
  logs = {"loss/current": current_loss, "loss/average": avr_loss}
31
 
32
  if args.network_train_unet_only:
33
- logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
34
  elif args.network_train_text_encoder_only:
35
- logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
36
  else:
37
- logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
38
- logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
39
-
40
- return logs
41
 
 
 
42
 
43
- # Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
44
- # code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
45
- # Which is a newer release of diffusers than currently packaged with sd-scripts
46
- # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
47
-
48
-
49
- def get_scheduler_fix(
50
- name: Union[str, SchedulerType],
51
- optimizer: Optimizer,
52
- num_warmup_steps: Optional[int] = None,
53
- num_training_steps: Optional[int] = None,
54
- num_cycles: int = 1,
55
- power: float = 1.0,
56
- ):
57
- """
58
- Unified API to get any scheduler from its name.
59
- Args:
60
- name (`str` or `SchedulerType`):
61
- The name of the scheduler to use.
62
- optimizer (`torch.optim.Optimizer`):
63
- The optimizer that will be used during training.
64
- num_warmup_steps (`int`, *optional*):
65
- The number of warmup steps to do. This is not required by all schedulers (hence the argument being
66
- optional), the function will raise an error if it's unset and the scheduler type requires it.
67
- num_training_steps (`int``, *optional*):
68
- The number of training steps to do. This is not required by all schedulers (hence the argument being
69
- optional), the function will raise an error if it's unset and the scheduler type requires it.
70
- num_cycles (`int`, *optional*):
71
- The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
72
- power (`float`, *optional*, defaults to 1.0):
73
- Power factor. See `POLYNOMIAL` scheduler
74
- last_epoch (`int`, *optional*, defaults to -1):
75
- The index of the last epoch when resuming training.
76
- """
77
- name = SchedulerType(name)
78
- schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
79
- if name == SchedulerType.CONSTANT:
80
- return schedule_func(optimizer)
81
-
82
- # All other schedulers require `num_warmup_steps`
83
- if num_warmup_steps is None:
84
- raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
85
-
86
- if name == SchedulerType.CONSTANT_WITH_WARMUP:
87
- return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
88
-
89
- # All other schedulers require `num_training_steps`
90
- if num_training_steps is None:
91
- raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
92
-
93
- if name == SchedulerType.COSINE_WITH_RESTARTS:
94
- return schedule_func(
95
- optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
96
- )
97
-
98
- if name == SchedulerType.POLYNOMIAL:
99
- return schedule_func(
100
- optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
101
- )
102
-
103
- return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
104
 
105
 
106
  def train(args):
@@ -111,6 +54,7 @@ def train(args):
111
 
112
  cache_latents = args.cache_latents
113
  use_dreambooth_method = args.in_json is None
 
114
 
115
  if args.seed is not None:
116
  set_seed(args.seed)
@@ -118,35 +62,47 @@ def train(args):
118
  tokenizer = train_util.load_tokenizer(args)
119
 
120
  # データセットを準備する
121
- if use_dreambooth_method:
122
- print("Use DreamBooth method.")
123
- train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
124
- tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
125
- args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
126
- args.bucket_reso_steps, args.bucket_no_upscale,
127
- args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
128
- args.random_crop, args.debug_dataset)
129
  else:
130
- print("Train with captions.")
131
- train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
132
- tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
133
- args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
134
- args.bucket_reso_steps, args.bucket_no_upscale,
135
- args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
136
- args.dataset_repeats, args.debug_dataset)
137
-
138
- # 学習データのdropout率を設定する
139
- train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
140
-
141
- train_dataset.make_buckets()
 
 
 
 
 
 
 
 
142
 
143
  if args.debug_dataset:
144
- train_util.debug_dataset(train_dataset)
145
  return
146
- if len(train_dataset) == 0:
147
  print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
148
  return
149
 
 
 
 
 
150
  # acceleratorを準備する
151
  print("prepare accelerator")
152
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
@@ -161,7 +117,7 @@ def train(args):
161
  if args.lowram:
162
  text_encoder.to("cuda")
163
  unet.to("cuda")
164
-
165
  # モデルに xformers とか memory efficient attention を組み込む
166
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
167
 
@@ -171,7 +127,7 @@ def train(args):
171
  vae.requires_grad_(False)
172
  vae.eval()
173
  with torch.no_grad():
174
- train_dataset.cache_latents(vae)
175
  vae.to("cpu")
176
  if torch.cuda.is_available():
177
  torch.cuda.empty_cache()
@@ -208,36 +164,14 @@ def train(args):
208
  # 学習に必要なクラスを準備する
209
  print("prepare optimizer, data loader etc.")
210
 
211
- # 8-bit Adamを使う
212
- if args.use_8bit_adam:
213
- try:
214
- import bitsandbytes as bnb
215
- except ImportError:
216
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
217
- print("use 8-bit Adam optimizer")
218
- optimizer_class = bnb.optim.AdamW8bit
219
- elif args.use_lion_optimizer:
220
- try:
221
- import lion_pytorch
222
- except ImportError:
223
- raise ImportError("No lion_pytorch / lion_pytorch がインストールされて��ないようです")
224
- print("use Lion optimizer")
225
- optimizer_class = lion_pytorch.Lion
226
- else:
227
- optimizer_class = torch.optim.AdamW
228
-
229
- optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
230
-
231
  trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
232
-
233
- # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
234
- optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
235
 
236
  # dataloaderを準備する
237
  # DataLoaderのプロセス数:0はメインプロセスになる
238
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
239
  train_dataloader = torch.utils.data.DataLoader(
240
- train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
241
 
242
  # 学習ステップ数を計算する
243
  if args.max_train_epochs is not None:
@@ -245,11 +179,9 @@ def train(args):
245
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
246
 
247
  # lr schedulerを用意する
248
- # lr_scheduler = diffusers.optimization.get_scheduler(
249
- lr_scheduler = get_scheduler_fix(
250
- args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
251
- num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
252
- num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
253
 
254
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
255
  if args.full_fp16:
@@ -317,17 +249,19 @@ def train(args):
317
  args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
318
 
319
  # 学習する
 
320
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
321
  print("running training / 学習開始")
322
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
323
- print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
324
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
325
  print(f" num epochs / epoch数: {num_train_epochs}")
326
- print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
327
- print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
328
  print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
329
  print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
330
 
 
331
  metadata = {
332
  "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
333
  "ss_training_started_at": training_started_at, # unix timestamp
@@ -335,12 +269,10 @@ def train(args):
335
  "ss_learning_rate": args.learning_rate,
336
  "ss_text_encoder_lr": args.text_encoder_lr,
337
  "ss_unet_lr": args.unet_lr,
338
- "ss_num_train_images": train_dataset.num_train_images, # includes repeating
339
- "ss_num_reg_images": train_dataset.num_reg_images,
340
  "ss_num_batches_per_epoch": len(train_dataloader),
341
  "ss_num_epochs": num_train_epochs,
342
- "ss_batch_size_per_device": args.train_batch_size,
343
- "ss_total_batch_size": total_batch_size,
344
  "ss_gradient_checkpointing": args.gradient_checkpointing,
345
  "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
346
  "ss_max_train_steps": args.max_train_steps,
@@ -352,29 +284,149 @@ def train(args):
352
  "ss_mixed_precision": args.mixed_precision,
353
  "ss_full_fp16": bool(args.full_fp16),
354
  "ss_v2": bool(args.v2),
355
- "ss_resolution": args.resolution,
356
  "ss_clip_skip": args.clip_skip,
357
  "ss_max_token_length": args.max_token_length,
358
- "ss_color_aug": bool(args.color_aug),
359
- "ss_flip_aug": bool(args.flip_aug),
360
- "ss_random_crop": bool(args.random_crop),
361
- "ss_shuffle_caption": bool(args.shuffle_caption),
362
  "ss_cache_latents": bool(args.cache_latents),
363
- "ss_enable_bucket": bool(train_dataset.enable_bucket),
364
- "ss_min_bucket_reso": train_dataset.min_bucket_reso,
365
- "ss_max_bucket_reso": train_dataset.max_bucket_reso,
366
  "ss_seed": args.seed,
367
- "ss_keep_tokens": args.keep_tokens,
368
  "ss_noise_offset": args.noise_offset,
369
- "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
370
- "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
371
- "ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
372
- "ss_bucket_info": json.dumps(train_dataset.bucket_info),
373
  "ss_training_comment": args.training_comment, # will not be updated after training
374
  "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
375
- "ss_optimizer": optimizer_name
 
 
 
 
 
 
376
  }
377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  # uncomment if another network is added
379
  # for key, value in net_kwargs.items():
380
  # metadata["ss_arg_" + key] = value
@@ -410,7 +462,7 @@ def train(args):
410
  loss_total = 0.0
411
  for epoch in range(num_train_epochs):
412
  print(f"epoch {epoch+1}/{num_train_epochs}")
413
- train_dataset.set_current_epoch(epoch + 1)
414
 
415
  metadata["ss_epoch"] = str(epoch+1)
416
 
@@ -447,7 +499,7 @@ def train(args):
447
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
448
 
449
  # Predict the noise residual
450
- with autocast():
451
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
452
 
453
  if args.v_parameterization:
@@ -465,9 +517,9 @@ def train(args):
465
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
466
 
467
  accelerator.backward(loss)
468
- if accelerator.sync_gradients:
469
  params_to_clip = network.get_trainable_params()
470
- accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
471
 
472
  optimizer.step()
473
  lr_scheduler.step()
@@ -478,6 +530,8 @@ def train(args):
478
  progress_bar.update(1)
479
  global_step += 1
480
 
 
 
481
  current_loss = loss.detach().item()
482
  if epoch == 0:
483
  loss_list.append(current_loss)
@@ -508,6 +562,7 @@ def train(args):
508
  def save_func():
509
  ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
510
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
 
511
  print(f"saving checkpoint: {ckpt_file}")
512
  unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
513
 
@@ -522,9 +577,12 @@ def train(args):
522
  if saving and args.save_state:
523
  train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
524
 
 
 
525
  # end of epoch
526
 
527
  metadata["ss_epoch"] = str(num_train_epochs)
 
528
 
529
  is_main_process = accelerator.is_main_process
530
  if is_main_process:
@@ -555,6 +613,8 @@ if __name__ == '__main__':
555
  train_util.add_sd_models_arguments(parser)
556
  train_util.add_dataset_arguments(parser, True, True, True)
557
  train_util.add_training_arguments(parser, True)
 
 
558
 
559
  parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
560
  parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
@@ -562,10 +622,6 @@ if __name__ == '__main__':
562
 
563
  parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
564
  parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
565
- parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
566
- help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
567
- parser.add_argument("--lr_scheduler_power", type=float, default=1,
568
- help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
569
 
570
  parser.add_argument("--network_weights", type=str, default=None,
571
  help="pretrained weights for network / 学習するネットワークの初期重み")
 
 
 
 
1
  from torch.nn.parallel import DistributedDataParallel as DDP
 
2
  import importlib
3
  import argparse
4
  import gc
 
11
  from tqdm import tqdm
12
  import torch
13
  from accelerate.utils import set_seed
 
14
  from diffusers import DDPMScheduler
15
 
16
  import library.train_util as train_util
17
+ from library.train_util import (
18
+ DreamBoothDataset,
19
+ )
20
+ import library.config_util as config_util
21
+ from library.config_util import (
22
+ ConfigSanitizer,
23
+ BlueprintGenerator,
24
+ )
25
 
26
 
27
  def collate_fn(examples):
28
  return examples[0]
29
 
30
 
31
+ # TODO 他のスクリプトと共通化する
32
  def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
33
  logs = {"loss/current": current_loss, "loss/average": avr_loss}
34
 
35
  if args.network_train_unet_only:
36
+ logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
37
  elif args.network_train_text_encoder_only:
38
+ logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
39
  else:
40
+ logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
41
+ logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
 
 
42
 
43
+ if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
44
+ logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
45
 
46
+ return logs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
 
49
  def train(args):
 
54
 
55
  cache_latents = args.cache_latents
56
  use_dreambooth_method = args.in_json is None
57
+ use_user_config = args.dataset_config is not None
58
 
59
  if args.seed is not None:
60
  set_seed(args.seed)
 
62
  tokenizer = train_util.load_tokenizer(args)
63
 
64
  # データセットを準備する
65
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
66
+ if use_user_config:
67
+ print(f"Load dataset config from {args.dataset_config}")
68
+ user_config = config_util.load_user_config(args.dataset_config)
69
+ ignored = ["train_data_dir", "reg_data_dir", "in_json"]
70
+ if any(getattr(args, attr) is not None for attr in ignored):
71
+ print(
72
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
73
  else:
74
+ if use_dreambooth_method:
75
+ print("Use DreamBooth method.")
76
+ user_config = {
77
+ "datasets": [{
78
+ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
79
+ }]
80
+ }
81
+ else:
82
+ print("Train with captions.")
83
+ user_config = {
84
+ "datasets": [{
85
+ "subsets": [{
86
+ "image_dir": args.train_data_dir,
87
+ "metadata_file": args.in_json,
88
+ }]
89
+ }]
90
+ }
91
+
92
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
93
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
94
 
95
  if args.debug_dataset:
96
+ train_util.debug_dataset(train_dataset_group)
97
  return
98
+ if len(train_dataset_group) == 0:
99
  print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
100
  return
101
 
102
+ if cache_latents:
103
+ assert train_dataset_group.is_latent_cacheable(
104
+ ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
105
+
106
  # acceleratorを準備する
107
  print("prepare accelerator")
108
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
 
117
  if args.lowram:
118
  text_encoder.to("cuda")
119
  unet.to("cuda")
120
+
121
  # モデルに xformers とか memory efficient attention を組み込む
122
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
123
 
 
127
  vae.requires_grad_(False)
128
  vae.eval()
129
  with torch.no_grad():
130
+ train_dataset_group.cache_latents(vae)
131
  vae.to("cpu")
132
  if torch.cuda.is_available():
133
  torch.cuda.empty_cache()
 
164
  # 学習に必要なクラスを準備する
165
  print("prepare optimizer, data loader etc.")
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
168
+ optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
 
 
169
 
170
  # dataloaderを準備する
171
  # DataLoaderのプロセス数:0はメインプロセスになる
172
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
173
  train_dataloader = torch.utils.data.DataLoader(
174
+ train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
175
 
176
  # 学習ステップ数を計算する
177
  if args.max_train_epochs is not None:
 
179
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
180
 
181
  # lr schedulerを用意する
182
+ lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
183
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
184
+ num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
 
 
185
 
186
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
187
  if args.full_fp16:
 
249
  args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
250
 
251
  # 学習する
252
+ # TODO: find a way to handle total batch size when there are multiple datasets
253
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
254
  print("running training / 学習開始")
255
+ print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
256
+ print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
257
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
258
  print(f" num epochs / epoch数: {num_train_epochs}")
259
+ print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
260
+ # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
261
  print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
262
  print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
263
 
264
+ # TODO refactor metadata creation and move to util
265
  metadata = {
266
  "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
267
  "ss_training_started_at": training_started_at, # unix timestamp
 
269
  "ss_learning_rate": args.learning_rate,
270
  "ss_text_encoder_lr": args.text_encoder_lr,
271
  "ss_unet_lr": args.unet_lr,
272
+ "ss_num_train_images": train_dataset_group.num_train_images,
273
+ "ss_num_reg_images": train_dataset_group.num_reg_images,
274
  "ss_num_batches_per_epoch": len(train_dataloader),
275
  "ss_num_epochs": num_train_epochs,
 
 
276
  "ss_gradient_checkpointing": args.gradient_checkpointing,
277
  "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
278
  "ss_max_train_steps": args.max_train_steps,
 
284
  "ss_mixed_precision": args.mixed_precision,
285
  "ss_full_fp16": bool(args.full_fp16),
286
  "ss_v2": bool(args.v2),
 
287
  "ss_clip_skip": args.clip_skip,
288
  "ss_max_token_length": args.max_token_length,
 
 
 
 
289
  "ss_cache_latents": bool(args.cache_latents),
 
 
 
290
  "ss_seed": args.seed,
291
+ "ss_lowram": args.lowram,
292
  "ss_noise_offset": args.noise_offset,
 
 
 
 
293
  "ss_training_comment": args.training_comment, # will not be updated after training
294
  "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
295
+ "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
296
+ "ss_max_grad_norm": args.max_grad_norm,
297
+ "ss_caption_dropout_rate": args.caption_dropout_rate,
298
+ "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
299
+ "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
300
+ "ss_face_crop_aug_range": args.face_crop_aug_range,
301
+ "ss_prior_loss_weight": args.prior_loss_weight,
302
  }
303
 
304
+ if use_user_config:
305
+ # save metadata of multiple datasets
306
+ # NOTE: pack "ss_datasets" value as json one time
307
+ # or should also pack nested collections as json?
308
+ datasets_metadata = []
309
+ tag_frequency = {} # merge tag frequency for metadata editor
310
+ dataset_dirs_info = {} # merge subset dirs for metadata editor
311
+
312
+ for dataset in train_dataset_group.datasets:
313
+ is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
314
+ dataset_metadata = {
315
+ "is_dreambooth": is_dreambooth_dataset,
316
+ "batch_size_per_device": dataset.batch_size,
317
+ "num_train_images": dataset.num_train_images, # includes repeating
318
+ "num_reg_images": dataset.num_reg_images,
319
+ "resolution": (dataset.width, dataset.height),
320
+ "enable_bucket": bool(dataset.enable_bucket),
321
+ "min_bucket_reso": dataset.min_bucket_reso,
322
+ "max_bucket_reso": dataset.max_bucket_reso,
323
+ "tag_frequency": dataset.tag_frequency,
324
+ "bucket_info": dataset.bucket_info,
325
+ }
326
+
327
+ subsets_metadata = []
328
+ for subset in dataset.subsets:
329
+ subset_metadata = {
330
+ "img_count": subset.img_count,
331
+ "num_repeats": subset.num_repeats,
332
+ "color_aug": bool(subset.color_aug),
333
+ "flip_aug": bool(subset.flip_aug),
334
+ "random_crop": bool(subset.random_crop),
335
+ "shuffle_caption": bool(subset.shuffle_caption),
336
+ "keep_tokens": subset.keep_tokens,
337
+ }
338
+
339
+ image_dir_or_metadata_file = None
340
+ if subset.image_dir:
341
+ image_dir = os.path.basename(subset.image_dir)
342
+ subset_metadata["image_dir"] = image_dir
343
+ image_dir_or_metadata_file = image_dir
344
+
345
+ if is_dreambooth_dataset:
346
+ subset_metadata["class_tokens"] = subset.class_tokens
347
+ subset_metadata["is_reg"] = subset.is_reg
348
+ if subset.is_reg:
349
+ image_dir_or_metadata_file = None # not merging reg dataset
350
+ else:
351
+ metadata_file = os.path.basename(subset.metadata_file)
352
+ subset_metadata["metadata_file"] = metadata_file
353
+ image_dir_or_metadata_file = metadata_file # may overwrite
354
+
355
+ subsets_metadata.append(subset_metadata)
356
+
357
+ # merge dataset dir: not reg subset only
358
+ # TODO update additional-network extension to show detailed dataset config from metadata
359
+ if image_dir_or_metadata_file is not None:
360
+ # datasets may have a certain dir multiple times
361
+ v = image_dir_or_metadata_file
362
+ i = 2
363
+ while v in dataset_dirs_info:
364
+ v = image_dir_or_metadata_file + f" ({i})"
365
+ i += 1
366
+ image_dir_or_metadata_file = v
367
+
368
+ dataset_dirs_info[image_dir_or_metadata_file] = {
369
+ "n_repeats": subset.num_repeats,
370
+ "img_count": subset.img_count
371
+ }
372
+
373
+ dataset_metadata["subsets"] = subsets_metadata
374
+ datasets_metadata.append(dataset_metadata)
375
+
376
+ # merge tag frequency:
377
+ for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
378
+ # あるデ���レクトリが複数のdatasetで使用されている場合、一度だけ数える
379
+ # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
380
+ # なので、ここで複数datasetの回数を合算してもあまり意味はない
381
+ if ds_dir_name in tag_frequency:
382
+ continue
383
+ tag_frequency[ds_dir_name] = ds_freq_for_dir
384
+
385
+ metadata["ss_datasets"] = json.dumps(datasets_metadata)
386
+ metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
387
+ metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
388
+ else:
389
+ # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
390
+ assert len(
391
+ train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
392
+
393
+ dataset = train_dataset_group.datasets[0]
394
+
395
+ dataset_dirs_info = {}
396
+ reg_dataset_dirs_info = {}
397
+ if use_dreambooth_method:
398
+ for subset in dataset.subsets:
399
+ info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
400
+ info[os.path.basename(subset.image_dir)] = {
401
+ "n_repeats": subset.num_repeats,
402
+ "img_count": subset.img_count
403
+ }
404
+ else:
405
+ for subset in dataset.subsets:
406
+ dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
407
+ "n_repeats": subset.num_repeats,
408
+ "img_count": subset.img_count
409
+ }
410
+
411
+ metadata.update({
412
+ "ss_batch_size_per_device": args.train_batch_size,
413
+ "ss_total_batch_size": total_batch_size,
414
+ "ss_resolution": args.resolution,
415
+ "ss_color_aug": bool(args.color_aug),
416
+ "ss_flip_aug": bool(args.flip_aug),
417
+ "ss_random_crop": bool(args.random_crop),
418
+ "ss_shuffle_caption": bool(args.shuffle_caption),
419
+ "ss_enable_bucket": bool(dataset.enable_bucket),
420
+ "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
421
+ "ss_min_bucket_reso": dataset.min_bucket_reso,
422
+ "ss_max_bucket_reso": dataset.max_bucket_reso,
423
+ "ss_keep_tokens": args.keep_tokens,
424
+ "ss_dataset_dirs": json.dumps(dataset_dirs_info),
425
+ "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
426
+ "ss_tag_frequency": json.dumps(dataset.tag_frequency),
427
+ "ss_bucket_info": json.dumps(dataset.bucket_info),
428
+ })
429
+
430
  # uncomment if another network is added
431
  # for key, value in net_kwargs.items():
432
  # metadata["ss_arg_" + key] = value
 
462
  loss_total = 0.0
463
  for epoch in range(num_train_epochs):
464
  print(f"epoch {epoch+1}/{num_train_epochs}")
465
+ train_dataset_group.set_current_epoch(epoch + 1)
466
 
467
  metadata["ss_epoch"] = str(epoch+1)
468
 
 
499
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
500
 
501
  # Predict the noise residual
502
+ with accelerator.autocast():
503
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
504
 
505
  if args.v_parameterization:
 
517
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
518
 
519
  accelerator.backward(loss)
520
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
521
  params_to_clip = network.get_trainable_params()
522
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
523
 
524
  optimizer.step()
525
  lr_scheduler.step()
 
530
  progress_bar.update(1)
531
  global_step += 1
532
 
533
+ train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
534
+
535
  current_loss = loss.detach().item()
536
  if epoch == 0:
537
  loss_list.append(current_loss)
 
562
  def save_func():
563
  ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
564
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
565
+ metadata["ss_training_finished_at"] = str(time.time())
566
  print(f"saving checkpoint: {ckpt_file}")
567
  unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
568
 
 
577
  if saving and args.save_state:
578
  train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
579
 
580
+ train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
581
+
582
  # end of epoch
583
 
584
  metadata["ss_epoch"] = str(num_train_epochs)
585
+ metadata["ss_training_finished_at"] = str(time.time())
586
 
587
  is_main_process = accelerator.is_main_process
588
  if is_main_process:
 
613
  train_util.add_sd_models_arguments(parser)
614
  train_util.add_dataset_arguments(parser, True, True, True)
615
  train_util.add_training_arguments(parser, True)
616
+ train_util.add_optimizer_arguments(parser)
617
+ config_util.add_config_arguments(parser)
618
 
619
  parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
620
  parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
 
622
 
623
  parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
624
  parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
 
 
 
 
625
 
626
  parser.add_argument("--network_weights", type=str, default=None,
627
  help="pretrained weights for network / 学習するネットワークの初期重み")
train_network_opt.py CHANGED
@@ -1,8 +1,5 @@
1
- from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
2
- from torch.optim import Optimizer
3
  from torch.cuda.amp import autocast
4
  from torch.nn.parallel import DistributedDataParallel as DDP
5
- from typing import Optional, Union
6
  import importlib
7
  import argparse
8
  import gc
@@ -17,136 +14,47 @@ import torch
17
  from accelerate.utils import set_seed
18
  import diffusers
19
  from diffusers import DDPMScheduler
20
- print("**********************************")
21
- #先に
22
- #pip install torch_optimizer
23
- #が必要
24
- try:
25
- import torch_optimizer as optim
26
- except:
27
- print("torch_optimizerがインストールされていないためAdafactorとAdastand以外の追加optimzierは使えません。\noptimizerの変更をしたい場合先にpip install torch_optimizerでライブラリを追加してください")
28
- try:
29
- import adastand
30
- except:
31
- print("※Adastandが使えません")
32
-
33
- from transformers.optimization import Adafactor, AdafactorSchedule
34
- print("**********************************")
35
  ##### バケット拡張のためのモジュール
36
  import append_module
37
  ######
38
  import library.train_util as train_util
39
- from library.train_util import DreamBoothDataset, FineTuningDataset
 
 
 
 
 
 
 
40
 
41
 
42
  def collate_fn(examples):
43
  return examples[0]
44
 
45
 
46
- def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
 
47
  logs = {"loss/current": current_loss, "loss/average": avr_loss}
48
-
49
- if args.network_train_unet_only:
50
- logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
51
- elif args.network_train_text_encoder_only:
52
- logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
 
 
 
53
  else:
54
  last_lrs = lr_scheduler.get_last_lr()
55
- if len(last_lrs) == 2:
56
- logs["lr/textencoder"] = float(last_lrs[0])
57
- logs["lr/unet"] = float(last_lrs[-1]) # may be same to textencoder
58
- else:
59
- if len(last_lrs) == 4:
60
- logs_names = ["textencoder", "lora_unet_mid_block", "unet_down_blocks", "unet_up_blocks"]
61
- elif len(last_lrs) == 8:
62
- logs_names = ["textencoder", "unet_midblock"]
63
- for i in range(3):
64
- logs_names.append(f"unet_down_blocks_{i}")
65
- logs_names.append(f"unet_up_blocks_{i+1}")
66
- else:
67
- logs_names = []
68
- for i in range(12):
69
- logs_names.append(f"text_model_encoder_layers_{i}_")
70
- logs_names.append("unet_midblock")
71
- for i in range(3):
72
- logs_names.append(f"unet_down_blocks_{i}")
73
- logs_names.append(f"unet_up_blocks_{i+1}")
74
-
75
- for last_lr, logs_name in zip(last_lrs, logs_names):
76
- logs[f"lr/{logs_name}"] = float(last_lr)
77
 
78
  return logs
79
 
80
 
81
- # Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
82
- # code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
83
- # Which is a newer release of diffusers than currently packaged with sd-scripts
84
- # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
85
-
86
-
87
- def get_scheduler_fix(
88
- name: Union[str, SchedulerType],
89
- optimizer: Optimizer,
90
- num_warmup_steps: Optional[int] = None,
91
- num_training_steps: Optional[int] = None,
92
- num_cycles: float = 1.,
93
- power: float = 1.0,
94
- ):
95
- """
96
- Unified API to get any scheduler from its name.
97
- Args:
98
- name (`str` or `SchedulerType`):
99
- The name of the scheduler to use.
100
- optimizer (`torch.optim.Optimizer`):
101
- The optimizer that will be used during training.
102
- num_warmup_steps (`int`, *optional*):
103
- The number of warmup steps to do. This is not required by all schedulers (hence the argument being
104
- optional), the function will raise an error if it's unset and the scheduler type requires it.
105
- num_training_steps (`int``, *optional*):
106
- The number of training steps to do. This is not required by all schedulers (hence the argument being
107
- optional), the function will raise an error if it's unset and the scheduler type requires it.
108
- num_cycles (`int`, *optional*):
109
- The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
110
- power (`float`, *optional*, defaults to 1.0):
111
- Power factor. See `POLYNOMIAL` scheduler
112
- last_epoch (`int`, *optional*, defaults to -1):
113
- The index of the last epoch when resuming training.
114
- """
115
- name = SchedulerType(name)
116
- schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
117
- if name == SchedulerType.CONSTANT:
118
- return schedule_func(optimizer)
119
-
120
- # All other schedulers require `num_warmup_steps`
121
- if num_warmup_steps is None:
122
- raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
123
-
124
- if name == SchedulerType.CONSTANT_WITH_WARMUP:
125
- return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
126
-
127
- # All other schedulers require `num_training_steps`
128
- if num_training_steps is None:
129
- raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
130
-
131
- if name == SchedulerType.COSINE:
132
- print(f"{name} num_cycles: {num_cycles}")
133
- return schedule_func(
134
- optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
135
- )
136
- if name == SchedulerType.COSINE_WITH_RESTARTS:
137
- print(f"{name} num_cycles: {int(num_cycles)}")
138
- return schedule_func(
139
- optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=int(num_cycles)
140
- )
141
-
142
- if name == SchedulerType.POLYNOMIAL:
143
- return schedule_func(
144
- optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
145
- )
146
-
147
- return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
148
-
149
-
150
  def train(args):
151
  session_id = random.randint(0, 2**32)
152
  training_started_at = time.time()
@@ -155,6 +63,7 @@ def train(args):
155
 
156
  cache_latents = args.cache_latents
157
  use_dreambooth_method = args.in_json is None
 
158
 
159
  if args.seed is not None:
160
  set_seed(args.seed)
@@ -162,40 +71,56 @@ def train(args):
162
  tokenizer = train_util.load_tokenizer(args)
163
 
164
  # データセットを準備する
165
- if use_dreambooth_method:
166
- if args.min_resolution:
167
- args.min_resolution = tuple([int(r) for r in args.min_resolution.split(',')])
168
- if len(args.min_resolution) == 1:
169
- args.min_resolution = (args.min_resolution[0], args.min_resolution[0])
170
-
171
- print("Use DreamBooth method.")
172
- train_dataset = append_module.DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
173
- tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
174
- args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
175
- args.bucket_reso_steps, args.bucket_no_upscale,
176
- args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
177
- args.random_crop, args.debug_dataset, args.min_resolution, args.area_step)
178
  else:
179
- print("Train with captions.")
180
- train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
181
- tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
182
- args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
183
- args.bucket_reso_steps, args.bucket_no_upscale,
184
- args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
185
- args.dataset_repeats, args.debug_dataset)
186
-
187
- # 学習データのdropout率を設定する
188
- train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
189
-
190
- train_dataset.make_buckets()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
  if args.debug_dataset:
193
- train_util.debug_dataset(train_dataset)
194
  return
195
- if len(train_dataset) == 0:
196
  print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
197
  return
198
 
 
 
 
 
199
  # acceleratorを準備する
200
  print("prepare accelerator")
201
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
@@ -205,9 +130,12 @@ def train(args):
205
 
206
  # モデルを読み込む
207
  text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
208
- # unnecessary, but work on low-ram device
209
- text_encoder.to("cuda")
210
- unet.to("cuda")
 
 
 
211
  # モデルに xformers とか memory efficient attention を組み込む
212
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
213
 
@@ -217,7 +145,7 @@ def train(args):
217
  vae.requires_grad_(False)
218
  vae.eval()
219
  with torch.no_grad():
220
- train_dataset.cache_latents(vae)
221
  vae.to("cpu")
222
  if torch.cuda.is_available():
223
  torch.cuda.empty_cache()
@@ -253,165 +181,45 @@ def train(args):
253
 
254
  # 学習に必要なクラスを準備する
255
  print("prepare optimizer, data loader etc.")
256
- try:
257
- print(f"torch_optimzier version is {optim.__version__}")
258
- not_torch_optimizer_flag = False
259
- except:
260
- not_torch_optimizer_flag = True
261
- try:
262
- print(f"adastand version is {adastand.__version__()}")
263
- not_adasatand_optimzier_flag = False
264
- except:
265
- not_adasatand_optimzier_flag = True
266
-
267
- # 8-bit Adamを使う
268
- if args.optimizer=="Adafactor" or args.optimizer=="Adastand" or args.optimizer=="Adastand_belief":
269
- not_torch_optimizer_flag = False
270
- if args.optimizer=="Adafactor":
271
- not_adasatand_optimzier_flag = False
272
- if not_torch_optimizer_flag or not_adasatand_optimzier_flag:
273
- print(f"==========================\n必要なライブラリがないため {args.optimizer} の使用ができません。optimizerを AdamW に変更して実行します\n==========================")
274
- args.optimizer="AdamW"
275
- if args.use_8bit_adam:
276
- if not args.optimizer=="AdamW" and not args.optimizer=="Lamb":
277
- print(f"\n==========================\n{args.optimizer} は8bitAdamに実装されていないので8bitAdamをオフにします\n==========================\n")
278
- args.use_8bit_adam=False
279
- if args.use_8bit_adam:
280
- try:
281
- import bitsandbytes as bnb
282
- except ImportError:
283
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
284
- print("use 8-bit Adam optimizer")
285
- args.training_comment=f"{args.training_comment} use_8bit_adam=True"
286
- if args.optimizer=="Lamb":
287
- optimizer_class = bnb.optim.LAMB8bit
288
- else:
289
- args.optimizer="AdamW"
290
- optimizer_class = bnb.optim.AdamW8bit
291
- else:
292
- print(f"use {args.optimizer}")
293
- if args.optimizer=="RAdam":
294
- optimizer_class = torch.optim.RAdam
295
- elif args.optimizer=="AdaBound":
296
- optimizer_class = optim.AdaBound
297
- elif args.optimizer=="AdaBelief":
298
- optimizer_class = optim.AdaBelief
299
- elif args.optimizer=="AdamP":
300
- optimizer_class = optim.AdamP
301
- elif args.optimizer=="Adafactor":
302
- optimizer_class = Adafactor
303
- elif args.optimizer=="Adastand":
304
- optimizer_class = adastand.Adastand
305
- elif args.optimizer=="Adastand_belief":
306
- optimizer_class = adastand.Adastand_b
307
- elif args.optimizer=="AggMo":
308
- optimizer_class = optim.AggMo
309
- elif args.optimizer=="Apollo":
310
- optimizer_class = optim.Apollo
311
- elif args.optimizer=="Lamb":
312
- optimizer_class = optim.Lamb
313
- elif args.optimizer=="Ranger":
314
- optimizer_class = optim.Ranger
315
- elif args.optimizer=="RangerVA":
316
- optimizer_class = optim.RangerVA
317
- elif args.optimizer=="Yogi":
318
- optimizer_class = optim.Yogi
319
- elif args.optimizer=="Shampoo":
320
- optimizer_class = optim.Shampoo
321
- elif args.optimizer=="NovoGrad":
322
- optimizer_class = optim.NovoGrad
323
- elif args.optimizer=="QHAdam":
324
- optimizer_class = optim.QHAdam
325
- elif args.optimizer=="DiffGrad" or args.optimizer=="Lookahead_DiffGrad":
326
- optimizer_class = optim.DiffGrad
327
- elif args.optimizer=="MADGRAD":
328
- optimizer_class = optim.MADGRAD
329
- else:
330
- optimizer_class = torch.optim.AdamW
331
-
332
- # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
333
- #optimizerデフォ設定
334
- if args.optimizer_arg==None:
335
- if args.optimizer=="AdaBelief":
336
- args.optimizer_arg = ["eps=1e-16","betas=0.9,0.999","weight_decouple=True","rectify=False","fixed_decay=False"]
337
- elif args.optimizer=="DiffGrad":
338
- args.optimizer_arg = ["eps=1e-16"]
339
- optimizer_arg = {}
340
- lookahed_arg = {"k": 5, "alpha": 0.5}
341
- adafactor_scheduler_arg = {"initial_lr": 0.}
342
- int_args = ["k","n_sma_threshold","warmup"]
343
- str_args = ["transformer","grad_transformer"]
344
- if not args.optimizer_arg==None and len(args.optimizer_arg)>0:
345
- for _opt_arg in args.optimizer_arg:
346
- key, value = _opt_arg.split("=")
347
- if value=="True" or value=="False":
348
- optimizer_arg[key]=bool((value=="True"))
349
- elif key=="betas" or key=="nus" or key=="eps2" or (key=="eps" and "," in value):
350
- _value = value.split(",")
351
- optimizer_arg[key] = (float(_value[0]),float(_value[1]))
352
- del _value
353
- elif key in int_args:
354
- if "Lookahead" in args.optimizer:
355
- lookahed_arg[key] = int(value)
356
- else:
357
- optimizer_arg[key] = int(value)
358
- elif key in str_args:
359
- optimizer_arg[key] = value
360
- else:
361
- if key=="alpha" and "Lookahead" in args.optimizer:
362
- lookahed_arg[key] = int(value)
363
- elif key=="initial_lr" and args.optimizer == "Adafactor":
364
- adafactor_scheduler_arg[key] = float(value)
365
- else:
366
- optimizer_arg[key] = float(value)
367
- del _opt_arg
368
- AdafactorScheduler_Flag = False
369
- list_of_init_lr = []
370
- if args.optimizer=="Adafactor":
371
- if not "relative_step" in optimizer_arg:
372
- optimizer_arg["relative_step"] = True
373
- if "warmup_init" in optimizer_arg:
374
- if optimizer_arg["warmup_init"]==True and optimizer_arg["relative_step"]==False:
375
- print("**************\nwarmup_initはrelative_stepがオンである必要があるためrelative_stepをオンにします\n**************")
376
- optimizer_arg["relative_step"] = True
377
- if optimizer_arg["relative_step"] == True:
378
- AdafactorScheduler_Flag = True
379
- list_of_init_lr = [0.,0.]
380
- if args.text_encoder_lr is not None: list_of_init_lr[0] = float(args.text_encoder_lr)
381
- if args.unet_lr is not None: list_of_init_lr[1] = float(args.unet_lr)
382
- #if not "initial_lr" in adafactor_scheduler_arg:
383
- # adafactor_scheduler_arg = args.learning_rate
384
- args.learning_rate = None
385
- args.text_encoder_lr = None
386
- args.unet_lr = None
387
- print(f"optimizer arg: {optimizer_arg}")
388
- print("=-----------------------------------=")
389
- if not AdafactorScheduler_Flag: args.split_lora_networks = False
390
  if args.split_lora_networks:
 
391
  lora_names = append_module.create_split_names(args.split_lora_networks, args.split_lora_level)
392
  append_module.replace_prepare_optimizer_params(network)
393
- trainable_params, _list_of_init_lr = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, list_of_init_lr, lora_names)
394
  else:
395
  trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
396
- _list_of_init_lr = []
397
- print(f"trainable_params_len: {len(trainable_params)}")
398
- if len(_list_of_init_lr)>0:
399
- list_of_init_lr = _list_of_init_lr
400
- print(f"split loras network is {len(list_of_init_lr)}")
401
- if len(list_of_init_lr) > 0:
402
- adafactor_scheduler_arg["initial_lr"] = list_of_init_lr
403
-
404
- optimizer = optimizer_class(trainable_params, lr=args.learning_rate, **optimizer_arg)
405
-
406
- if args.optimizer=="Lookahead_DiffGrad" or args.optimizer=="Lookahedad_Adam":
407
- optimizer = optim.Lookahead(optimizer, **lookahed_arg)
408
- print(f"lookahed_arg: {lookahed_arg}")
409
-
 
 
 
 
 
 
 
 
 
 
410
  # dataloaderを準備する
411
  # DataLoaderのプロセス数:0はメインプロセスになる
412
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
413
  train_dataloader = torch.utils.data.DataLoader(
414
- train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
415
 
416
  # 学習ステップ数を計算する
417
  if args.max_train_epochs is not None:
@@ -419,22 +227,18 @@ def train(args):
419
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
420
 
421
  # lr schedulerを用意する
422
- # lr_scheduler = diffusers.optimization.get_scheduler(
423
- if AdafactorScheduler_Flag:
424
- print("===================================\nAdafactorはデフォルトでrelative_stepがオンになっているので lrは自動算出されるためLrScheculerの指定も無効になります\nもし任意のLrやLr_Schedulerを使いたい場合は --optimizer_arg relative_ste=False を指定してください\nまた任意のLrを使う場合は scale_parameter=False も併せて指定するのが推奨です\n===================================")
425
- lr_scheduler = append_module.AdafactorSchedule_append(optimizer, **adafactor_scheduler_arg)
426
- print(f"AdafactorSchedule initial lrs: {lr_scheduler.get_lr()}")
427
- del list_of_init_lr
428
  else:
429
- lr_scheduler = get_scheduler_fix(
430
- args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
431
- num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
432
- num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
433
-
434
  #追加機能の設定をコメントに追記して残す
435
- args.training_comment=f"{args.training_comment} optimizer: {args.optimizer} / optimizer_arg: {args.optimizer_arg}"
436
- if AdafactorScheduler_Flag:
437
- args.training_comment=f"{args.training_comment} split_lora_networks: {args.split_lora_networks}"
 
438
  if args.min_resolution:
439
  args.training_comment=f"{args.training_comment} min_resolution: {args.min_resolution} area_step: {args.area_step}"
440
 
@@ -504,17 +308,19 @@ def train(args):
504
  args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
505
 
506
  # 学習する
 
507
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
508
  print("running training / 学習開始")
509
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
510
- print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
511
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
512
  print(f" num epochs / epoch数: {num_train_epochs}")
513
- print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
514
- print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
515
  print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
516
  print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
517
 
 
518
  metadata = {
519
  "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
520
  "ss_training_started_at": training_started_at, # unix timestamp
@@ -522,12 +328,10 @@ def train(args):
522
  "ss_learning_rate": args.learning_rate,
523
  "ss_text_encoder_lr": args.text_encoder_lr,
524
  "ss_unet_lr": args.unet_lr,
525
- "ss_num_train_images": train_dataset.num_train_images, # includes repeating
526
- "ss_num_reg_images": train_dataset.num_reg_images,
527
  "ss_num_batches_per_epoch": len(train_dataloader),
528
  "ss_num_epochs": num_train_epochs,
529
- "ss_batch_size_per_device": args.train_batch_size,
530
- "ss_total_batch_size": total_batch_size,
531
  "ss_gradient_checkpointing": args.gradient_checkpointing,
532
  "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
533
  "ss_max_train_steps": args.max_train_steps,
@@ -539,28 +343,149 @@ def train(args):
539
  "ss_mixed_precision": args.mixed_precision,
540
  "ss_full_fp16": bool(args.full_fp16),
541
  "ss_v2": bool(args.v2),
542
- "ss_resolution": args.resolution,
543
  "ss_clip_skip": args.clip_skip,
544
  "ss_max_token_length": args.max_token_length,
545
- "ss_color_aug": bool(args.color_aug),
546
- "ss_flip_aug": bool(args.flip_aug),
547
- "ss_random_crop": bool(args.random_crop),
548
- "ss_shuffle_caption": bool(args.shuffle_caption),
549
  "ss_cache_latents": bool(args.cache_latents),
550
- "ss_enable_bucket": bool(train_dataset.enable_bucket),
551
- "ss_min_bucket_reso": train_dataset.min_bucket_reso,
552
- "ss_max_bucket_reso": train_dataset.max_bucket_reso,
553
  "ss_seed": args.seed,
554
- "ss_keep_tokens": args.keep_tokens,
555
  "ss_noise_offset": args.noise_offset,
556
- "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
557
- "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
558
- "ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
559
- "ss_bucket_info": json.dumps(train_dataset.bucket_info),
560
  "ss_training_comment": args.training_comment, # will not be updated after training
561
- "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash()
 
 
 
 
 
 
 
562
  }
563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
  # uncomment if another network is added
565
  # for key, value in net_kwargs.items():
566
  # metadata["ss_arg_" + key] = value
@@ -596,7 +521,7 @@ def train(args):
596
  loss_total = 0.0
597
  for epoch in range(num_train_epochs):
598
  print(f"epoch {epoch+1}/{num_train_epochs}")
599
- train_dataset.set_current_epoch(epoch + 1)
600
 
601
  metadata["ss_epoch"] = str(epoch+1)
602
 
@@ -633,7 +558,7 @@ def train(args):
633
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
634
 
635
  # Predict the noise residual
636
- with autocast():
637
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
638
 
639
  if args.v_parameterization:
@@ -651,12 +576,18 @@ def train(args):
651
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
652
 
653
  accelerator.backward(loss)
654
- if accelerator.sync_gradients:
655
  params_to_clip = network.get_trainable_params()
656
- accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
657
 
 
658
  optimizer.step()
659
- lr_scheduler.step()
 
 
 
 
 
660
  optimizer.zero_grad(set_to_none=True)
661
 
662
  # Checks if the accelerator has performed an optimization step behind the scenes
@@ -664,6 +595,8 @@ def train(args):
664
  progress_bar.update(1)
665
  global_step += 1
666
 
 
 
667
  current_loss = loss.detach().item()
668
  if epoch == 0:
669
  loss_list.append(current_loss)
@@ -676,7 +609,7 @@ def train(args):
676
  progress_bar.set_postfix(**logs)
677
 
678
  if args.logging_dir is not None:
679
- logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
680
  accelerator.log(logs, step=global_step)
681
 
682
  if global_step >= args.max_train_steps:
@@ -694,6 +627,7 @@ def train(args):
694
  def save_func():
695
  ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
696
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
 
697
  print(f"saving checkpoint: {ckpt_file}")
698
  unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
699
 
@@ -708,9 +642,12 @@ def train(args):
708
  if saving and args.save_state:
709
  train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
710
 
 
 
711
  # end of epoch
712
 
713
  metadata["ss_epoch"] = str(num_train_epochs)
 
714
 
715
  is_main_process = accelerator.is_main_process
716
  if is_main_process:
@@ -741,6 +678,8 @@ if __name__ == '__main__':
741
  train_util.add_sd_models_arguments(parser)
742
  train_util.add_dataset_arguments(parser, True, True, True)
743
  train_util.add_training_arguments(parser, True)
 
 
744
 
745
  parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
746
  parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
@@ -748,10 +687,6 @@ if __name__ == '__main__':
748
 
749
  parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
750
  parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
751
- parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
752
- help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
753
- parser.add_argument("--lr_scheduler_power", type=float, default=1,
754
- help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
755
 
756
  parser.add_argument("--network_weights", type=str, default=None,
757
  help="pretrained weights for network / 学習するネットワークの初期重み")
@@ -771,27 +706,29 @@ if __name__ == '__main__':
771
  #Optimizer変更関連のオプション追加
772
  append_module.add_append_arguments(parser)
773
  args = append_module.get_config(parser)
 
 
 
 
 
 
 
 
 
 
 
774
 
775
  if args.resolution==args.min_resolution:
776
  args.min_resolution=None
777
 
778
  train(args)
 
779
 
780
- #学習が終わったら現在のargsを保存する
781
- # import yaml
782
- # import datetime
783
- # _t = datetime.datetime.today().strftime('%Y%m%d_%H%M')
784
- # if args.output_name==None:
785
- # config_name = f"train_network_config_{_t}.yaml"
786
- # else:
787
- # config_name = f"train_network_config_{os.path.basename(args.output_name)}_{_t}.yaml"
788
- # print(f"{config_name} に設定を書き出し中...")
789
- # with open(config_name, mode="w") as f:
790
- # yaml.dump(args.__dict__, f, indent=4)
791
- # print("done!")
792
 
793
  '''
794
  optimizer設定メモ
 
 
795
  (optimizer_argから設定できるように変更するためのメモ)
796
 
797
  AdamWのweight_decay初期値は1e-2
@@ -821,6 +758,7 @@ Adafactor
821
  transformerベースのT5学習において最強とかいう噂のoptimizer
822
  huggingfaceのサンプルパラ
823
  eps=1e-30,1e-3 clip_threshold=1.0 decay_rate=-0.8 relative_step=False scale_parameter=False warmup_init=False
 
824
 
825
  AggMo
826
 
 
 
 
1
  from torch.cuda.amp import autocast
2
  from torch.nn.parallel import DistributedDataParallel as DDP
 
3
  import importlib
4
  import argparse
5
  import gc
 
14
  from accelerate.utils import set_seed
15
  import diffusers
16
  from diffusers import DDPMScheduler
17
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  ##### バケット拡張のためのモジュール
19
  import append_module
20
  ######
21
  import library.train_util as train_util
22
+ from library.train_util import (
23
+ DreamBoothDataset,
24
+ )
25
+ import library.config_util as config_util
26
+ from library.config_util import (
27
+ ConfigSanitizer,
28
+ BlueprintGenerator,
29
+ )
30
 
31
 
32
  def collate_fn(examples):
33
  return examples[0]
34
 
35
 
36
+ # TODO 他のスクリプトと共通化する
37
+ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, split_names=None):
38
  logs = {"loss/current": current_loss, "loss/average": avr_loss}
39
+ if not args.split_lora_networks:
40
+ if args.network_train_unet_only:
41
+ logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
42
+ elif args.network_train_text_encoder_only:
43
+ logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
44
+ else:
45
+ logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
46
+ logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
47
  else:
48
  last_lrs = lr_scheduler.get_last_lr()
49
+ for last_lr, t_name in zip(last_lrs, split_names):
50
+ logs[f"lr/{t_name}"] = float(last_lr)
51
+ #D-Adaptationの仕様ちゃんと見てないからたぶん分割したのをちゃんと表示するならそれに合わせた記述が必要 でも多分D-Adaptationの挙動的に全部同一の形になるのでいらない
52
+ if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
53
+ logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  return logs
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def train(args):
59
  session_id = random.randint(0, 2**32)
60
  training_started_at = time.time()
 
63
 
64
  cache_latents = args.cache_latents
65
  use_dreambooth_method = args.in_json is None
66
+ use_user_config = args.dataset_config is not None
67
 
68
  if args.seed is not None:
69
  set_seed(args.seed)
 
71
  tokenizer = train_util.load_tokenizer(args)
72
 
73
  # データセットを準備する
74
+ if args.min_resolution:
75
+ args.min_resolution = tuple([int(r) for r in args.min_resolution.split(',')])
76
+ if len(args.min_resolution) == 1:
77
+ args.min_resolution = (args.min_resolution[0], args.min_resolution[0])
78
+ blueprint_generator = append_module.BlueprintGenerator(append_module.ConfigSanitizer(True, True, True))
 
 
 
 
 
 
 
 
79
  else:
80
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
81
+ if use_user_config:
82
+ print(f"Load dataset config from {args.dataset_config}")
83
+ user_config = config_util.load_user_config(args.dataset_config)
84
+ ignored = ["train_data_dir", "reg_data_dir", "in_json"]
85
+ if any(getattr(args, attr) is not None for attr in ignored):
86
+ print(
87
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
88
+ else:
89
+ if use_dreambooth_method:
90
+ print("Use DreamBooth method.")
91
+ user_config = {
92
+ "datasets": [{
93
+ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
94
+ }]
95
+ }
96
+ else:
97
+ print("Train with captions.")
98
+ user_config = {
99
+ "datasets": [{
100
+ "subsets": [{
101
+ "image_dir": args.train_data_dir,
102
+ "metadata_file": args.in_json,
103
+ }]
104
+ }]
105
+ }
106
+
107
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
108
+ if args.min_resolution:
109
+ train_dataset_group = append_module.generate_dataset_group_by_blueprint(blueprint.dataset_group)
110
+ else:
111
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
112
 
113
  if args.debug_dataset:
114
+ train_util.debug_dataset(train_dataset_group)
115
  return
116
+ if len(train_dataset_group) == 0:
117
  print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
118
  return
119
 
120
+ if cache_latents:
121
+ assert train_dataset_group.is_latent_cacheable(
122
+ ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
123
+
124
  # acceleratorを準備する
125
  print("prepare accelerator")
126
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
 
130
 
131
  # モデルを読み込む
132
  text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
133
+
134
+ # work on low-ram device
135
+ if args.lowram:
136
+ text_encoder.to("cuda")
137
+ unet.to("cuda")
138
+
139
  # モデルに xformers とか memory efficient attention を組み込む
140
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
141
 
 
145
  vae.requires_grad_(False)
146
  vae.eval()
147
  with torch.no_grad():
148
+ train_dataset_group.cache_latents(vae)
149
  vae.to("cpu")
150
  if torch.cuda.is_available():
151
  torch.cuda.empty_cache()
 
181
 
182
  # 学習に必要なクラスを準備する
183
  print("prepare optimizer, data loader etc.")
184
+ split_flag = (args.split_lora_networks) or ((not args.network_train_text_encoder_only) and (not args.network_train_unet_only))
185
+
186
+ used_names = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  if args.split_lora_networks:
188
+ lr_dic, block_args_dic = append_module.create_lr_blocks(args.blocks_lr_setting, args.block_optim_args)
189
  lora_names = append_module.create_split_names(args.split_lora_networks, args.split_lora_level)
190
  append_module.replace_prepare_optimizer_params(network)
191
+ trainable_params, adafactor_scheduler_arg, used_names = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, lora_names, lr_dic, block_args_dic)
192
  else:
193
  trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
194
+ if split_flag:
195
+ _t_lr = 0.
196
+ _u_lr = 0.
197
+ if args.text_encoder_lr:
198
+ _t_lr = args.text_encoder_lr
199
+ if args.unet_lr:
200
+ _u_lr = args.unet_lr
201
+ adafactor_scheduler_arg = {"initial_lr": [_t_lr, _u_lr]}
202
+
203
+ optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
204
+ if args.use_lookahead:
205
+ try:
206
+ import torch_optimizer
207
+ lookahed_arg = {"k": 5, "alpha": 0.5}
208
+ if args.lookahead_arg is not None:
209
+ for _arg in args.lookahead_arg:
210
+ k, v = _arg.split("=")
211
+ if k == "k":
212
+ lookahed_arg[k] = int(v)
213
+ else:
214
+ lookahed_arg[k] = float(v)
215
+ optimizer = torch_optimizer.Lookahead(optimizer, **lookahed_arg)
216
+ except:
217
+ print("\n============\ntorch_optimizerのimportに失敗しました Lookaheadを無効化して処理を続けます\n============\n")
218
  # dataloaderを準備する
219
  # DataLoaderのプロセス数:0はメインプロセスになる
220
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
221
  train_dataloader = torch.utils.data.DataLoader(
222
+ train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
223
 
224
  # 学習ステップ数を計算する
225
  if args.max_train_epochs is not None:
 
227
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
228
 
229
  # lr schedulerを用意する
230
+ if args.lr_scheduler.startswith("adafactor") and split_flag:
231
+ lr_scheduler = append_module.get_scheduler_Adafactor(args.lr_scheduler, optimizer, adafactor_scheduler_arg)
 
 
 
 
232
  else:
233
+ lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
234
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
235
+ num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
236
+
 
237
  #追加機能の設定をコメントに追記して残す
238
+ if args.use_lookahead:
239
+ args.training_comment=f"{args.training_comment} use Lookahead: True Lookahead args: {lookahed_arg}"
240
+ if args.split_lora_networks:
241
+ args.training_comment=f"{args.training_comment} split_lora_networks: {args.split_lora_networks} split_level: {args.split_lora_level}"
242
  if args.min_resolution:
243
  args.training_comment=f"{args.training_comment} min_resolution: {args.min_resolution} area_step: {args.area_step}"
244
 
 
308
  args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
309
 
310
  # 学習する
311
+ # TODO: find a way to handle total batch size when there are multiple datasets
312
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
313
  print("running training / 学習開始")
314
+ print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
315
+ print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
316
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
317
  print(f" num epochs / epoch数: {num_train_epochs}")
318
+ print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
319
+ # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
320
  print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
321
  print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
322
 
323
+ # TODO refactor metadata creation and move to util
324
  metadata = {
325
  "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
326
  "ss_training_started_at": training_started_at, # unix timestamp
 
328
  "ss_learning_rate": args.learning_rate,
329
  "ss_text_encoder_lr": args.text_encoder_lr,
330
  "ss_unet_lr": args.unet_lr,
331
+ "ss_num_train_images": train_dataset_group.num_train_images,
332
+ "ss_num_reg_images": train_dataset_group.num_reg_images,
333
  "ss_num_batches_per_epoch": len(train_dataloader),
334
  "ss_num_epochs": num_train_epochs,
 
 
335
  "ss_gradient_checkpointing": args.gradient_checkpointing,
336
  "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
337
  "ss_max_train_steps": args.max_train_steps,
 
343
  "ss_mixed_precision": args.mixed_precision,
344
  "ss_full_fp16": bool(args.full_fp16),
345
  "ss_v2": bool(args.v2),
 
346
  "ss_clip_skip": args.clip_skip,
347
  "ss_max_token_length": args.max_token_length,
 
 
 
 
348
  "ss_cache_latents": bool(args.cache_latents),
 
 
 
349
  "ss_seed": args.seed,
350
+ "ss_lowram": args.lowram,
351
  "ss_noise_offset": args.noise_offset,
 
 
 
 
352
  "ss_training_comment": args.training_comment, # will not be updated after training
353
+ "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
354
+ "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
355
+ "ss_max_grad_norm": args.max_grad_norm,
356
+ "ss_caption_dropout_rate": args.caption_dropout_rate,
357
+ "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
358
+ "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
359
+ "ss_face_crop_aug_range": args.face_crop_aug_range,
360
+ "ss_prior_loss_weight": args.prior_loss_weight,
361
  }
362
 
363
+ if use_user_config:
364
+ # save metadata of multiple datasets
365
+ # NOTE: pack "ss_datasets" value as json one time
366
+ # or should also pack nested collections as json?
367
+ datasets_metadata = []
368
+ tag_frequency = {} # merge tag frequency for metadata editor
369
+ dataset_dirs_info = {} # merge subset dirs for metadata editor
370
+
371
+ for dataset in train_dataset_group.datasets:
372
+ is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
373
+ dataset_metadata = {
374
+ "is_dreambooth": is_dreambooth_dataset,
375
+ "batch_size_per_device": dataset.batch_size,
376
+ "num_train_images": dataset.num_train_images, # includes repeating
377
+ "num_reg_images": dataset.num_reg_images,
378
+ "resolution": (dataset.width, dataset.height),
379
+ "enable_bucket": bool(dataset.enable_bucket),
380
+ "min_bucket_reso": dataset.min_bucket_reso,
381
+ "max_bucket_reso": dataset.max_bucket_reso,
382
+ "tag_frequency": dataset.tag_frequency,
383
+ "bucket_info": dataset.bucket_info,
384
+ }
385
+
386
+ subsets_metadata = []
387
+ for subset in dataset.subsets:
388
+ subset_metadata = {
389
+ "img_count": subset.img_count,
390
+ "num_repeats": subset.num_repeats,
391
+ "color_aug": bool(subset.color_aug),
392
+ "flip_aug": bool(subset.flip_aug),
393
+ "random_crop": bool(subset.random_crop),
394
+ "shuffle_caption": bool(subset.shuffle_caption),
395
+ "keep_tokens": subset.keep_tokens,
396
+ }
397
+
398
+ image_dir_or_metadata_file = None
399
+ if subset.image_dir:
400
+ image_dir = os.path.basename(subset.image_dir)
401
+ subset_metadata["image_dir"] = image_dir
402
+ image_dir_or_metadata_file = image_dir
403
+
404
+ if is_dreambooth_dataset:
405
+ subset_metadata["class_tokens"] = subset.class_tokens
406
+ subset_metadata["is_reg"] = subset.is_reg
407
+ if subset.is_reg:
408
+ image_dir_or_metadata_file = None # not merging reg dataset
409
+ else:
410
+ metadata_file = os.path.basename(subset.metadata_file)
411
+ subset_metadata["metadata_file"] = metadata_file
412
+ image_dir_or_metadata_file = metadata_file # may overwrite
413
+
414
+ subsets_metadata.append(subset_metadata)
415
+
416
+ # merge dataset dir: not reg subset only
417
+ # TODO update additional-network extension to show detailed dataset config from metadata
418
+ if image_dir_or_metadata_file is not None:
419
+ # datasets may have a certain dir multiple times
420
+ v = image_dir_or_metadata_file
421
+ i = 2
422
+ while v in dataset_dirs_info:
423
+ v = image_dir_or_metadata_file + f" ({i})"
424
+ i += 1
425
+ image_dir_or_metadata_file = v
426
+
427
+ dataset_dirs_info[image_dir_or_metadata_file] = {
428
+ "n_repeats": subset.num_repeats,
429
+ "img_count": subset.img_count
430
+ }
431
+
432
+ dataset_metadata["subsets"] = subsets_metadata
433
+ datasets_metadata.append(dataset_metadata)
434
+
435
+ # merge tag frequency:
436
+ for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
437
+ # あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
438
+ # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
439
+ # なので、ここで複数datasetの回数を合算してもあまり意味はない
440
+ if ds_dir_name in tag_frequency:
441
+ continue
442
+ tag_frequency[ds_dir_name] = ds_freq_for_dir
443
+
444
+ metadata["ss_datasets"] = json.dumps(datasets_metadata)
445
+ metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
446
+ metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
447
+ else:
448
+ # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
449
+ assert len(
450
+ train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
451
+
452
+ dataset = train_dataset_group.datasets[0]
453
+
454
+ dataset_dirs_info = {}
455
+ reg_dataset_dirs_info = {}
456
+ if use_dreambooth_method:
457
+ for subset in dataset.subsets:
458
+ info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
459
+ info[os.path.basename(subset.image_dir)] = {
460
+ "n_repeats": subset.num_repeats,
461
+ "img_count": subset.img_count
462
+ }
463
+ else:
464
+ for subset in dataset.subsets:
465
+ dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
466
+ "n_repeats": subset.num_repeats,
467
+ "img_count": subset.img_count
468
+ }
469
+
470
+ metadata.update({
471
+ "ss_batch_size_per_device": args.train_batch_size,
472
+ "ss_total_batch_size": total_batch_size,
473
+ "ss_resolution": args.resolution,
474
+ "ss_color_aug": bool(args.color_aug),
475
+ "ss_flip_aug": bool(args.flip_aug),
476
+ "ss_random_crop": bool(args.random_crop),
477
+ "ss_shuffle_caption": bool(args.shuffle_caption),
478
+ "ss_enable_bucket": bool(dataset.enable_bucket),
479
+ "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
480
+ "ss_min_bucket_reso": dataset.min_bucket_reso,
481
+ "ss_max_bucket_reso": dataset.max_bucket_reso,
482
+ "ss_keep_tokens": args.keep_tokens,
483
+ "ss_dataset_dirs": json.dumps(dataset_dirs_info),
484
+ "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
485
+ "ss_tag_frequency": json.dumps(dataset.tag_frequency),
486
+ "ss_bucket_info": json.dumps(dataset.bucket_info),
487
+ })
488
+
489
  # uncomment if another network is added
490
  # for key, value in net_kwargs.items():
491
  # metadata["ss_arg_" + key] = value
 
521
  loss_total = 0.0
522
  for epoch in range(num_train_epochs):
523
  print(f"epoch {epoch+1}/{num_train_epochs}")
524
+ train_dataset_group.set_current_epoch(epoch + 1)
525
 
526
  metadata["ss_epoch"] = str(epoch+1)
527
 
 
558
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
559
 
560
  # Predict the noise residual
561
+ with accelerator.autocast():
562
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
563
 
564
  if args.v_parameterization:
 
576
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
577
 
578
  accelerator.backward(loss)
579
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
580
  params_to_clip = network.get_trainable_params()
581
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
582
 
583
+ scale = accelerator.scaler.get_scale()
584
  optimizer.step()
585
+ if args.lr_scheduler.startswith("adafactor"):
586
+ skip_lr_sched = (scale >= accelerator.scaler.get_scale())
587
+ else:
588
+ skip_lr_sched = True
589
+ if not skip_lr_sched:
590
+ lr_scheduler.step()
591
  optimizer.zero_grad(set_to_none=True)
592
 
593
  # Checks if the accelerator has performed an optimization step behind the scenes
 
595
  progress_bar.update(1)
596
  global_step += 1
597
 
598
+ train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
599
+
600
  current_loss = loss.detach().item()
601
  if epoch == 0:
602
  loss_list.append(current_loss)
 
609
  progress_bar.set_postfix(**logs)
610
 
611
  if args.logging_dir is not None:
612
+ logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, used_names)
613
  accelerator.log(logs, step=global_step)
614
 
615
  if global_step >= args.max_train_steps:
 
627
  def save_func():
628
  ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
629
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
630
+ metadata["ss_training_finished_at"] = str(time.time())
631
  print(f"saving checkpoint: {ckpt_file}")
632
  unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
633
 
 
642
  if saving and args.save_state:
643
  train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
644
 
645
+ train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
646
+
647
  # end of epoch
648
 
649
  metadata["ss_epoch"] = str(num_train_epochs)
650
+ metadata["ss_training_finished_at"] = str(time.time())
651
 
652
  is_main_process = accelerator.is_main_process
653
  if is_main_process:
 
678
  train_util.add_sd_models_arguments(parser)
679
  train_util.add_dataset_arguments(parser, True, True, True)
680
  train_util.add_training_arguments(parser, True)
681
+ train_util.add_optimizer_arguments(parser)
682
+ config_util.add_config_arguments(parser)
683
 
684
  parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
685
  parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
 
687
 
688
  parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
689
  parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
 
 
 
 
690
 
691
  parser.add_argument("--network_weights", type=str, default=None,
692
  help="pretrained weights for network / 学習するネットワークの初期重み")
 
706
  #Optimizer変更関連のオプション追加
707
  append_module.add_append_arguments(parser)
708
  args = append_module.get_config(parser)
709
+ #argsを保存する
710
+ import yaml
711
+ import datetime
712
+ _t = datetime.datetime.today().strftime('%Y%m%d_%H%M')
713
+ if args.output_name==None:
714
+ config_name = f"train_network_config_{_t}.yaml"
715
+ else:
716
+ config_name = f"train_network_config_{os.path.basename(args.output_name)}_{_t}.yaml"
717
+ print(f"{config_name} に設定を書き出し中...")
718
+ with open(config_name, mode="w") as f:
719
+ yaml.dump(args.__dict__, f, indent=4)
720
 
721
  if args.resolution==args.min_resolution:
722
  args.min_resolution=None
723
 
724
  train(args)
725
+ print("done!")
726
 
 
 
 
 
 
 
 
 
 
 
 
 
727
 
728
  '''
729
  optimizer設定メモ
730
+ torch_optimizer.AdaBelief
731
+ adastand.Adastand
732
  (optimizer_argから設定できるように変更するためのメモ)
733
 
734
  AdamWのweight_decay初期値は1e-2
 
758
  transformerベースのT5学習において最強とかいう噂のoptimizer
759
  huggingfaceのサンプルパラ
760
  eps=1e-30,1e-3 clip_threshold=1.0 decay_rate=-0.8 relative_step=False scale_parameter=False warmup_init=False
761
+ epsの二つ目の値1e-3が学習率に影響大きい
762
 
763
  AggMo
764
 
train_textual_inversion.py CHANGED
@@ -11,7 +11,11 @@ import diffusers
11
  from diffusers import DDPMScheduler
12
 
13
  import library.train_util as train_util
14
- from library.train_util import DreamBoothDataset, FineTuningDataset
 
 
 
 
15
 
16
  imagenet_templates_small = [
17
  "a photo of a {}",
@@ -79,7 +83,6 @@ def train(args):
79
  train_util.prepare_dataset_args(args, True)
80
 
81
  cache_latents = args.cache_latents
82
- use_dreambooth_method = args.in_json is None
83
 
84
  if args.seed is not None:
85
  set_seed(args.seed)
@@ -139,21 +142,35 @@ def train(args):
139
  print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
140
 
141
  # データセットを準備する
142
- if use_dreambooth_method:
143
- print("Use DreamBooth method.")
144
- train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
145
- tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
146
- args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
147
- args.bucket_reso_steps, args.bucket_no_upscale,
148
- args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
149
  else:
150
- print("Train with captions.")
151
- train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
152
- tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
153
- args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
154
- args.bucket_reso_steps, args.bucket_no_upscale,
155
- args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
156
- args.dataset_repeats, args.debug_dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
159
  if use_template:
@@ -163,20 +180,25 @@ def train(args):
163
  captions = []
164
  for tmpl in templates:
165
  captions.append(tmpl.format(replace_to))
166
- train_dataset.add_replacement("", captions)
167
- elif args.num_vectors_per_token > 1:
168
- replace_to = " ".join(token_strings)
169
- train_dataset.add_replacement(args.token_string, replace_to)
170
-
171
- train_dataset.make_buckets()
 
 
172
 
173
  if args.debug_dataset:
174
- train_util.debug_dataset(train_dataset, show_input_ids=True)
175
  return
176
- if len(train_dataset) == 0:
177
  print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
178
  return
179
 
 
 
 
180
  # モデルに xformers とか memory efficient attention を組み込む
181
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
182
 
@@ -186,7 +208,7 @@ def train(args):
186
  vae.requires_grad_(False)
187
  vae.eval()
188
  with torch.no_grad():
189
- train_dataset.cache_latents(vae)
190
  vae.to("cpu")
191
  if torch.cuda.is_available():
192
  torch.cuda.empty_cache()
@@ -198,35 +220,14 @@ def train(args):
198
 
199
  # 学習に必要なクラスを準備する
200
  print("prepare optimizer, data loader etc.")
201
-
202
- # 8-bit Adamを使う
203
- if args.use_8bit_adam:
204
- try:
205
- import bitsandbytes as bnb
206
- except ImportError:
207
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
208
- print("use 8-bit Adam optimizer")
209
- optimizer_class = bnb.optim.AdamW8bit
210
- elif args.use_lion_optimizer:
211
- try:
212
- import lion_pytorch
213
- except ImportError:
214
- raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
215
- print("use Lion optimizer")
216
- optimizer_class = lion_pytorch.Lion
217
- else:
218
- optimizer_class = torch.optim.AdamW
219
-
220
  trainable_params = text_encoder.get_input_embeddings().parameters()
221
-
222
- # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
223
- optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
224
 
225
  # dataloaderを準備する
226
  # DataLoaderのプロセス数:0はメインプロセスになる
227
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
228
  train_dataloader = torch.utils.data.DataLoader(
229
- train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
230
 
231
  # 学習ステップ数を計算する
232
  if args.max_train_epochs is not None:
@@ -234,8 +235,9 @@ def train(args):
234
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
235
 
236
  # lr schedulerを用意する
237
- lr_scheduler = diffusers.optimization.get_scheduler(
238
- args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
 
239
 
240
  # acceleratorがなんかよろしくやってくれるらしい
241
  text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
@@ -283,8 +285,8 @@ def train(args):
283
  # 学習する
284
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
285
  print("running training / 学習開始")
286
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
287
- print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
288
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
289
  print(f" num epochs / epoch数: {num_train_epochs}")
290
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
@@ -303,12 +305,11 @@ def train(args):
303
 
304
  for epoch in range(num_train_epochs):
305
  print(f"epoch {epoch+1}/{num_train_epochs}")
306
- train_dataset.set_current_epoch(epoch + 1)
307
 
308
  text_encoder.train()
309
 
310
  loss_total = 0
311
- bef_epo_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
312
  for step, batch in enumerate(train_dataloader):
313
  with accelerator.accumulate(text_encoder):
314
  with torch.no_grad():
@@ -357,9 +358,9 @@ def train(args):
357
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
358
 
359
  accelerator.backward(loss)
360
- if accelerator.sync_gradients:
361
  params_to_clip = text_encoder.get_input_embeddings().parameters()
362
- accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
363
 
364
  optimizer.step()
365
  lr_scheduler.step()
@@ -374,9 +375,14 @@ def train(args):
374
  progress_bar.update(1)
375
  global_step += 1
376
 
 
 
 
377
  current_loss = loss.detach().item()
378
  if args.logging_dir is not None:
379
- logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
 
 
380
  accelerator.log(logs, step=global_step)
381
 
382
  loss_total += current_loss
@@ -394,8 +400,6 @@ def train(args):
394
  accelerator.wait_for_everyone()
395
 
396
  updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
397
- # d = updated_embs - bef_epo_embs
398
- # print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
399
 
400
  if args.save_every_n_epochs is not None:
401
  model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
@@ -417,6 +421,9 @@ def train(args):
417
  if saving and args.save_state:
418
  train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
419
 
 
 
 
420
  # end of epoch
421
 
422
  is_main_process = accelerator.is_main_process
@@ -491,6 +498,8 @@ if __name__ == '__main__':
491
  train_util.add_sd_models_arguments(parser)
492
  train_util.add_dataset_arguments(parser, True, True, False)
493
  train_util.add_training_arguments(parser, True)
 
 
494
 
495
  parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
496
  help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
 
11
  from diffusers import DDPMScheduler
12
 
13
  import library.train_util as train_util
14
+ import library.config_util as config_util
15
+ from library.config_util import (
16
+ ConfigSanitizer,
17
+ BlueprintGenerator,
18
+ )
19
 
20
  imagenet_templates_small = [
21
  "a photo of a {}",
 
83
  train_util.prepare_dataset_args(args, True)
84
 
85
  cache_latents = args.cache_latents
 
86
 
87
  if args.seed is not None:
88
  set_seed(args.seed)
 
142
  print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
143
 
144
  # データセットを準備する
145
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
146
+ if args.dataset_config is not None:
147
+ print(f"Load dataset config from {args.dataset_config}")
148
+ user_config = config_util.load_user_config(args.dataset_config)
149
+ ignored = ["train_data_dir", "reg_data_dir", "in_json"]
150
+ if any(getattr(args, attr) is not None for attr in ignored):
151
+ print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
152
  else:
153
+ use_dreambooth_method = args.in_json is None
154
+ if use_dreambooth_method:
155
+ print("Use DreamBooth method.")
156
+ user_config = {
157
+ "datasets": [{
158
+ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
159
+ }]
160
+ }
161
+ else:
162
+ print("Train with captions.")
163
+ user_config = {
164
+ "datasets": [{
165
+ "subsets": [{
166
+ "image_dir": args.train_data_dir,
167
+ "metadata_file": args.in_json,
168
+ }]
169
+ }]
170
+ }
171
+
172
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
173
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
174
 
175
  # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
176
  if use_template:
 
180
  captions = []
181
  for tmpl in templates:
182
  captions.append(tmpl.format(replace_to))
183
+ train_dataset_group.add_replacement("", captions)
184
+ else:
185
+ if args.num_vectors_per_token > 1:
186
+ replace_to = " ".join(token_strings)
187
+ train_dataset_group.add_replacement(args.token_string, replace_to)
188
+ prompt_replacement = (args.token_string, replace_to)
189
+ else:
190
+ prompt_replacement = None
191
 
192
  if args.debug_dataset:
193
+ train_util.debug_dataset(train_dataset_group, show_input_ids=True)
194
  return
195
+ if len(train_dataset_group) == 0:
196
  print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
197
  return
198
 
199
+ if cache_latents:
200
+ assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
201
+
202
  # モデルに xformers とか memory efficient attention を組み込む
203
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
204
 
 
208
  vae.requires_grad_(False)
209
  vae.eval()
210
  with torch.no_grad():
211
+ train_dataset_group.cache_latents(vae)
212
  vae.to("cpu")
213
  if torch.cuda.is_available():
214
  torch.cuda.empty_cache()
 
220
 
221
  # 学習に必要なクラスを準備する
222
  print("prepare optimizer, data loader etc.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  trainable_params = text_encoder.get_input_embeddings().parameters()
224
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params)
 
 
225
 
226
  # dataloaderを準備する
227
  # DataLoaderのプロセス数:0はメインプロセスになる
228
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
229
  train_dataloader = torch.utils.data.DataLoader(
230
+ train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
231
 
232
  # 学習ステップ数を計算する
233
  if args.max_train_epochs is not None:
 
235
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
236
 
237
  # lr schedulerを用意する
238
+ lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
239
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
240
+ num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
241
 
242
  # acceleratorがなんかよろしくやってくれるらしい
243
  text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
 
285
  # 学習する
286
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
287
  print("running training / 学習開始")
288
+ print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
289
+ print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
290
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
291
  print(f" num epochs / epoch数: {num_train_epochs}")
292
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
 
305
 
306
  for epoch in range(num_train_epochs):
307
  print(f"epoch {epoch+1}/{num_train_epochs}")
308
+ train_dataset_group.set_current_epoch(epoch + 1)
309
 
310
  text_encoder.train()
311
 
312
  loss_total = 0
 
313
  for step, batch in enumerate(train_dataloader):
314
  with accelerator.accumulate(text_encoder):
315
  with torch.no_grad():
 
358
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
359
 
360
  accelerator.backward(loss)
361
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
362
  params_to_clip = text_encoder.get_input_embeddings().parameters()
363
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
364
 
365
  optimizer.step()
366
  lr_scheduler.step()
 
375
  progress_bar.update(1)
376
  global_step += 1
377
 
378
+ train_util.sample_images(accelerator, args, None, global_step, accelerator.device,
379
+ vae, tokenizer, text_encoder, unet, prompt_replacement)
380
+
381
  current_loss = loss.detach().item()
382
  if args.logging_dir is not None:
383
+ logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
384
+ if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
385
+ logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
386
  accelerator.log(logs, step=global_step)
387
 
388
  loss_total += current_loss
 
400
  accelerator.wait_for_everyone()
401
 
402
  updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
 
 
403
 
404
  if args.save_every_n_epochs is not None:
405
  model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
 
421
  if saving and args.save_state:
422
  train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
423
 
424
+ train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device,
425
+ vae, tokenizer, text_encoder, unet, prompt_replacement)
426
+
427
  # end of epoch
428
 
429
  is_main_process = accelerator.is_main_process
 
498
  train_util.add_sd_models_arguments(parser)
499
  train_util.add_dataset_arguments(parser, True, True, False)
500
  train_util.add_training_arguments(parser, True)
501
+ train_util.add_optimizer_arguments(parser)
502
+ config_util.add_config_arguments(parser)
503
 
504
  parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
505
  help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")