abc commited on
Commit
0d34cbd
·
1 Parent(s): b9a80b5

Upload 44 files

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ bitsandbytes_windows/libbitsandbytes_cuda116.dll filter=lfs diff=lfs merge=lfs -text
fine_tune.py CHANGED
@@ -13,7 +13,11 @@ import diffusers
13
  from diffusers import DDPMScheduler
14
 
15
  import library.train_util as train_util
16
-
 
 
 
 
17
 
18
  def collate_fn(examples):
19
  return examples[0]
@@ -30,25 +34,36 @@ def train(args):
30
 
31
  tokenizer = train_util.load_tokenizer(args)
32
 
33
- train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
34
- tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
35
- args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
36
- args.bucket_reso_steps, args.bucket_no_upscale,
37
- args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
38
- args.dataset_repeats, args.debug_dataset)
39
-
40
- # 学習データのdropout率を設定する
41
- train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
42
-
43
- train_dataset.make_buckets()
 
 
 
 
 
 
 
 
44
 
45
  if args.debug_dataset:
46
- train_util.debug_dataset(train_dataset)
47
  return
48
- if len(train_dataset) == 0:
49
  print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
50
  return
51
 
 
 
 
52
  # acceleratorを準備する
53
  print("prepare accelerator")
54
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
@@ -109,7 +124,7 @@ def train(args):
109
  vae.requires_grad_(False)
110
  vae.eval()
111
  with torch.no_grad():
112
- train_dataset.cache_latents(vae)
113
  vae.to("cpu")
114
  if torch.cuda.is_available():
115
  torch.cuda.empty_cache()
@@ -149,33 +164,13 @@ def train(args):
149
 
150
  # 学習に必要なクラスを準備する
151
  print("prepare optimizer, data loader etc.")
152
-
153
- # 8-bit Adamを使う
154
- if args.use_8bit_adam:
155
- try:
156
- import bitsandbytes as bnb
157
- except ImportError:
158
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
159
- print("use 8-bit Adam optimizer")
160
- optimizer_class = bnb.optim.AdamW8bit
161
- elif args.use_lion_optimizer:
162
- try:
163
- import lion_pytorch
164
- except ImportError:
165
- raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
166
- print("use Lion optimizer")
167
- optimizer_class = lion_pytorch.Lion
168
- else:
169
- optimizer_class = torch.optim.AdamW
170
-
171
- # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
172
- optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
173
 
174
  # dataloaderを準備する
175
  # DataLoaderのプロセス数:0はメインプロセスになる
176
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
177
  train_dataloader = torch.utils.data.DataLoader(
178
- train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
179
 
180
  # 学習ステップ数を計算する
181
  if args.max_train_epochs is not None:
@@ -183,8 +178,9 @@ def train(args):
183
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
184
 
185
  # lr schedulerを用意する
186
- lr_scheduler = diffusers.optimization.get_scheduler(
187
- args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
 
188
 
189
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
190
  if args.full_fp16:
@@ -218,7 +214,7 @@ def train(args):
218
  # 学習する
219
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
220
  print("running training / 学習開始")
221
- print(f" num examples / サンプル数: {train_dataset.num_train_images}")
222
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
223
  print(f" num epochs / epoch数: {num_train_epochs}")
224
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
@@ -237,7 +233,7 @@ def train(args):
237
 
238
  for epoch in range(num_train_epochs):
239
  print(f"epoch {epoch+1}/{num_train_epochs}")
240
- train_dataset.set_current_epoch(epoch + 1)
241
 
242
  for m in training_models:
243
  m.train()
@@ -286,11 +282,11 @@ def train(args):
286
  loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
287
 
288
  accelerator.backward(loss)
289
- if accelerator.sync_gradients:
290
  params_to_clip = []
291
  for m in training_models:
292
  params_to_clip.extend(m.parameters())
293
- accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
294
 
295
  optimizer.step()
296
  lr_scheduler.step()
@@ -301,11 +297,16 @@ def train(args):
301
  progress_bar.update(1)
302
  global_step += 1
303
 
 
 
304
  current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
305
  if args.logging_dir is not None:
306
- logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
 
 
307
  accelerator.log(logs, step=global_step)
308
 
 
309
  loss_total += current_loss
310
  avr_loss = loss_total / (step+1)
311
  logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
@@ -315,7 +316,7 @@ def train(args):
315
  break
316
 
317
  if args.logging_dir is not None:
318
- logs = {"epoch_loss": loss_total / len(train_dataloader)}
319
  accelerator.log(logs, step=epoch+1)
320
 
321
  accelerator.wait_for_everyone()
@@ -325,6 +326,8 @@ def train(args):
325
  train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
326
  save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
327
 
 
 
328
  is_main_process = accelerator.is_main_process
329
  if is_main_process:
330
  unet = unwrap_model(unet)
@@ -351,6 +354,8 @@ if __name__ == '__main__':
351
  train_util.add_dataset_arguments(parser, False, True, True)
352
  train_util.add_training_arguments(parser, False)
353
  train_util.add_sd_saving_arguments(parser)
 
 
354
 
355
  parser.add_argument("--diffusers_xformers", action='store_true',
356
  help='use xformers by diffusers / Diffusersでxformersを使用する')
 
13
  from diffusers import DDPMScheduler
14
 
15
  import library.train_util as train_util
16
+ import library.config_util as config_util
17
+ from library.config_util import (
18
+ ConfigSanitizer,
19
+ BlueprintGenerator,
20
+ )
21
 
22
  def collate_fn(examples):
23
  return examples[0]
 
34
 
35
  tokenizer = train_util.load_tokenizer(args)
36
 
37
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
38
+ if args.dataset_config is not None:
39
+ print(f"Load dataset config from {args.dataset_config}")
40
+ user_config = config_util.load_user_config(args.dataset_config)
41
+ ignored = ["train_data_dir", "in_json"]
42
+ if any(getattr(args, attr) is not None for attr in ignored):
43
+ print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
44
+ else:
45
+ user_config = {
46
+ "datasets": [{
47
+ "subsets": [{
48
+ "image_dir": args.train_data_dir,
49
+ "metadata_file": args.in_json,
50
+ }]
51
+ }]
52
+ }
53
+
54
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
55
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
56
 
57
  if args.debug_dataset:
58
+ train_util.debug_dataset(train_dataset_group)
59
  return
60
+ if len(train_dataset_group) == 0:
61
  print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
62
  return
63
 
64
+ if cache_latents:
65
+ assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
66
+
67
  # acceleratorを準備する
68
  print("prepare accelerator")
69
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
 
124
  vae.requires_grad_(False)
125
  vae.eval()
126
  with torch.no_grad():
127
+ train_dataset_group.cache_latents(vae)
128
  vae.to("cpu")
129
  if torch.cuda.is_available():
130
  torch.cuda.empty_cache()
 
164
 
165
  # 学習に必要なクラスを準備する
166
  print("prepare optimizer, data loader etc.")
167
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  # dataloaderを準備する
170
  # DataLoaderのプロセス数:0はメインプロセスになる
171
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
172
  train_dataloader = torch.utils.data.DataLoader(
173
+ train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
174
 
175
  # 学習ステップ数を計算する
176
  if args.max_train_epochs is not None:
 
178
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
179
 
180
  # lr schedulerを用意する
181
+ lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
182
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
183
+ num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
184
 
185
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
186
  if args.full_fp16:
 
214
  # 学習する
215
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
216
  print("running training / 学習開始")
217
+ print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
218
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
219
  print(f" num epochs / epoch数: {num_train_epochs}")
220
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
 
233
 
234
  for epoch in range(num_train_epochs):
235
  print(f"epoch {epoch+1}/{num_train_epochs}")
236
+ train_dataset_group.set_current_epoch(epoch + 1)
237
 
238
  for m in training_models:
239
  m.train()
 
282
  loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
283
 
284
  accelerator.backward(loss)
285
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
286
  params_to_clip = []
287
  for m in training_models:
288
  params_to_clip.extend(m.parameters())
289
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
290
 
291
  optimizer.step()
292
  lr_scheduler.step()
 
297
  progress_bar.update(1)
298
  global_step += 1
299
 
300
+ train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
301
+
302
  current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
303
  if args.logging_dir is not None:
304
+ logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
305
+ if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
306
+ logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
307
  accelerator.log(logs, step=global_step)
308
 
309
+ # TODO moving averageにする
310
  loss_total += current_loss
311
  avr_loss = loss_total / (step+1)
312
  logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
 
316
  break
317
 
318
  if args.logging_dir is not None:
319
+ logs = {"loss/epoch": loss_total / len(train_dataloader)}
320
  accelerator.log(logs, step=epoch+1)
321
 
322
  accelerator.wait_for_everyone()
 
326
  train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
327
  save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
328
 
329
+ train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
330
+
331
  is_main_process = accelerator.is_main_process
332
  if is_main_process:
333
  unet = unwrap_model(unet)
 
354
  train_util.add_dataset_arguments(parser, False, True, True)
355
  train_util.add_training_arguments(parser, False)
356
  train_util.add_sd_saving_arguments(parser)
357
+ train_util.add_optimizer_arguments(parser)
358
+ config_util.add_config_arguments(parser)
359
 
360
  parser.add_argument("--diffusers_xformers", action='store_true',
361
  help='use xformers by diffusers / Diffusersでxformersを使用する')
gen_img_diffusers.py CHANGED
@@ -47,7 +47,7 @@ VGG(
47
  """
48
 
49
  import json
50
- from typing import List, Optional, Union
51
  import glob
52
  import importlib
53
  import inspect
@@ -60,7 +60,6 @@ import math
60
  import os
61
  import random
62
  import re
63
- from typing import Any, Callable, List, Optional, Union
64
 
65
  import diffusers
66
  import numpy as np
@@ -81,6 +80,9 @@ from PIL import Image
81
  from PIL.PngImagePlugin import PngInfo
82
 
83
  import library.model_util as model_util
 
 
 
84
 
85
  # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
86
  TOKENIZER_PATH = "openai/clip-vit-large-patch14"
@@ -487,6 +489,9 @@ class PipelineLike():
487
  self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
488
  self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
489
 
 
 
 
490
  # Textual Inversion
491
  def add_token_replacement(self, target_token_id, rep_token_ids):
492
  self.token_replacements[target_token_id] = rep_token_ids
@@ -500,7 +505,11 @@ class PipelineLike():
500
  new_tokens.append(token)
501
  return new_tokens
502
 
 
 
 
503
  # region xformersとか使う部分:独自に書き換えるので関係なし
 
504
  def enable_xformers_memory_efficient_attention(self):
505
  r"""
506
  Enable memory efficient attention as implemented in xformers.
@@ -581,6 +590,8 @@ class PipelineLike():
581
  latents: Optional[torch.FloatTensor] = None,
582
  max_embeddings_multiples: Optional[int] = 3,
583
  output_type: Optional[str] = "pil",
 
 
584
  # return_dict: bool = True,
585
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
586
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
@@ -672,6 +683,9 @@ class PipelineLike():
672
  else:
673
  raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
674
 
 
 
 
675
  if strength < 0 or strength > 1:
676
  raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
677
 
@@ -752,7 +766,7 @@ class PipelineLike():
752
  text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
753
  text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
754
 
755
- if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None:
756
  if isinstance(clip_guide_images, PIL.Image.Image):
757
  clip_guide_images = [clip_guide_images]
758
 
@@ -765,7 +779,7 @@ class PipelineLike():
765
  image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
766
  if len(image_embeddings_clip) == 1:
767
  image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
768
- else:
769
  size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
770
  clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
771
  clip_guide_images = torch.cat(clip_guide_images, dim=0)
@@ -774,6 +788,10 @@ class PipelineLike():
774
  image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
775
  if len(image_embeddings_vgg16) == 1:
776
  image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
 
 
 
 
777
 
778
  # set timesteps
779
  self.scheduler.set_timesteps(num_inference_steps, self.device)
@@ -781,7 +799,6 @@ class PipelineLike():
781
  latents_dtype = text_embeddings.dtype
782
  init_latents_orig = None
783
  mask = None
784
- noise = None
785
 
786
  if init_image is None:
787
  # get the initial random noise unless the user supplied it
@@ -813,6 +830,8 @@ class PipelineLike():
813
  if isinstance(init_image[0], PIL.Image.Image):
814
  init_image = [preprocess_image(im) for im in init_image]
815
  init_image = torch.cat(init_image)
 
 
816
 
817
  # mask image to tensor
818
  if mask_image is not None:
@@ -823,9 +842,24 @@ class PipelineLike():
823
 
824
  # encode the init image into latents and scale the latents
825
  init_image = init_image.to(device=self.device, dtype=latents_dtype)
826
- init_latent_dist = self.vae.encode(init_image).latent_dist
827
- init_latents = init_latent_dist.sample(generator=generator)
828
- init_latents = 0.18215 * init_latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
829
  if len(init_latents) == 1:
830
  init_latents = init_latents.repeat((batch_size, 1, 1, 1))
831
  init_latents_orig = init_latents
@@ -864,12 +898,21 @@ class PipelineLike():
864
  extra_step_kwargs["eta"] = eta
865
 
866
  num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
 
 
 
 
867
  for i, t in enumerate(tqdm(timesteps)):
868
  # expand the latents if we are doing classifier free guidance
869
  latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
870
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
871
  # predict the noise residual
872
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
 
 
 
 
873
 
874
  # perform guidance
875
  if do_classifier_free_guidance:
@@ -911,8 +954,19 @@ class PipelineLike():
911
  if is_cancelled_callback is not None and is_cancelled_callback():
912
  return None
913
 
 
 
 
914
  latents = 1 / 0.18215 * latents
915
- image = self.vae.decode(latents).sample
 
 
 
 
 
 
 
 
916
 
917
  image = (image / 2 + 0.5).clamp(0, 1)
918
 
@@ -1595,10 +1649,11 @@ def get_unweighted_text_embeddings(
1595
  if pad == eos: # v1
1596
  text_input_chunk[:, -1] = text_input[0, -1]
1597
  else: # v2
1598
- if text_input_chunk[:, -1] != eos and text_input_chunk[:, -1] != pad: # 最後に普通の文字がある
1599
- text_input_chunk[:, -1] = eos
1600
- if text_input_chunk[:, 1] == pad: # BOSだけであとはPAD
1601
- text_input_chunk[:, 1] = eos
 
1602
 
1603
  if clip_skip is None or clip_skip == 1:
1604
  text_embedding = pipe.text_encoder(text_input_chunk)[0]
@@ -1799,7 +1854,7 @@ def preprocess_mask(mask):
1799
  mask = mask.convert("L")
1800
  w, h = mask.size
1801
  w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
1802
- mask = mask.resize((w // 8, h // 8), resample=PIL.Image.LANCZOS)
1803
  mask = np.array(mask).astype(np.float32) / 255.0
1804
  mask = np.tile(mask, (4, 1, 1))
1805
  mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
@@ -1817,6 +1872,35 @@ def preprocess_mask(mask):
1817
  # return text_encoder
1818
 
1819
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1820
  def main(args):
1821
  if args.fp16:
1822
  dtype = torch.float16
@@ -1881,10 +1965,7 @@ def main(args):
1881
  # tokenizerを読み込む
1882
  print("loading tokenizer")
1883
  if use_stable_diffusion_format:
1884
- if args.v2:
1885
- tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
1886
- else:
1887
- tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
1888
 
1889
  # schedulerを用意する
1890
  sched_init_args = {}
@@ -1995,11 +2076,13 @@ def main(args):
1995
  # networkを組み込む
1996
  if args.network_module:
1997
  networks = []
 
1998
  for i, network_module in enumerate(args.network_module):
1999
  print("import network module:", network_module)
2000
  imported_module = importlib.import_module(network_module)
2001
 
2002
  network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
 
2003
 
2004
  net_kwargs = {}
2005
  if args.network_args and i < len(args.network_args):
@@ -2014,7 +2097,7 @@ def main(args):
2014
  network_weight = args.network_weights[i]
2015
  print("load network weights from:", network_weight)
2016
 
2017
- if model_util.is_safetensors(network_weight):
2018
  from safetensors.torch import safe_open
2019
  with safe_open(network_weight, framework="pt") as f:
2020
  metadata = f.metadata()
@@ -2037,6 +2120,18 @@ def main(args):
2037
  else:
2038
  networks = []
2039
 
 
 
 
 
 
 
 
 
 
 
 
 
2040
  if args.opt_channels_last:
2041
  print(f"set optimizing: channels last")
2042
  text_encoder.to(memory_format=torch.channels_last)
@@ -2050,9 +2145,14 @@ def main(args):
2050
  if vgg16_model is not None:
2051
  vgg16_model.to(memory_format=torch.channels_last)
2052
 
 
 
 
 
2053
  pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
2054
  clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
2055
  vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
 
2056
  print("pipeline is ready.")
2057
 
2058
  if args.diffusers_xformers:
@@ -2177,18 +2277,34 @@ def main(args):
2177
  mask_images = l
2178
 
2179
  # 画像サイズにオプション指定があるときはリサイズする
2180
- if init_images is not None and args.W is not None and args.H is not None:
2181
- print(f"resize img2img source images to {args.W}*{args.H}")
2182
- init_images = resize_images(init_images, (args.W, args.H))
 
2183
  if mask_images is not None:
2184
  print(f"resize img2img mask images to {args.W}*{args.H}")
2185
  mask_images = resize_images(mask_images, (args.W, args.H))
2186
 
 
 
 
 
 
 
 
 
 
 
 
 
2187
  prev_image = None # for VGG16 guided
2188
  if args.guide_image_path is not None:
2189
- print(f"load image for CLIP/VGG16 guidance: {args.guide_image_path}")
2190
- guide_images = load_images(args.guide_image_path)
2191
- print(f"loaded {len(guide_images)} guide images for CLIP/VGG16 guidance")
 
 
 
2192
  if len(guide_images) == 0:
2193
  print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
2194
  guide_images = None
@@ -2219,33 +2335,46 @@ def main(args):
2219
  iter_seed = random.randint(0, 0x7fffffff)
2220
 
2221
  # バッチ処理の関数
2222
- def process_batch(batch, highres_fix, highres_1st=False):
2223
  batch_size = len(batch)
2224
 
2225
  # highres_fixの処理
2226
  if highres_fix and not highres_1st:
2227
- # 1st stageのバッチを作成して呼び出す
2228
- print("process 1st stage1")
2229
  batch_1st = []
2230
- for params1, (width, height, steps, scale, negative_scale, strength) in batch:
2231
- width_1st = int(width * args.highres_fix_scale + .5)
2232
- height_1st = int(height * args.highres_fix_scale + .5)
2233
  width_1st = width_1st - width_1st % 32
2234
  height_1st = height_1st - height_1st % 32
2235
- batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, negative_scale, strength)))
 
 
 
2236
  images_1st = process_batch(batch_1st, True, True)
2237
 
2238
  # 2nd stageのバッチを作成して以下処理する
2239
- print("process 2nd stage1")
 
 
 
 
 
 
 
 
2240
  batch_2nd = []
2241
- for i, (b1, image) in enumerate(zip(batch, images_1st)):
2242
- image = image.resize((width, height), resample=PIL.Image.LANCZOS)
2243
- (step, prompt, negative_prompt, seed, _, _, clip_prompt, guide_image), params2 = b1
2244
- batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2))
 
2245
  batch = batch_2nd
2246
 
2247
- (step_first, _, _, _, init_image, mask_image, _, guide_image), (width,
2248
- height, steps, scale, negative_scale, strength) = batch[0]
 
2249
  noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
2250
 
2251
  prompts = []
@@ -2278,7 +2407,7 @@ def main(args):
2278
  all_images_are_same = True
2279
  all_masks_are_same = True
2280
  all_guide_images_are_same = True
2281
- for i, ((_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
2282
  prompts.append(prompt)
2283
  negative_prompts.append(negative_prompt)
2284
  seeds.append(seed)
@@ -2295,9 +2424,13 @@ def main(args):
2295
  all_masks_are_same = mask_images[-2] is mask_image
2296
 
2297
  if guide_image is not None:
2298
- guide_images.append(guide_image)
2299
- if i > 0 and all_guide_images_are_same:
2300
- all_guide_images_are_same = guide_images[-2] is guide_image
 
 
 
 
2301
 
2302
  # make start code
2303
  torch.manual_seed(seed)
@@ -2320,10 +2453,24 @@ def main(args):
2320
  if guide_images is not None and all_guide_images_are_same:
2321
  guide_images = guide_images[0]
2322
 
 
 
 
 
 
 
 
 
2323
  # generate
 
 
 
 
2324
  images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
2325
- output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
2326
- if highres_1st and not args.highres_fix_save_1st:
 
 
2327
  return images
2328
 
2329
  # save image
@@ -2398,6 +2545,7 @@ def main(args):
2398
  strength = 0.8 if args.strength is None else args.strength
2399
  negative_prompt = ""
2400
  clip_prompt = None
 
2401
 
2402
  prompt_args = prompt.strip().split(' --')
2403
  prompt = prompt_args[0]
@@ -2461,6 +2609,15 @@ def main(args):
2461
  clip_prompt = m.group(1)
2462
  print(f"clip prompt: {clip_prompt}")
2463
  continue
 
 
 
 
 
 
 
 
 
2464
  except ValueError as ex:
2465
  print(f"Exception in parsing / 解析エラー: {parg}")
2466
  print(ex)
@@ -2498,7 +2655,12 @@ def main(args):
2498
  mask_image = mask_images[global_step % len(mask_images)]
2499
 
2500
  if guide_images is not None:
2501
- guide_image = guide_images[global_step % len(guide_images)]
 
 
 
 
 
2502
  elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
2503
  if prev_image is None:
2504
  print("Generate 1st image without guide image.")
@@ -2506,10 +2668,9 @@ def main(args):
2506
  print("Use previous image as guide image.")
2507
  guide_image = prev_image
2508
 
2509
- # TODO named tupleか何かにする
2510
- b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
2511
- (width, height, steps, scale, negative_scale, strength))
2512
- if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
2513
  process_batch(batch_data, highres_fix)
2514
  batch_data.clear()
2515
 
@@ -2553,6 +2714,8 @@ if __name__ == '__main__':
2553
  parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
2554
  parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
2555
  parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
 
 
2556
  parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
2557
  parser.add_argument('--sampler', type=str, default='ddim',
2558
  choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
@@ -2564,6 +2727,8 @@ if __name__ == '__main__':
2564
  parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
2565
  parser.add_argument("--vae", type=str, default=None,
2566
  help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
 
 
2567
  # parser.add_argument("--replace_clip_l14_336", action='store_true',
2568
  # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
2569
  parser.add_argument("--seed", type=int, default=None,
@@ -2578,12 +2743,15 @@ if __name__ == '__main__':
2578
  parser.add_argument("--opt_channels_last", action='store_true',
2579
  help='set channels last option to model / モデルにchannels lastを指定し最適化する')
2580
  parser.add_argument("--network_module", type=str, default=None, nargs='*',
2581
- help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
2582
  parser.add_argument("--network_weights", type=str, default=None, nargs='*',
2583
- help='Hypernetwork weights to load / Hypernetworkの重み')
2584
- parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
 
2585
  parser.add_argument("--network_args", type=str, default=None, nargs='*',
2586
  help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
 
 
2587
  parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
2588
  help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
2589
  parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
@@ -2597,15 +2765,26 @@ if __name__ == '__main__':
2597
  help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
2598
  parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
2599
  help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
2600
- parser.add_argument("--guide_image_path", type=str, default=None, help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
 
2601
  parser.add_argument("--highres_fix_scale", type=float, default=None,
2602
  help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
2603
  parser.add_argument("--highres_fix_steps", type=int, default=28,
2604
  help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
2605
  parser.add_argument("--highres_fix_save_1st", action='store_true',
2606
  help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
 
 
2607
  parser.add_argument("--negative_scale", type=float, default=None,
2608
  help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
2609
 
 
 
 
 
 
 
 
 
2610
  args = parser.parse_args()
2611
  main(args)
 
47
  """
48
 
49
  import json
50
+ from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
51
  import glob
52
  import importlib
53
  import inspect
 
60
  import os
61
  import random
62
  import re
 
63
 
64
  import diffusers
65
  import numpy as np
 
80
  from PIL.PngImagePlugin import PngInfo
81
 
82
  import library.model_util as model_util
83
+ import library.train_util as train_util
84
+ import tools.original_control_net as original_control_net
85
+ from tools.original_control_net import ControlNetInfo
86
 
87
  # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
88
  TOKENIZER_PATH = "openai/clip-vit-large-patch14"
 
489
  self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
490
  self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
491
 
492
+ # ControlNet
493
+ self.control_nets: List[ControlNetInfo] = []
494
+
495
  # Textual Inversion
496
  def add_token_replacement(self, target_token_id, rep_token_ids):
497
  self.token_replacements[target_token_id] = rep_token_ids
 
505
  new_tokens.append(token)
506
  return new_tokens
507
 
508
+ def set_control_nets(self, ctrl_nets):
509
+ self.control_nets = ctrl_nets
510
+
511
  # region xformersとか使う部分:独自に書き換えるので関係なし
512
+
513
  def enable_xformers_memory_efficient_attention(self):
514
  r"""
515
  Enable memory efficient attention as implemented in xformers.
 
590
  latents: Optional[torch.FloatTensor] = None,
591
  max_embeddings_multiples: Optional[int] = 3,
592
  output_type: Optional[str] = "pil",
593
+ vae_batch_size: float = None,
594
+ return_latents: bool = False,
595
  # return_dict: bool = True,
596
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
597
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
 
683
  else:
684
  raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
685
 
686
+ vae_batch_size = batch_size if vae_batch_size is None else (
687
+ int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size)))
688
+
689
  if strength < 0 or strength > 1:
690
  raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
691
 
 
766
  text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
767
  text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
768
 
769
+ if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None or self.control_nets:
770
  if isinstance(clip_guide_images, PIL.Image.Image):
771
  clip_guide_images = [clip_guide_images]
772
 
 
779
  image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
780
  if len(image_embeddings_clip) == 1:
781
  image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
782
+ elif self.vgg16_guidance_scale > 0:
783
  size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
784
  clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
785
  clip_guide_images = torch.cat(clip_guide_images, dim=0)
 
788
  image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
789
  if len(image_embeddings_vgg16) == 1:
790
  image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
791
+ else:
792
+ # ControlNetのhintにguide imageを流用する
793
+ # 前処理はControlNet側で行う
794
+ pass
795
 
796
  # set timesteps
797
  self.scheduler.set_timesteps(num_inference_steps, self.device)
 
799
  latents_dtype = text_embeddings.dtype
800
  init_latents_orig = None
801
  mask = None
 
802
 
803
  if init_image is None:
804
  # get the initial random noise unless the user supplied it
 
830
  if isinstance(init_image[0], PIL.Image.Image):
831
  init_image = [preprocess_image(im) for im in init_image]
832
  init_image = torch.cat(init_image)
833
+ if isinstance(init_image, list):
834
+ init_image = torch.stack(init_image)
835
 
836
  # mask image to tensor
837
  if mask_image is not None:
 
842
 
843
  # encode the init image into latents and scale the latents
844
  init_image = init_image.to(device=self.device, dtype=latents_dtype)
845
+ if init_image.size()[2:] == (height // 8, width // 8):
846
+ init_latents = init_image
847
+ else:
848
+ if vae_batch_size >= batch_size:
849
+ init_latent_dist = self.vae.encode(init_image).latent_dist
850
+ init_latents = init_latent_dist.sample(generator=generator)
851
+ else:
852
+ if torch.cuda.is_available():
853
+ torch.cuda.empty_cache()
854
+ init_latents = []
855
+ for i in tqdm(range(0, batch_size, vae_batch_size)):
856
+ init_latent_dist = self.vae.encode(init_image[i:i + vae_batch_size]
857
+ if vae_batch_size > 1 else init_image[i].unsqueeze(0)).latent_dist
858
+ init_latents.append(init_latent_dist.sample(generator=generator))
859
+ init_latents = torch.cat(init_latents)
860
+
861
+ init_latents = 0.18215 * init_latents
862
+
863
  if len(init_latents) == 1:
864
  init_latents = init_latents.repeat((batch_size, 1, 1, 1))
865
  init_latents_orig = init_latents
 
898
  extra_step_kwargs["eta"] = eta
899
 
900
  num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
901
+
902
+ if self.control_nets:
903
+ guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
904
+
905
  for i, t in enumerate(tqdm(timesteps)):
906
  # expand the latents if we are doing classifier free guidance
907
  latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
908
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
909
+
910
  # predict the noise residual
911
+ if self.control_nets:
912
+ noise_pred = original_control_net.call_unet_and_control_net(
913
+ i, num_latent_input, self.unet, self.control_nets, guided_hints, i / len(timesteps), latent_model_input, t, text_embeddings).sample
914
+ else:
915
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
916
 
917
  # perform guidance
918
  if do_classifier_free_guidance:
 
954
  if is_cancelled_callback is not None and is_cancelled_callback():
955
  return None
956
 
957
+ if return_latents:
958
+ return (latents, False)
959
+
960
  latents = 1 / 0.18215 * latents
961
+ if vae_batch_size >= batch_size:
962
+ image = self.vae.decode(latents).sample
963
+ else:
964
+ if torch.cuda.is_available():
965
+ torch.cuda.empty_cache()
966
+ images = []
967
+ for i in tqdm(range(0, batch_size, vae_batch_size)):
968
+ images.append(self.vae.decode(latents[i:i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample)
969
+ image = torch.cat(images)
970
 
971
  image = (image / 2 + 0.5).clamp(0, 1)
972
 
 
1649
  if pad == eos: # v1
1650
  text_input_chunk[:, -1] = text_input[0, -1]
1651
  else: # v2
1652
+ for j in range(len(text_input_chunk)):
1653
+ if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
1654
+ text_input_chunk[j, -1] = eos
1655
+ if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
1656
+ text_input_chunk[j, 1] = eos
1657
 
1658
  if clip_skip is None or clip_skip == 1:
1659
  text_embedding = pipe.text_encoder(text_input_chunk)[0]
 
1854
  mask = mask.convert("L")
1855
  w, h = mask.size
1856
  w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
1857
+ mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS)
1858
  mask = np.array(mask).astype(np.float32) / 255.0
1859
  mask = np.tile(mask, (4, 1, 1))
1860
  mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
 
1872
  # return text_encoder
1873
 
1874
 
1875
+ class BatchDataBase(NamedTuple):
1876
+ # バッチ分割が必要ないデータ
1877
+ step: int
1878
+ prompt: str
1879
+ negative_prompt: str
1880
+ seed: int
1881
+ init_image: Any
1882
+ mask_image: Any
1883
+ clip_prompt: str
1884
+ guide_image: Any
1885
+
1886
+
1887
+ class BatchDataExt(NamedTuple):
1888
+ # バッチ分割が必要なデータ
1889
+ width: int
1890
+ height: int
1891
+ steps: int
1892
+ scale: float
1893
+ negative_scale: float
1894
+ strength: float
1895
+ network_muls: Tuple[float]
1896
+
1897
+
1898
+ class BatchData(NamedTuple):
1899
+ return_latents: bool
1900
+ base: BatchDataBase
1901
+ ext: BatchDataExt
1902
+
1903
+
1904
  def main(args):
1905
  if args.fp16:
1906
  dtype = torch.float16
 
1965
  # tokenizerを読み込む
1966
  print("loading tokenizer")
1967
  if use_stable_diffusion_format:
1968
+ tokenizer = train_util.load_tokenizer(args)
 
 
 
1969
 
1970
  # schedulerを用意する
1971
  sched_init_args = {}
 
2076
  # networkを組み込む
2077
  if args.network_module:
2078
  networks = []
2079
+ network_default_muls = []
2080
  for i, network_module in enumerate(args.network_module):
2081
  print("import network module:", network_module)
2082
  imported_module = importlib.import_module(network_module)
2083
 
2084
  network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
2085
+ network_default_muls.append(network_mul)
2086
 
2087
  net_kwargs = {}
2088
  if args.network_args and i < len(args.network_args):
 
2097
  network_weight = args.network_weights[i]
2098
  print("load network weights from:", network_weight)
2099
 
2100
+ if model_util.is_safetensors(network_weight) and args.network_show_meta:
2101
  from safetensors.torch import safe_open
2102
  with safe_open(network_weight, framework="pt") as f:
2103
  metadata = f.metadata()
 
2120
  else:
2121
  networks = []
2122
 
2123
+ # ControlNetの処理
2124
+ control_nets: List[ControlNetInfo] = []
2125
+ if args.control_net_models:
2126
+ for i, model in enumerate(args.control_net_models):
2127
+ prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
2128
+ weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
2129
+ ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
2130
+
2131
+ ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
2132
+ prep = original_control_net.load_preprocess(prep_type)
2133
+ control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
2134
+
2135
  if args.opt_channels_last:
2136
  print(f"set optimizing: channels last")
2137
  text_encoder.to(memory_format=torch.channels_last)
 
2145
  if vgg16_model is not None:
2146
  vgg16_model.to(memory_format=torch.channels_last)
2147
 
2148
+ for cn in control_nets:
2149
+ cn.unet.to(memory_format=torch.channels_last)
2150
+ cn.net.to(memory_format=torch.channels_last)
2151
+
2152
  pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
2153
  clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
2154
  vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
2155
+ pipe.set_control_nets(control_nets)
2156
  print("pipeline is ready.")
2157
 
2158
  if args.diffusers_xformers:
 
2277
  mask_images = l
2278
 
2279
  # 画像サイズにオプション指定があるときはリサイズする
2280
+ if args.W is not None and args.H is not None:
2281
+ if init_images is not None:
2282
+ print(f"resize img2img source images to {args.W}*{args.H}")
2283
+ init_images = resize_images(init_images, (args.W, args.H))
2284
  if mask_images is not None:
2285
  print(f"resize img2img mask images to {args.W}*{args.H}")
2286
  mask_images = resize_images(mask_images, (args.W, args.H))
2287
 
2288
+ if networks and mask_images:
2289
+ # mask を領域情報として流用する、現在は1枚だけ対応
2290
+ # TODO 複数のnetwork classの混在時の考慮
2291
+ print("use mask as region")
2292
+ # import cv2
2293
+ # for i in range(3):
2294
+ # cv2.imshow("msk", np.array(mask_images[0])[:,:,i])
2295
+ # cv2.waitKey()
2296
+ # cv2.destroyAllWindows()
2297
+ networks[0].__class__.set_regions(networks, np.array(mask_images[0]))
2298
+ mask_images = None
2299
+
2300
  prev_image = None # for VGG16 guided
2301
  if args.guide_image_path is not None:
2302
+ print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
2303
+ guide_images = []
2304
+ for p in args.guide_image_path:
2305
+ guide_images.extend(load_images(p))
2306
+
2307
+ print(f"loaded {len(guide_images)} guide images for guidance")
2308
  if len(guide_images) == 0:
2309
  print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
2310
  guide_images = None
 
2335
  iter_seed = random.randint(0, 0x7fffffff)
2336
 
2337
  # バッチ処理の関数
2338
+ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
2339
  batch_size = len(batch)
2340
 
2341
  # highres_fixの処理
2342
  if highres_fix and not highres_1st:
2343
+ # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
2344
+ print("process 1st stage")
2345
  batch_1st = []
2346
+ for _, base, ext in batch:
2347
+ width_1st = int(ext.width * args.highres_fix_scale + .5)
2348
+ height_1st = int(ext.height * args.highres_fix_scale + .5)
2349
  width_1st = width_1st - width_1st % 32
2350
  height_1st = height_1st - height_1st % 32
2351
+
2352
+ ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale,
2353
+ ext.negative_scale, ext.strength, ext.network_muls)
2354
+ batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
2355
  images_1st = process_batch(batch_1st, True, True)
2356
 
2357
  # 2nd stageのバッチを作成して以下処理する
2358
+ print("process 2nd stage")
2359
+ if args.highres_fix_latents_upscaling:
2360
+ org_dtype = images_1st.dtype
2361
+ if images_1st.dtype == torch.bfloat16:
2362
+ images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
2363
+ images_1st = torch.nn.functional.interpolate(
2364
+ images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode='bilinear') # , antialias=True)
2365
+ images_1st = images_1st.to(org_dtype)
2366
+
2367
  batch_2nd = []
2368
+ for i, (bd, image) in enumerate(zip(batch, images_1st)):
2369
+ if not args.highres_fix_latents_upscaling:
2370
+ image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
2371
+ bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext)
2372
+ batch_2nd.append(bd_2nd)
2373
  batch = batch_2nd
2374
 
2375
+ # このバッチの情報を取り出す
2376
+ return_latents, (step_first, _, _, _, init_image, mask_image, _, guide_image), \
2377
+ (width, height, steps, scale, negative_scale, strength, network_muls) = batch[0]
2378
  noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
2379
 
2380
  prompts = []
 
2407
  all_images_are_same = True
2408
  all_masks_are_same = True
2409
  all_guide_images_are_same = True
2410
+ for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
2411
  prompts.append(prompt)
2412
  negative_prompts.append(negative_prompt)
2413
  seeds.append(seed)
 
2424
  all_masks_are_same = mask_images[-2] is mask_image
2425
 
2426
  if guide_image is not None:
2427
+ if type(guide_image) is list:
2428
+ guide_images.extend(guide_image)
2429
+ all_guide_images_are_same = False
2430
+ else:
2431
+ guide_images.append(guide_image)
2432
+ if i > 0 and all_guide_images_are_same:
2433
+ all_guide_images_are_same = guide_images[-2] is guide_image
2434
 
2435
  # make start code
2436
  torch.manual_seed(seed)
 
2453
  if guide_images is not None and all_guide_images_are_same:
2454
  guide_images = guide_images[0]
2455
 
2456
+ # ControlNet使用時はguide imageをリサイズする
2457
+ if control_nets:
2458
+ # TODO resample��メソッド
2459
+ guide_images = guide_images if type(guide_images) == list else [guide_images]
2460
+ guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images]
2461
+ if len(guide_images) == 1:
2462
+ guide_images = guide_images[0]
2463
+
2464
  # generate
2465
+ if networks:
2466
+ for n, m in zip(networks, network_muls if network_muls else network_default_muls):
2467
+ n.set_multiplier(m)
2468
+
2469
  images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
2470
+ output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises,
2471
+ vae_batch_size=args.vae_batch_size, return_latents=return_latents,
2472
+ clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
2473
+ if highres_1st and not args.highres_fix_save_1st: # return images or latents
2474
  return images
2475
 
2476
  # save image
 
2545
  strength = 0.8 if args.strength is None else args.strength
2546
  negative_prompt = ""
2547
  clip_prompt = None
2548
+ network_muls = None
2549
 
2550
  prompt_args = prompt.strip().split(' --')
2551
  prompt = prompt_args[0]
 
2609
  clip_prompt = m.group(1)
2610
  print(f"clip prompt: {clip_prompt}")
2611
  continue
2612
+
2613
+ m = re.match(r'am ([\d\.\-,]+)', parg, re.IGNORECASE)
2614
+ if m: # network multiplies
2615
+ network_muls = [float(v) for v in m.group(1).split(",")]
2616
+ while len(network_muls) < len(networks):
2617
+ network_muls.append(network_muls[-1])
2618
+ print(f"network mul: {network_muls}")
2619
+ continue
2620
+
2621
  except ValueError as ex:
2622
  print(f"Exception in parsing / 解析エラー: {parg}")
2623
  print(ex)
 
2655
  mask_image = mask_images[global_step % len(mask_images)]
2656
 
2657
  if guide_images is not None:
2658
+ if control_nets: # 複数件の場合あり
2659
+ c = len(control_nets)
2660
+ p = global_step % (len(guide_images) // c)
2661
+ guide_image = guide_images[p * c:p * c + c]
2662
+ else:
2663
+ guide_image = guide_images[global_step % len(guide_images)]
2664
  elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
2665
  if prev_image is None:
2666
  print("Generate 1st image without guide image.")
 
2668
  print("Use previous image as guide image.")
2669
  guide_image = prev_image
2670
 
2671
+ b1 = BatchData(False, BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
2672
+ BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None))
2673
+ if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
 
2674
  process_batch(batch_data, highres_fix)
2675
  batch_data.clear()
2676
 
 
2714
  parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
2715
  parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
2716
  parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
2717
+ parser.add_argument("--vae_batch_size", type=float, default=None,
2718
+ help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率")
2719
  parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
2720
  parser.add_argument('--sampler', type=str, default='ddim',
2721
  choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
 
2727
  parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
2728
  parser.add_argument("--vae", type=str, default=None,
2729
  help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
2730
+ parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
2731
+ help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
2732
  # parser.add_argument("--replace_clip_l14_336", action='store_true',
2733
  # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
2734
  parser.add_argument("--seed", type=int, default=None,
 
2743
  parser.add_argument("--opt_channels_last", action='store_true',
2744
  help='set channels last option to model / モデルにchannels lastを指定し最適化する')
2745
  parser.add_argument("--network_module", type=str, default=None, nargs='*',
2746
+ help='additional network module to use / 追加ネットワークを使う時そのモジュール名')
2747
  parser.add_argument("--network_weights", type=str, default=None, nargs='*',
2748
+ help='additional network weights to load / 追加ネットワークの重み')
2749
+ parser.add_argument("--network_mul", type=float, default=None, nargs='*',
2750
+ help='additional network multiplier / 追加ネットワークの効果の倍率')
2751
  parser.add_argument("--network_args", type=str, default=None, nargs='*',
2752
  help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
2753
+ parser.add_argument("--network_show_meta", action='store_true',
2754
+ help='show metadata of network model / ネットワークモデルのメタデータを表示する')
2755
  parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
2756
  help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
2757
  parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
 
2765
  help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
2766
  parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
2767
  help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
2768
+ parser.add_argument("--guide_image_path", type=str, default=None, nargs="*",
2769
+ help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
2770
  parser.add_argument("--highres_fix_scale", type=float, default=None,
2771
  help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
2772
  parser.add_argument("--highres_fix_steps", type=int, default=28,
2773
  help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
2774
  parser.add_argument("--highres_fix_save_1st", action='store_true',
2775
  help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
2776
+ parser.add_argument("--highres_fix_latents_upscaling", action='store_true',
2777
+ help="use latents upscaling for highres fix / highres fixでlatentで拡大する")
2778
  parser.add_argument("--negative_scale", type=float, default=None,
2779
  help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
2780
 
2781
+ parser.add_argument("--control_net_models", type=str, default=None, nargs='*',
2782
+ help='ControlNet models to use / 使用するControlNetのモデル名')
2783
+ parser.add_argument("--control_net_preps", type=str, default=None, nargs='*',
2784
+ help='ControlNet preprocess to use / 使用するControlNetのプリプロセス名')
2785
+ parser.add_argument("--control_net_weights", type=float, default=None, nargs='*', help='ControlNet weights / ControlNetの重み')
2786
+ parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*',
2787
+ help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率')
2788
+
2789
  args = parser.parse_args()
2790
  main(args)
library/model_util.py CHANGED
@@ -4,7 +4,7 @@
4
  import math
5
  import os
6
  import torch
7
- from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
8
  from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
9
  from safetensors.torch import load_file, save_file
10
 
@@ -916,7 +916,11 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
916
  info = text_model.load_state_dict(converted_text_encoder_checkpoint)
917
  else:
918
  converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
 
 
919
  text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
 
 
920
  info = text_model.load_state_dict(converted_text_encoder_checkpoint)
921
  print("loading text encoder:", info)
922
 
 
4
  import math
5
  import os
6
  import torch
7
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
8
  from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
9
  from safetensors.torch import load_file, save_file
10
 
 
916
  info = text_model.load_state_dict(converted_text_encoder_checkpoint)
917
  else:
918
  converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
919
+
920
+ logging.set_verbosity_error() # don't show annoying warning
921
  text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
922
+ logging.set_verbosity_warning()
923
+
924
  info = text_model.load_state_dict(converted_text_encoder_checkpoint)
925
  print("loading text encoder:", info)
926
 
library/train_util.py CHANGED
@@ -1,12 +1,21 @@
1
  # common functions for training
2
 
3
  import argparse
 
4
  import json
 
5
  import shutil
6
  import time
7
- from typing import Dict, List, NamedTuple, Tuple
 
 
 
 
 
 
 
 
8
  from accelerate import Accelerator
9
- from torch.autograd.function import Function
10
  import glob
11
  import math
12
  import os
@@ -17,10 +26,16 @@ from io import BytesIO
17
 
18
  from tqdm import tqdm
19
  import torch
 
20
  from torchvision import transforms
21
  from transformers import CLIPTokenizer
 
22
  import diffusers
23
- from diffusers import DDPMScheduler, StableDiffusionPipeline
 
 
 
 
24
  import albumentations as albu
25
  import numpy as np
26
  from PIL import Image
@@ -195,23 +210,95 @@ class BucketBatchIndex(NamedTuple):
195
  batch_index: int
196
 
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  class BaseDataset(torch.utils.data.Dataset):
199
- def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
200
  super().__init__()
201
- self.tokenizer: CLIPTokenizer = tokenizer
202
  self.max_token_length = max_token_length
203
- self.shuffle_caption = shuffle_caption
204
- self.shuffle_keep_tokens = shuffle_keep_tokens
205
  # width/height is used when enable_bucket==False
206
  self.width, self.height = (None, None) if resolution is None else resolution
207
- self.face_crop_aug_range = face_crop_aug_range
208
- self.flip_aug = flip_aug
209
- self.color_aug = color_aug
210
  self.debug_dataset = debug_dataset
211
- self.random_crop = random_crop
 
 
212
  self.token_padding_disabled = False
213
- self.dataset_dirs_info = {}
214
- self.reg_dataset_dirs_info = {}
215
  self.tag_frequency = {}
216
 
217
  self.enable_bucket = False
@@ -225,49 +312,28 @@ class BaseDataset(torch.utils.data.Dataset):
225
  self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
226
 
227
  self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
228
- self.dropout_rate: float = 0
229
- self.dropout_every_n_epochs: int = None
230
- self.tag_dropout_rate: float = 0
231
 
232
  # augmentation
233
- flip_p = 0.5 if flip_aug else 0.0
234
- if color_aug:
235
- # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
236
- self.aug = albu.Compose([
237
- albu.OneOf([
238
- albu.HueSaturationValue(8, 0, 0, p=.5),
239
- albu.RandomGamma((95, 105), p=.5),
240
- ], p=.33),
241
- albu.HorizontalFlip(p=flip_p)
242
- ], p=1.)
243
- elif flip_aug:
244
- self.aug = albu.Compose([
245
- albu.HorizontalFlip(p=flip_p)
246
- ], p=1.)
247
- else:
248
- self.aug = None
249
 
250
  self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
251
 
252
  self.image_data: Dict[str, ImageInfo] = {}
 
253
 
254
  self.replacements = {}
255
 
256
  def set_current_epoch(self, epoch):
257
  self.current_epoch = epoch
258
-
259
- def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
260
- # コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
261
- self.dropout_rate = dropout_rate
262
- self.dropout_every_n_epochs = dropout_every_n_epochs
263
- self.tag_dropout_rate = tag_dropout_rate
264
 
265
  def set_tag_frequency(self, dir_name, captions):
266
  frequency_for_dir = self.tag_frequency.get(dir_name, {})
267
  self.tag_frequency[dir_name] = frequency_for_dir
268
  for caption in captions:
269
  for tag in caption.split(","):
270
- if tag and not tag.isspace():
 
271
  tag = tag.lower()
272
  frequency = frequency_for_dir.get(tag, 0)
273
  frequency_for_dir[tag] = frequency + 1
@@ -278,42 +344,36 @@ class BaseDataset(torch.utils.data.Dataset):
278
  def add_replacement(self, str_from, str_to):
279
  self.replacements[str_from] = str_to
280
 
281
- def process_caption(self, caption):
282
  # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
283
- is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
284
- is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
285
 
286
  if is_drop_out:
287
  caption = ""
288
  else:
289
- if self.shuffle_caption or self.tag_dropout_rate > 0:
290
  def dropout_tags(tokens):
291
- if self.tag_dropout_rate <= 0:
292
  return tokens
293
  l = []
294
  for token in tokens:
295
- if random.random() >= self.tag_dropout_rate:
296
  l.append(token)
297
  return l
298
 
299
- tokens = [t.strip() for t in caption.strip().split(",")]
300
- if self.shuffle_keep_tokens is None:
301
- if self.shuffle_caption:
302
- random.shuffle(tokens)
303
-
304
- tokens = dropout_tags(tokens)
305
- else:
306
- if len(tokens) > self.shuffle_keep_tokens:
307
- keep_tokens = tokens[:self.shuffle_keep_tokens]
308
- tokens = tokens[self.shuffle_keep_tokens:]
309
 
310
- if self.shuffle_caption:
311
- random.shuffle(tokens)
312
 
313
- tokens = dropout_tags(tokens)
314
 
315
- tokens = keep_tokens + tokens
316
- caption = ", ".join(tokens)
317
 
318
  # textual inversion対応
319
  for str_from, str_to in self.replacements.items():
@@ -367,8 +427,9 @@ class BaseDataset(torch.utils.data.Dataset):
367
  input_ids = torch.stack(iids_list) # 3,77
368
  return input_ids
369
 
370
- def register_image(self, info: ImageInfo):
371
  self.image_data[info.image_key] = info
 
372
 
373
  def make_buckets(self):
374
  '''
@@ -467,7 +528,7 @@ class BaseDataset(torch.utils.data.Dataset):
467
  img = np.array(image, np.uint8)
468
  return img
469
 
470
- def trim_and_resize_if_required(self, image, reso, resized_size):
471
  image_height, image_width = image.shape[0:2]
472
 
473
  if image_width != resized_size[0] or image_height != resized_size[1]:
@@ -477,22 +538,27 @@ class BaseDataset(torch.utils.data.Dataset):
477
  image_height, image_width = image.shape[0:2]
478
  if image_width > reso[0]:
479
  trim_size = image_width - reso[0]
480
- p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
481
  # print("w", trim_size, p)
482
  image = image[:, p:p + reso[0]]
483
  if image_height > reso[1]:
484
  trim_size = image_height - reso[1]
485
- p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
486
  # print("h", trim_size, p)
487
  image = image[p:p + reso[1]]
488
 
489
  assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
490
  return image
491
 
 
 
 
492
  def cache_latents(self, vae):
493
  # TODO ここを高速化したい
494
  print("caching latents.")
495
  for info in tqdm(self.image_data.values()):
 
 
496
  if info.latents_npz is not None:
497
  info.latents = self.load_latents_from_npz(info, False)
498
  info.latents = torch.FloatTensor(info.latents)
@@ -502,13 +568,13 @@ class BaseDataset(torch.utils.data.Dataset):
502
  continue
503
 
504
  image = self.load_image(info.absolute_path)
505
- image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
506
 
507
  img_tensor = self.image_transforms(image)
508
  img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
509
  info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
510
 
511
- if self.flip_aug:
512
  image = image[:, ::-1].copy() # cannot convert to Tensor without copy
513
  img_tensor = self.image_transforms(image)
514
  img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
@@ -518,11 +584,11 @@ class BaseDataset(torch.utils.data.Dataset):
518
  image = Image.open(image_path)
519
  return image.size
520
 
521
- def load_image_with_face_info(self, image_path: str):
522
  img = self.load_image(image_path)
523
 
524
  face_cx = face_cy = face_w = face_h = 0
525
- if self.face_crop_aug_range is not None:
526
  tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
527
  if len(tokens) >= 5:
528
  face_cx = int(tokens[-4])
@@ -533,7 +599,7 @@ class BaseDataset(torch.utils.data.Dataset):
533
  return img, face_cx, face_cy, face_w, face_h
534
 
535
  # いい感じに切り出す
536
- def crop_target(self, image, face_cx, face_cy, face_w, face_h):
537
  height, width = image.shape[0:2]
538
  if height == self.height and width == self.width:
539
  return image
@@ -541,8 +607,8 @@ class BaseDataset(torch.utils.data.Dataset):
541
  # 画像サイズはsizeより大きいのでリサイズする
542
  face_size = max(face_w, face_h)
543
  min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
544
- min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
545
- max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
546
  if min_scale >= max_scale: # range指定がmin==max
547
  scale = min_scale
548
  else:
@@ -560,13 +626,13 @@ class BaseDataset(torch.utils.data.Dataset):
560
  for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
561
  p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
562
 
563
- if self.random_crop:
564
  # 背景も含めるために顔を中心に置く確率を高めつつずらす
565
  range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
566
  p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
567
  else:
568
  # range指定があるときのみ、すこしだけランダムに(わりと適当)
569
- if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
570
  if face_size > self.size // 10 and face_size >= 40:
571
  p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
572
 
@@ -589,9 +655,6 @@ class BaseDataset(torch.utils.data.Dataset):
589
  return self._length
590
 
591
  def __getitem__(self, index):
592
- if index == 0:
593
- self.shuffle_buckets()
594
-
595
  bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
596
  bucket_batch_size = self.buckets_indices[index].bucket_batch_size
597
  image_index = self.buckets_indices[index].batch_index * bucket_batch_size
@@ -604,28 +667,29 @@ class BaseDataset(torch.utils.data.Dataset):
604
 
605
  for image_key in bucket[image_index:image_index + bucket_batch_size]:
606
  image_info = self.image_data[image_key]
 
607
  loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
608
 
609
  # image/latentsを処理する
610
  if image_info.latents is not None:
611
- latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped
612
  image = None
613
  elif image_info.latents_npz is not None:
614
- latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5)
615
  latents = torch.FloatTensor(latents)
616
  image = None
617
  else:
618
  # 画像を読み込み、必要ならcropする
619
- img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
620
  im_h, im_w = img.shape[0:2]
621
 
622
  if self.enable_bucket:
623
- img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
624
  else:
625
  if face_cx > 0: # 顔位置情報あり
626
- img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
627
  elif im_h > self.height or im_w > self.width:
628
- assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
629
  if im_h > self.height:
630
  p = random.randint(0, im_h - self.height)
631
  img = img[p:p + self.height]
@@ -637,8 +701,9 @@ class BaseDataset(torch.utils.data.Dataset):
637
  assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
638
 
639
  # augmentation
640
- if self.aug is not None:
641
- img = self.aug(image=img)['image']
 
642
 
643
  latents = None
644
  image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
@@ -646,7 +711,7 @@ class BaseDataset(torch.utils.data.Dataset):
646
  images.append(image)
647
  latents_list.append(latents)
648
 
649
- caption = self.process_caption(image_info.caption)
650
  captions.append(caption)
651
  if not self.token_padding_disabled: # this option might be omitted in future
652
  input_ids_list.append(self.get_input_ids(caption))
@@ -677,9 +742,8 @@ class BaseDataset(torch.utils.data.Dataset):
677
 
678
 
679
  class DreamBoothDataset(BaseDataset):
680
- def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
681
- super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
682
- resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
683
 
684
  assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
685
 
@@ -702,7 +766,7 @@ class DreamBoothDataset(BaseDataset):
702
  self.bucket_reso_steps = None # この情報は使われない
703
  self.bucket_no_upscale = False
704
 
705
- def read_caption(img_path):
706
  # captionの候補ファイル名を作る
707
  base_name = os.path.splitext(img_path)[0]
708
  base_name_face_det = base_name
@@ -725,153 +789,181 @@ class DreamBoothDataset(BaseDataset):
725
  break
726
  return caption
727
 
728
- def load_dreambooth_dir(dir):
729
- if not os.path.isdir(dir):
730
- # print(f"ignore file: {dir}")
731
- return 0, [], []
732
 
733
- tokens = os.path.basename(dir).split('_')
734
- try:
735
- n_repeats = int(tokens[0])
736
- except ValueError as e:
737
- print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
738
- return 0, [], []
739
-
740
- caption_by_folder = '_'.join(tokens[1:])
741
- img_paths = glob_images(dir, "*")
742
- print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
743
 
744
  # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
745
  captions = []
746
  for img_path in img_paths:
747
- cap_for_img = read_caption(img_path)
748
- captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
 
 
 
 
749
 
750
- self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
751
 
752
- return n_repeats, img_paths, captions
753
 
754
- print("prepare train images.")
755
- train_dirs = os.listdir(train_data_dir)
756
  num_train_images = 0
757
- for dir in train_dirs:
758
- n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
759
- num_train_images += n_repeats * len(img_paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
760
 
761
  for img_path, caption in zip(img_paths, captions):
762
- info = ImageInfo(img_path, n_repeats, caption, False, img_path)
763
- self.register_image(info)
 
 
 
764
 
765
- self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
 
766
 
767
  print(f"{num_train_images} train images with repeating.")
768
  self.num_train_images = num_train_images
769
 
770
- # reg imageは数を数えて学習画像と同じ枚数にする
771
- num_reg_images = 0
772
- if reg_data_dir:
773
- print("prepare reg images.")
774
- reg_infos: List[ImageInfo] = []
775
 
776
- reg_dirs = os.listdir(reg_data_dir)
777
- for dir in reg_dirs:
778
- n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
779
- num_reg_images += n_repeats * len(img_paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
780
 
781
- for img_path, caption in zip(img_paths, captions):
782
- info = ImageInfo(img_path, n_repeats, caption, True, img_path)
783
- reg_infos.append(info)
784
 
785
- self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
786
 
787
- print(f"{num_reg_images} reg images.")
788
- if num_train_images < num_reg_images:
789
- print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
 
 
 
 
 
790
 
791
- if num_reg_images == 0:
792
- print("no regularization images / 正則化画像が見つかりませんでした")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
793
  else:
794
- # num_repeatsを計算する:どうせ大した数ではないのでループで処理する
795
- n = 0
796
- first_loop = True
797
- while n < num_train_images:
798
- for info in reg_infos:
799
- if first_loop:
800
- self.register_image(info)
801
- n += info.num_repeats
802
- else:
803
- info.num_repeats += 1
804
- n += 1
805
- if n >= num_train_images:
806
- break
807
- first_loop = False
808
 
809
- self.num_reg_images = num_reg_images
 
 
810
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
811
 
812
- class FineTuningDataset(BaseDataset):
813
- def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
814
- super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
815
- resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
816
-
817
- # メタデータを読み込む
818
- if os.path.exists(json_file_name):
819
- print(f"loading existing metadata: {json_file_name}")
820
- with open(json_file_name, "rt", encoding='utf-8') as f:
821
- metadata = json.load(f)
822
- else:
823
- raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}")
824
 
825
- self.metadata = metadata
826
- self.train_data_dir = train_data_dir
827
- self.batch_size = batch_size
828
 
829
- tags_list = []
830
- for image_key, img_md in metadata.items():
831
- # path情報を作る
832
- if os.path.exists(image_key):
833
- abs_path = image_key
834
- else:
835
- # わりといい加減だがいい方法が思いつかん
836
- abs_path = glob_images(train_data_dir, image_key)
837
- assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
838
- abs_path = abs_path[0]
839
-
840
- caption = img_md.get('caption')
841
- tags = img_md.get('tags')
842
- if caption is None:
843
- caption = tags
844
- elif tags is not None and len(tags) > 0:
845
- caption = caption + ', ' + tags
846
- tags_list.append(tags)
847
- assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
848
-
849
- image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
850
- image_info.image_size = img_md.get('train_resolution')
851
-
852
- if not self.color_aug and not self.random_crop:
853
- # if npz exists, use them
854
- image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
855
-
856
- self.register_image(image_info)
857
- self.num_train_images = len(metadata) * dataset_repeats
858
- self.num_reg_images = 0
859
 
860
- # TODO do not record tag freq when no tag
861
- self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
862
- self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
 
 
 
863
 
864
  # check existence of all npz files
865
- use_npz_latents = not (self.color_aug or self.random_crop)
866
  if use_npz_latents:
 
867
  npz_any = False
868
  npz_all = True
 
869
  for image_info in self.image_data.values():
 
 
870
  has_npz = image_info.latents_npz is not None
871
  npz_any = npz_any or has_npz
872
 
873
- if self.flip_aug:
874
  has_npz = has_npz and image_info.latents_npz_flipped is not None
 
875
  npz_all = npz_all and has_npz
876
 
877
  if npz_any and not npz_all:
@@ -883,7 +975,7 @@ class FineTuningDataset(BaseDataset):
883
  elif not npz_all:
884
  use_npz_latents = False
885
  print(f"some of npz file does not exist. ignore npz files / いくつ���のnpzファイルが見つからないためnpzファイルを無視します")
886
- if self.flip_aug:
887
  print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
888
  # else:
889
  # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
@@ -929,7 +1021,7 @@ class FineTuningDataset(BaseDataset):
929
  for image_info in self.image_data.values():
930
  image_info.latents_npz = image_info.latents_npz_flipped = None
931
 
932
- def image_key_to_npz_file(self, image_key):
933
  base_name = os.path.splitext(image_key)[0]
934
  npz_file_norm = base_name + '.npz'
935
 
@@ -941,8 +1033,8 @@ class FineTuningDataset(BaseDataset):
941
  return npz_file_norm, npz_file_flip
942
 
943
  # image_key is relative path
944
- npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz')
945
- npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz')
946
 
947
  if not os.path.exists(npz_file_norm):
948
  npz_file_norm = None
@@ -953,13 +1045,60 @@ class FineTuningDataset(BaseDataset):
953
  return npz_file_norm, npz_file_flip
954
 
955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
956
  def debug_dataset(train_dataset, show_input_ids=False):
957
  print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
958
  print("Escape for exit. / Escキーで中断、終了します")
959
 
960
  train_dataset.set_current_epoch(1)
961
  k = 0
962
- for i, example in enumerate(train_dataset):
 
 
 
963
  if example['latents'] is not None:
964
  print(f"sample has latents from npz file: {example['latents'].size()}")
965
  for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
@@ -1364,6 +1503,35 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
1364
  help='enable v-parameterization training / v-parameterization学習を有効にする')
1365
  parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
1366
  help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1367
 
1368
 
1369
  def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
@@ -1387,10 +1555,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
1387
  parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
1388
  parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
1389
  help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
1390
- parser.add_argument("--use_8bit_adam", action="store_true",
1391
- help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
1392
- parser.add_argument("--use_lion_optimizer", action="store_true",
1393
- help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
1394
  parser.add_argument("--mem_eff_attn", action="store_true",
1395
  help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
1396
  parser.add_argument("--xformers", action="store_true",
@@ -1398,7 +1562,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
1398
  parser.add_argument("--vae", type=str, default=None,
1399
  help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
1400
 
1401
- parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
1402
  parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
1403
  parser.add_argument("--max_train_epochs", type=int, default=None,
1404
  help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
@@ -1419,15 +1582,23 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
1419
  parser.add_argument("--logging_dir", type=str, default=None,
1420
  help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
1421
  parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
1422
- parser.add_argument("--lr_scheduler", type=str, default="constant",
1423
- help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
1424
- parser.add_argument("--lr_warmup_steps", type=int, default=0,
1425
- help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
1426
  parser.add_argument("--noise_offset", type=float, default=None,
1427
  help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
1428
  parser.add_argument("--lowram", action="store_true",
1429
  help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
1430
 
 
 
 
 
 
 
 
 
 
 
 
 
1431
  if support_dreambooth:
1432
  # DreamBooth training
1433
  parser.add_argument("--prior_loss_weight", type=float, default=1.0,
@@ -1449,8 +1620,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
1449
  parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
1450
  parser.add_argument("--caption_extention", type=str, default=None,
1451
  help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
1452
- parser.add_argument("--keep_tokens", type=int, default=None,
1453
- help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
1454
  parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
1455
  parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
1456
  parser.add_argument("--face_crop_aug_range", type=str, default=None,
@@ -1475,11 +1646,11 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
1475
  if support_caption_dropout:
1476
  # Textual Inversion はcaptionのdropoutをsupportしない
1477
  # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
1478
- parser.add_argument("--caption_dropout_rate", type=float, default=0,
1479
  help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
1480
- parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
1481
  help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
1482
- parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
1483
  help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
1484
 
1485
  if support_dreambooth:
@@ -1504,16 +1675,256 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
1504
  # region utils
1505
 
1506
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1507
  def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
1508
  # backward compatibility
1509
  if args.caption_extention is not None:
1510
  args.caption_extension = args.caption_extention
1511
  args.caption_extention = None
1512
 
1513
- if args.cache_latents:
1514
- assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
1515
- assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
1516
-
1517
  # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
1518
  if args.resolution is not None:
1519
  args.resolution = tuple([int(r) for r in args.resolution.split(',')])
@@ -1536,12 +1947,28 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
1536
 
1537
  def load_tokenizer(args: argparse.Namespace):
1538
  print("prepare tokenizer")
1539
- if args.v2:
1540
- tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
1541
- else:
1542
- tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
1543
- if args.max_token_length is not None:
 
 
 
 
 
 
 
 
 
 
 
1544
  print(f"update token length: {args.max_token_length}")
 
 
 
 
 
1545
  return tokenizer
1546
 
1547
 
@@ -1592,13 +2019,19 @@ def prepare_dtype(args: argparse.Namespace):
1592
 
1593
 
1594
  def load_target_model(args: argparse.Namespace, weight_dtype):
1595
- load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers
 
 
1596
  if load_stable_diffusion_format:
1597
  print("load StableDiffusion checkpoint")
1598
- text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
1599
  else:
1600
  print("load Diffusers pretrained models")
1601
- pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
 
 
 
 
1602
  text_encoder = pipe.text_encoder
1603
  vae = pipe.vae
1604
  unet = pipe.unet
@@ -1767,6 +2200,197 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
1767
  model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
1768
  accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
1769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1770
  # endregion
1771
 
1772
  # region 前処理用
 
1
  # common functions for training
2
 
3
  import argparse
4
+ import importlib
5
  import json
6
+ import re
7
  import shutil
8
  import time
9
+ from typing import (
10
+ Dict,
11
+ List,
12
+ NamedTuple,
13
+ Optional,
14
+ Sequence,
15
+ Tuple,
16
+ Union,
17
+ )
18
  from accelerate import Accelerator
 
19
  import glob
20
  import math
21
  import os
 
26
 
27
  from tqdm import tqdm
28
  import torch
29
+ from torch.optim import Optimizer
30
  from torchvision import transforms
31
  from transformers import CLIPTokenizer
32
+ import transformers
33
  import diffusers
34
+ from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
35
+ from diffusers import (StableDiffusionPipeline, DDPMScheduler,
36
+ EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler,
37
+ LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler,
38
+ KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler)
39
  import albumentations as albu
40
  import numpy as np
41
  from PIL import Image
 
210
  batch_index: int
211
 
212
 
213
+ class AugHelper:
214
+ def __init__(self):
215
+ # prepare all possible augmentators
216
+ color_aug_method = albu.OneOf([
217
+ albu.HueSaturationValue(8, 0, 0, p=.5),
218
+ albu.RandomGamma((95, 105), p=.5),
219
+ ], p=.33)
220
+ flip_aug_method = albu.HorizontalFlip(p=0.5)
221
+
222
+ # key: (use_color_aug, use_flip_aug)
223
+ self.augmentors = {
224
+ (True, True): albu.Compose([
225
+ color_aug_method,
226
+ flip_aug_method,
227
+ ], p=1.),
228
+ (True, False): albu.Compose([
229
+ color_aug_method,
230
+ ], p=1.),
231
+ (False, True): albu.Compose([
232
+ flip_aug_method,
233
+ ], p=1.),
234
+ (False, False): None
235
+ }
236
+
237
+ def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]:
238
+ return self.augmentors[(use_color_aug, use_flip_aug)]
239
+
240
+
241
+ class BaseSubset:
242
+ def __init__(self, image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, keep_tokens: int, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], random_crop: bool, caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float) -> None:
243
+ self.image_dir = image_dir
244
+ self.num_repeats = num_repeats
245
+ self.shuffle_caption = shuffle_caption
246
+ self.keep_tokens = keep_tokens
247
+ self.color_aug = color_aug
248
+ self.flip_aug = flip_aug
249
+ self.face_crop_aug_range = face_crop_aug_range
250
+ self.random_crop = random_crop
251
+ self.caption_dropout_rate = caption_dropout_rate
252
+ self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
253
+ self.caption_tag_dropout_rate = caption_tag_dropout_rate
254
+
255
+ self.img_count = 0
256
+
257
+
258
+ class DreamBoothSubset(BaseSubset):
259
+ def __init__(self, image_dir: str, is_reg: bool, class_tokens: Optional[str], caption_extension: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
260
+ assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
261
+
262
+ super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
263
+ face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
264
+
265
+ self.is_reg = is_reg
266
+ self.class_tokens = class_tokens
267
+ self.caption_extension = caption_extension
268
+
269
+ def __eq__(self, other) -> bool:
270
+ if not isinstance(other, DreamBoothSubset):
271
+ return NotImplemented
272
+ return self.image_dir == other.image_dir
273
+
274
+
275
+ class FineTuningSubset(BaseSubset):
276
+ def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
277
+ assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
278
+
279
+ super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
280
+ face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
281
+
282
+ self.metadata_file = metadata_file
283
+
284
+ def __eq__(self, other) -> bool:
285
+ if not isinstance(other, FineTuningSubset):
286
+ return NotImplemented
287
+ return self.metadata_file == other.metadata_file
288
+
289
+
290
  class BaseDataset(torch.utils.data.Dataset):
291
+ def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None:
292
  super().__init__()
293
+ self.tokenizer = tokenizer
294
  self.max_token_length = max_token_length
 
 
295
  # width/height is used when enable_bucket==False
296
  self.width, self.height = (None, None) if resolution is None else resolution
 
 
 
297
  self.debug_dataset = debug_dataset
298
+
299
+ self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
300
+
301
  self.token_padding_disabled = False
 
 
302
  self.tag_frequency = {}
303
 
304
  self.enable_bucket = False
 
312
  self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
313
 
314
  self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
 
 
 
315
 
316
  # augmentation
317
+ self.aug_helper = AugHelper()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
  self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
320
 
321
  self.image_data: Dict[str, ImageInfo] = {}
322
+ self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
323
 
324
  self.replacements = {}
325
 
326
  def set_current_epoch(self, epoch):
327
  self.current_epoch = epoch
328
+ self.shuffle_buckets()
 
 
 
 
 
329
 
330
  def set_tag_frequency(self, dir_name, captions):
331
  frequency_for_dir = self.tag_frequency.get(dir_name, {})
332
  self.tag_frequency[dir_name] = frequency_for_dir
333
  for caption in captions:
334
  for tag in caption.split(","):
335
+ tag = tag.strip()
336
+ if tag:
337
  tag = tag.lower()
338
  frequency = frequency_for_dir.get(tag, 0)
339
  frequency_for_dir[tag] = frequency + 1
 
344
  def add_replacement(self, str_from, str_to):
345
  self.replacements[str_from] = str_to
346
 
347
+ def process_caption(self, subset: BaseSubset, caption):
348
  # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
349
+ is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate
350
+ is_drop_out = is_drop_out or subset.caption_dropout_every_n_epochs > 0 and self.current_epoch % subset.caption_dropout_every_n_epochs == 0
351
 
352
  if is_drop_out:
353
  caption = ""
354
  else:
355
+ if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0:
356
  def dropout_tags(tokens):
357
+ if subset.caption_tag_dropout_rate <= 0:
358
  return tokens
359
  l = []
360
  for token in tokens:
361
+ if random.random() >= subset.caption_tag_dropout_rate:
362
  l.append(token)
363
  return l
364
 
365
+ fixed_tokens = []
366
+ flex_tokens = [t.strip() for t in caption.strip().split(",")]
367
+ if subset.keep_tokens > 0:
368
+ fixed_tokens = flex_tokens[:subset.keep_tokens]
369
+ flex_tokens = flex_tokens[subset.keep_tokens:]
 
 
 
 
 
370
 
371
+ if subset.shuffle_caption:
372
+ random.shuffle(flex_tokens)
373
 
374
+ flex_tokens = dropout_tags(flex_tokens)
375
 
376
+ caption = ", ".join(fixed_tokens + flex_tokens)
 
377
 
378
  # textual inversion対応
379
  for str_from, str_to in self.replacements.items():
 
427
  input_ids = torch.stack(iids_list) # 3,77
428
  return input_ids
429
 
430
+ def register_image(self, info: ImageInfo, subset: BaseSubset):
431
  self.image_data[info.image_key] = info
432
+ self.image_to_subset[info.image_key] = subset
433
 
434
  def make_buckets(self):
435
  '''
 
528
  img = np.array(image, np.uint8)
529
  return img
530
 
531
+ def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size):
532
  image_height, image_width = image.shape[0:2]
533
 
534
  if image_width != resized_size[0] or image_height != resized_size[1]:
 
538
  image_height, image_width = image.shape[0:2]
539
  if image_width > reso[0]:
540
  trim_size = image_width - reso[0]
541
+ p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
542
  # print("w", trim_size, p)
543
  image = image[:, p:p + reso[0]]
544
  if image_height > reso[1]:
545
  trim_size = image_height - reso[1]
546
+ p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
547
  # print("h", trim_size, p)
548
  image = image[p:p + reso[1]]
549
 
550
  assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
551
  return image
552
 
553
+ def is_latent_cacheable(self):
554
+ return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
555
+
556
  def cache_latents(self, vae):
557
  # TODO ここを高速化したい
558
  print("caching latents.")
559
  for info in tqdm(self.image_data.values()):
560
+ subset = self.image_to_subset[info.image_key]
561
+
562
  if info.latents_npz is not None:
563
  info.latents = self.load_latents_from_npz(info, False)
564
  info.latents = torch.FloatTensor(info.latents)
 
568
  continue
569
 
570
  image = self.load_image(info.absolute_path)
571
+ image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
572
 
573
  img_tensor = self.image_transforms(image)
574
  img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
575
  info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
576
 
577
+ if subset.flip_aug:
578
  image = image[:, ::-1].copy() # cannot convert to Tensor without copy
579
  img_tensor = self.image_transforms(image)
580
  img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
 
584
  image = Image.open(image_path)
585
  return image.size
586
 
587
+ def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
588
  img = self.load_image(image_path)
589
 
590
  face_cx = face_cy = face_w = face_h = 0
591
+ if subset.face_crop_aug_range is not None:
592
  tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
593
  if len(tokens) >= 5:
594
  face_cx = int(tokens[-4])
 
599
  return img, face_cx, face_cy, face_w, face_h
600
 
601
  # いい感じに切り出す
602
+ def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h):
603
  height, width = image.shape[0:2]
604
  if height == self.height and width == self.width:
605
  return image
 
607
  # 画像サイズはsizeより大きいのでリサイズする
608
  face_size = max(face_w, face_h)
609
  min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
610
+ min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
611
+ max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
612
  if min_scale >= max_scale: # range指定がmin==max
613
  scale = min_scale
614
  else:
 
626
  for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
627
  p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
628
 
629
+ if subset.random_crop:
630
  # 背景も含めるために顔を中心に置く確率を高めつつずらす
631
  range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
632
  p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
633
  else:
634
  # range指定があるときのみ、すこしだけランダムに(わりと適当)
635
+ if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]:
636
  if face_size > self.size // 10 and face_size >= 40:
637
  p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
638
 
 
655
  return self._length
656
 
657
  def __getitem__(self, index):
 
 
 
658
  bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
659
  bucket_batch_size = self.buckets_indices[index].bucket_batch_size
660
  image_index = self.buckets_indices[index].batch_index * bucket_batch_size
 
667
 
668
  for image_key in bucket[image_index:image_index + bucket_batch_size]:
669
  image_info = self.image_data[image_key]
670
+ subset = self.image_to_subset[image_key]
671
  loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
672
 
673
  # image/latentsを処理する
674
  if image_info.latents is not None:
675
+ latents = image_info.latents if not subset.flip_aug or random.random() < .5 else image_info.latents_flipped
676
  image = None
677
  elif image_info.latents_npz is not None:
678
+ latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= .5)
679
  latents = torch.FloatTensor(latents)
680
  image = None
681
  else:
682
  # 画像を読み込み、必要ならcropする
683
+ img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path)
684
  im_h, im_w = img.shape[0:2]
685
 
686
  if self.enable_bucket:
687
+ img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size)
688
  else:
689
  if face_cx > 0: # 顔位置情報あり
690
+ img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
691
  elif im_h > self.height or im_w > self.width:
692
+ assert subset.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
693
  if im_h > self.height:
694
  p = random.randint(0, im_h - self.height)
695
  img = img[p:p + self.height]
 
701
  assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
702
 
703
  # augmentation
704
+ aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
705
+ if aug is not None:
706
+ img = aug(image=img)['image']
707
 
708
  latents = None
709
  image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
 
711
  images.append(image)
712
  latents_list.append(latents)
713
 
714
+ caption = self.process_caption(subset, image_info.caption)
715
  captions.append(caption)
716
  if not self.token_padding_disabled: # this option might be omitted in future
717
  input_ids_list.append(self.get_input_ids(caption))
 
742
 
743
 
744
  class DreamBoothDataset(BaseDataset):
745
+ def __init__(self, subsets: Sequence[DreamBoothSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset) -> None:
746
+ super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
 
747
 
748
  assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
749
 
 
766
  self.bucket_reso_steps = None # この情報は使われない
767
  self.bucket_no_upscale = False
768
 
769
+ def read_caption(img_path, caption_extension):
770
  # captionの候補ファイル名を作る
771
  base_name = os.path.splitext(img_path)[0]
772
  base_name_face_det = base_name
 
789
  break
790
  return caption
791
 
792
+ def load_dreambooth_dir(subset: DreamBoothSubset):
793
+ if not os.path.isdir(subset.image_dir):
794
+ print(f"not directory: {subset.image_dir}")
795
+ return [], []
796
 
797
+ img_paths = glob_images(subset.image_dir, "*")
798
+ print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
 
 
 
 
 
 
 
 
799
 
800
  # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
801
  captions = []
802
  for img_path in img_paths:
803
+ cap_for_img = read_caption(img_path, subset.caption_extension)
804
+ if cap_for_img is None and subset.class_tokens is None:
805
+ print(f"neither caption file nor class tokens are found. use empty caption for {img_path}")
806
+ captions.append("")
807
+ else:
808
+ captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
809
 
810
+ self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
811
 
812
+ return img_paths, captions
813
 
814
+ print("prepare images.")
 
815
  num_train_images = 0
816
+ num_reg_images = 0
817
+ reg_infos: List[ImageInfo] = []
818
+ for subset in subsets:
819
+ if subset.num_repeats < 1:
820
+ print(
821
+ f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
822
+ continue
823
+
824
+ if subset in self.subsets:
825
+ print(
826
+ f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
827
+ continue
828
+
829
+ img_paths, captions = load_dreambooth_dir(subset)
830
+ if len(img_paths) < 1:
831
+ print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
832
+ continue
833
+
834
+ if subset.is_reg:
835
+ num_reg_images += subset.num_repeats * len(img_paths)
836
+ else:
837
+ num_train_images += subset.num_repeats * len(img_paths)
838
 
839
  for img_path, caption in zip(img_paths, captions):
840
+ info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
841
+ if subset.is_reg:
842
+ reg_infos.append(info)
843
+ else:
844
+ self.register_image(info, subset)
845
 
846
+ subset.img_count = len(img_paths)
847
+ self.subsets.append(subset)
848
 
849
  print(f"{num_train_images} train images with repeating.")
850
  self.num_train_images = num_train_images
851
 
852
+ print(f"{num_reg_images} reg images.")
853
+ if num_train_images < num_reg_images:
854
+ print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
 
 
855
 
856
+ if num_reg_images == 0:
857
+ print("no regularization images / 正則化画像が見つかりませんでした")
858
+ else:
859
+ # num_repeatsを計算する:どうせ大した数ではないのでループで処理する
860
+ n = 0
861
+ first_loop = True
862
+ while n < num_train_images:
863
+ for info in reg_infos:
864
+ if first_loop:
865
+ self.register_image(info, subset)
866
+ n += info.num_repeats
867
+ else:
868
+ info.num_repeats += 1
869
+ n += 1
870
+ if n >= num_train_images:
871
+ break
872
+ first_loop = False
873
 
874
+ self.num_reg_images = num_reg_images
 
 
875
 
 
876
 
877
+ class FineTuningDataset(BaseDataset):
878
+ def __init__(self, subsets: Sequence[FineTuningSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset) -> None:
879
+ super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
880
+
881
+ self.batch_size = batch_size
882
+
883
+ self.num_train_images = 0
884
+ self.num_reg_images = 0
885
 
886
+ for subset in subsets:
887
+ if subset.num_repeats < 1:
888
+ print(
889
+ f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
890
+ continue
891
+
892
+ if subset in self.subsets:
893
+ print(
894
+ f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
895
+ continue
896
+
897
+ # メタデータを読み込む
898
+ if os.path.exists(subset.metadata_file):
899
+ print(f"loading existing metadata: {subset.metadata_file}")
900
+ with open(subset.metadata_file, "rt", encoding='utf-8') as f:
901
+ metadata = json.load(f)
902
  else:
903
+ raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}")
 
 
 
 
 
 
 
 
 
 
 
 
 
904
 
905
+ if len(metadata) < 1:
906
+ print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します")
907
+ continue
908
 
909
+ tags_list = []
910
+ for image_key, img_md in metadata.items():
911
+ # path情報を作る
912
+ if os.path.exists(image_key):
913
+ abs_path = image_key
914
+ else:
915
+ npz_path = os.path.join(subset.image_dir, image_key + ".npz")
916
+ if os.path.exists(npz_path):
917
+ abs_path = npz_path
918
+ else:
919
+ # わりといい加減だがいい方法が思いつかん
920
+ abs_path = glob_images(subset.image_dir, image_key)
921
+ assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
922
+ abs_path = abs_path[0]
923
 
924
+ caption = img_md.get('caption')
925
+ tags = img_md.get('tags')
926
+ if caption is None:
927
+ caption = tags
928
+ elif tags is not None and len(tags) > 0:
929
+ caption = caption + ', ' + tags
930
+ tags_list.append(tags)
 
 
 
 
 
931
 
932
+ if caption is None:
933
+ caption = ""
 
934
 
935
+ image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
936
+ image_info.image_size = img_md.get('train_resolution')
937
+
938
+ if not subset.color_aug and not subset.random_crop:
939
+ # if npz exists, use them
940
+ image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
941
+
942
+ self.register_image(image_info, subset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
943
 
944
+ self.num_train_images += len(metadata) * subset.num_repeats
945
+
946
+ # TODO do not record tag freq when no tag
947
+ self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list)
948
+ subset.img_count = len(metadata)
949
+ self.subsets.append(subset)
950
 
951
  # check existence of all npz files
952
+ use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
953
  if use_npz_latents:
954
+ flip_aug_in_subset = False
955
  npz_any = False
956
  npz_all = True
957
+
958
  for image_info in self.image_data.values():
959
+ subset = self.image_to_subset[image_info.image_key]
960
+
961
  has_npz = image_info.latents_npz is not None
962
  npz_any = npz_any or has_npz
963
 
964
+ if subset.flip_aug:
965
  has_npz = has_npz and image_info.latents_npz_flipped is not None
966
+ flip_aug_in_subset = True
967
  npz_all = npz_all and has_npz
968
 
969
  if npz_any and not npz_all:
 
975
  elif not npz_all:
976
  use_npz_latents = False
977
  print(f"some of npz file does not exist. ignore npz files / いくつ���のnpzファイルが見つからないためnpzファイルを無視します")
978
+ if flip_aug_in_subset:
979
  print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
980
  # else:
981
  # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
 
1021
  for image_info in self.image_data.values():
1022
  image_info.latents_npz = image_info.latents_npz_flipped = None
1023
 
1024
+ def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
1025
  base_name = os.path.splitext(image_key)[0]
1026
  npz_file_norm = base_name + '.npz'
1027
 
 
1033
  return npz_file_norm, npz_file_flip
1034
 
1035
  # image_key is relative path
1036
+ npz_file_norm = os.path.join(subset.image_dir, image_key + '.npz')
1037
+ npz_file_flip = os.path.join(subset.image_dir, image_key + '_flip.npz')
1038
 
1039
  if not os.path.exists(npz_file_norm):
1040
  npz_file_norm = None
 
1045
  return npz_file_norm, npz_file_flip
1046
 
1047
 
1048
+ # behave as Dataset mock
1049
+ class DatasetGroup(torch.utils.data.ConcatDataset):
1050
+ def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
1051
+ self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]]
1052
+
1053
+ super().__init__(datasets)
1054
+
1055
+ self.image_data = {}
1056
+ self.num_train_images = 0
1057
+ self.num_reg_images = 0
1058
+
1059
+ # simply concat together
1060
+ # TODO: handling image_data key duplication among dataset
1061
+ # In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset.
1062
+ for dataset in datasets:
1063
+ self.image_data.update(dataset.image_data)
1064
+ self.num_train_images += dataset.num_train_images
1065
+ self.num_reg_images += dataset.num_reg_images
1066
+
1067
+ def add_replacement(self, str_from, str_to):
1068
+ for dataset in self.datasets:
1069
+ dataset.add_replacement(str_from, str_to)
1070
+
1071
+ # def make_buckets(self):
1072
+ # for dataset in self.datasets:
1073
+ # dataset.make_buckets()
1074
+
1075
+ def cache_latents(self, vae):
1076
+ for i, dataset in enumerate(self.datasets):
1077
+ print(f"[Dataset {i}]")
1078
+ dataset.cache_latents(vae)
1079
+
1080
+ def is_latent_cacheable(self) -> bool:
1081
+ return all([dataset.is_latent_cacheable() for dataset in self.datasets])
1082
+
1083
+ def set_current_epoch(self, epoch):
1084
+ for dataset in self.datasets:
1085
+ dataset.set_current_epoch(epoch)
1086
+
1087
+ def disable_token_padding(self):
1088
+ for dataset in self.datasets:
1089
+ dataset.disable_token_padding()
1090
+
1091
+
1092
  def debug_dataset(train_dataset, show_input_ids=False):
1093
  print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
1094
  print("Escape for exit. / Escキーで中断、終了します")
1095
 
1096
  train_dataset.set_current_epoch(1)
1097
  k = 0
1098
+ indices = list(range(len(train_dataset)))
1099
+ random.shuffle(indices)
1100
+ for i, idx in enumerate(indices):
1101
+ example = train_dataset[idx]
1102
  if example['latents'] is not None:
1103
  print(f"sample has latents from npz file: {example['latents'].size()}")
1104
  for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
 
1503
  help='enable v-parameterization training / v-parameterization学習を有効にする')
1504
  parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
1505
  help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
1506
+ parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
1507
+ help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
1508
+
1509
+
1510
+ def add_optimizer_arguments(parser: argparse.ArgumentParser):
1511
+ parser.add_argument("--optimizer_type", type=str, default="",
1512
+ help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor")
1513
+
1514
+ # backward compatibility
1515
+ parser.add_argument("--use_8bit_adam", action="store_true",
1516
+ help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
1517
+ parser.add_argument("--use_lion_optimizer", action="store_true",
1518
+ help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
1519
+
1520
+ parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
1521
+ parser.add_argument("--max_grad_norm", default=1.0, type=float,
1522
+ help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない")
1523
+
1524
+ parser.add_argument("--optimizer_args", type=str, default=None, nargs='*',
1525
+ help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")")
1526
+
1527
+ parser.add_argument("--lr_scheduler", type=str, default="constant",
1528
+ help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor")
1529
+ parser.add_argument("--lr_warmup_steps", type=int, default=0,
1530
+ help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
1531
+ parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
1532
+ help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
1533
+ parser.add_argument("--lr_scheduler_power", type=float, default=1,
1534
+ help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
1535
 
1536
 
1537
  def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
 
1555
  parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
1556
  parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
1557
  help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
 
 
 
 
1558
  parser.add_argument("--mem_eff_attn", action="store_true",
1559
  help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
1560
  parser.add_argument("--xformers", action="store_true",
 
1562
  parser.add_argument("--vae", type=str, default=None,
1563
  help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
1564
 
 
1565
  parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
1566
  parser.add_argument("--max_train_epochs", type=int, default=None,
1567
  help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
 
1582
  parser.add_argument("--logging_dir", type=str, default=None,
1583
  help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
1584
  parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
 
 
 
 
1585
  parser.add_argument("--noise_offset", type=float, default=None,
1586
  help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
1587
  parser.add_argument("--lowram", action="store_true",
1588
  help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
1589
 
1590
+ parser.add_argument("--sample_every_n_steps", type=int, default=None,
1591
+ help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する")
1592
+ parser.add_argument("--sample_every_n_epochs", type=int, default=None,
1593
+ help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)")
1594
+ parser.add_argument("--sample_prompts", type=str, default=None,
1595
+ help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル")
1596
+ parser.add_argument('--sample_sampler', type=str, default='ddim',
1597
+ choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
1598
+ 'dpmsolver++', 'dpmsingle',
1599
+ 'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'],
1600
+ help=f'sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類')
1601
+
1602
  if support_dreambooth:
1603
  # DreamBooth training
1604
  parser.add_argument("--prior_loss_weight", type=float, default=1.0,
 
1620
  parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
1621
  parser.add_argument("--caption_extention", type=str, default=None,
1622
  help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
1623
+ parser.add_argument("--keep_tokens", type=int, default=0,
1624
+ help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)")
1625
  parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
1626
  parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
1627
  parser.add_argument("--face_crop_aug_range", type=str, default=None,
 
1646
  if support_caption_dropout:
1647
  # Textual Inversion はcaptionのdropoutをsupportしない
1648
  # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
1649
+ parser.add_argument("--caption_dropout_rate", type=float, default=0.0,
1650
  help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
1651
+ parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=0,
1652
  help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
1653
+ parser.add_argument("--caption_tag_dropout_rate", type=float, default=0.0,
1654
  help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
1655
 
1656
  if support_dreambooth:
 
1675
  # region utils
1676
 
1677
 
1678
+ def get_optimizer(args, trainable_params):
1679
+ # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
1680
+
1681
+ optimizer_type = args.optimizer_type
1682
+ if args.use_8bit_adam:
1683
+ assert not args.use_lion_optimizer, "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています"
1684
+ assert optimizer_type is None or optimizer_type == "", "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています"
1685
+ optimizer_type = "AdamW8bit"
1686
+
1687
+ elif args.use_lion_optimizer:
1688
+ assert optimizer_type is None or optimizer_type == "", "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています"
1689
+ optimizer_type = "Lion"
1690
+
1691
+ if optimizer_type is None or optimizer_type == "":
1692
+ optimizer_type = "AdamW"
1693
+ optimizer_type = optimizer_type.lower()
1694
+
1695
+ # 引数を分解する:boolとfloat、tupleのみ対応
1696
+ optimizer_kwargs = {}
1697
+ if args.optimizer_args is not None and len(args.optimizer_args) > 0:
1698
+ for arg in args.optimizer_args:
1699
+ key, value = arg.split('=')
1700
+
1701
+ value = value.split(",")
1702
+ for i in range(len(value)):
1703
+ if value[i].lower() == "true" or value[i].lower() == "false":
1704
+ value[i] = (value[i].lower() == "true")
1705
+ else:
1706
+ value[i] = float(value[i])
1707
+ if len(value) == 1:
1708
+ value = value[0]
1709
+ else:
1710
+ value = tuple(value)
1711
+
1712
+ optimizer_kwargs[key] = value
1713
+ # print("optkwargs:", optimizer_kwargs)
1714
+
1715
+ lr = args.learning_rate
1716
+
1717
+ if optimizer_type == "AdamW8bit".lower():
1718
+ try:
1719
+ import bitsandbytes as bnb
1720
+ except ImportError:
1721
+ raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
1722
+ print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
1723
+ optimizer_class = bnb.optim.AdamW8bit
1724
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1725
+
1726
+ elif optimizer_type == "SGDNesterov8bit".lower():
1727
+ try:
1728
+ import bitsandbytes as bnb
1729
+ except ImportError:
1730
+ raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
1731
+ print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
1732
+ if "momentum" not in optimizer_kwargs:
1733
+ print(f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
1734
+ optimizer_kwargs["momentum"] = 0.9
1735
+
1736
+ optimizer_class = bnb.optim.SGD8bit
1737
+ optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
1738
+
1739
+ elif optimizer_type == "Lion".lower():
1740
+ try:
1741
+ import lion_pytorch
1742
+ except ImportError:
1743
+ raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
1744
+ print(f"use Lion optimizer | {optimizer_kwargs}")
1745
+ optimizer_class = lion_pytorch.Lion
1746
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1747
+
1748
+ elif optimizer_type == "SGDNesterov".lower():
1749
+ print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
1750
+ if "momentum" not in optimizer_kwargs:
1751
+ print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
1752
+ optimizer_kwargs["momentum"] = 0.9
1753
+
1754
+ optimizer_class = torch.optim.SGD
1755
+ optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
1756
+
1757
+ elif optimizer_type == "DAdaptation".lower():
1758
+ try:
1759
+ import dadaptation
1760
+ except ImportError:
1761
+ raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
1762
+ print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
1763
+
1764
+ actual_lr = lr
1765
+ lr_count = 1
1766
+ if type(trainable_params) == list and type(trainable_params[0]) == dict:
1767
+ lrs = set()
1768
+ actual_lr = trainable_params[0].get("lr", actual_lr)
1769
+ for group in trainable_params:
1770
+ lrs.add(group.get("lr", actual_lr))
1771
+ lr_count = len(lrs)
1772
+
1773
+ if actual_lr <= 0.1:
1774
+ print(
1775
+ f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}')
1776
+ print('recommend option: lr=1.0 / 推奨は1.0です')
1777
+ if lr_count > 1:
1778
+ print(
1779
+ f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}")
1780
+
1781
+ optimizer_class = dadaptation.DAdaptAdam
1782
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1783
+
1784
+ elif optimizer_type == "Adafactor".lower():
1785
+ # 引数を確認して適宜補正する
1786
+ if "relative_step" not in optimizer_kwargs:
1787
+ optimizer_kwargs["relative_step"] = True # default
1788
+ if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
1789
+ print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします")
1790
+ optimizer_kwargs["relative_step"] = True
1791
+ print(f"use Adafactor optimizer | {optimizer_kwargs}")
1792
+
1793
+ if optimizer_kwargs["relative_step"]:
1794
+ print(f"relative_step is true / relative_stepがtrueです")
1795
+ if lr != 0.0:
1796
+ print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
1797
+ args.learning_rate = None
1798
+
1799
+ # trainable_paramsがgroupだった時の処理:lrを削除する
1800
+ if type(trainable_params) == list and type(trainable_params[0]) == dict:
1801
+ has_group_lr = False
1802
+ for group in trainable_params:
1803
+ p = group.pop("lr", None)
1804
+ has_group_lr = has_group_lr or (p is not None)
1805
+
1806
+ if has_group_lr:
1807
+ # 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない
1808
+ print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます")
1809
+ args.unet_lr = None
1810
+ args.text_encoder_lr = None
1811
+
1812
+ if args.lr_scheduler != "adafactor":
1813
+ print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
1814
+ args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
1815
+
1816
+ lr = None
1817
+ else:
1818
+ if args.max_grad_norm != 0.0:
1819
+ print(f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません")
1820
+ if args.lr_scheduler != "constant_with_warmup":
1821
+ print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
1822
+ if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
1823
+ print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
1824
+
1825
+ optimizer_class = transformers.optimization.Adafactor
1826
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1827
+
1828
+ elif optimizer_type == "AdamW".lower():
1829
+ print(f"use AdamW optimizer | {optimizer_kwargs}")
1830
+ optimizer_class = torch.optim.AdamW
1831
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1832
+
1833
+ else:
1834
+ # 任意のoptimizerを使う
1835
+ optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
1836
+ print(f"use {optimizer_type} | {optimizer_kwargs}")
1837
+ if "." not in optimizer_type:
1838
+ optimizer_module = torch.optim
1839
+ else:
1840
+ values = optimizer_type.split(".")
1841
+ optimizer_module = importlib.import_module(".".join(values[:-1]))
1842
+ optimizer_type = values[-1]
1843
+
1844
+ optimizer_class = getattr(optimizer_module, optimizer_type)
1845
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
1846
+
1847
+ optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
1848
+ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
1849
+
1850
+ return optimizer_name, optimizer_args, optimizer
1851
+
1852
+
1853
+ # Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
1854
+ # code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
1855
+ # Which is a newer release of diffusers than currently packaged with sd-scripts
1856
+ # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
1857
+
1858
+
1859
+ def get_scheduler_fix(
1860
+ name: Union[str, SchedulerType],
1861
+ optimizer: Optimizer,
1862
+ num_warmup_steps: Optional[int] = None,
1863
+ num_training_steps: Optional[int] = None,
1864
+ num_cycles: int = 1,
1865
+ power: float = 1.0,
1866
+ ):
1867
+ """
1868
+ Unified API to get any scheduler from its name.
1869
+ Args:
1870
+ name (`str` or `SchedulerType`):
1871
+ The name of the scheduler to use.
1872
+ optimizer (`torch.optim.Optimizer`):
1873
+ The optimizer that will be used during training.
1874
+ num_warmup_steps (`int`, *optional*):
1875
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
1876
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
1877
+ num_training_steps (`int``, *optional*):
1878
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
1879
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
1880
+ num_cycles (`int`, *optional*):
1881
+ The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
1882
+ power (`float`, *optional*, defaults to 1.0):
1883
+ Power factor. See `POLYNOMIAL` scheduler
1884
+ last_epoch (`int`, *optional*, defaults to -1):
1885
+ The index of the last epoch when resuming training.
1886
+ """
1887
+ if name.startswith("adafactor"):
1888
+ assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
1889
+ initial_lr = float(name.split(':')[1])
1890
+ # print("adafactor scheduler init lr", initial_lr)
1891
+ return transformers.optimization.AdafactorSchedule(optimizer, initial_lr)
1892
+
1893
+ name = SchedulerType(name)
1894
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
1895
+ if name == SchedulerType.CONSTANT:
1896
+ return schedule_func(optimizer)
1897
+
1898
+ # All other schedulers require `num_warmup_steps`
1899
+ if num_warmup_steps is None:
1900
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
1901
+
1902
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
1903
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
1904
+
1905
+ # All other schedulers require `num_training_steps`
1906
+ if num_training_steps is None:
1907
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
1908
+
1909
+ if name == SchedulerType.COSINE_WITH_RESTARTS:
1910
+ return schedule_func(
1911
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
1912
+ )
1913
+
1914
+ if name == SchedulerType.POLYNOMIAL:
1915
+ return schedule_func(
1916
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
1917
+ )
1918
+
1919
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
1920
+
1921
+
1922
  def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
1923
  # backward compatibility
1924
  if args.caption_extention is not None:
1925
  args.caption_extension = args.caption_extention
1926
  args.caption_extention = None
1927
 
 
 
 
 
1928
  # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
1929
  if args.resolution is not None:
1930
  args.resolution = tuple([int(r) for r in args.resolution.split(',')])
 
1947
 
1948
  def load_tokenizer(args: argparse.Namespace):
1949
  print("prepare tokenizer")
1950
+ original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH
1951
+
1952
+ tokenizer: CLIPTokenizer = None
1953
+ if args.tokenizer_cache_dir:
1954
+ local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace('/', '_'))
1955
+ if os.path.exists(local_tokenizer_path):
1956
+ print(f"load tokenizer from cache: {local_tokenizer_path}")
1957
+ tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2
1958
+
1959
+ if tokenizer is None:
1960
+ if args.v2:
1961
+ tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer")
1962
+ else:
1963
+ tokenizer = CLIPTokenizer.from_pretrained(original_path)
1964
+
1965
+ if hasattr(args, "max_token_length") and args.max_token_length is not None:
1966
  print(f"update token length: {args.max_token_length}")
1967
+
1968
+ if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
1969
+ print(f"save Tokenizer to cache: {local_tokenizer_path}")
1970
+ tokenizer.save_pretrained(local_tokenizer_path)
1971
+
1972
  return tokenizer
1973
 
1974
 
 
2019
 
2020
 
2021
  def load_target_model(args: argparse.Namespace, weight_dtype):
2022
+ name_or_path = args.pretrained_model_name_or_path
2023
+ name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
2024
+ load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
2025
  if load_stable_diffusion_format:
2026
  print("load StableDiffusion checkpoint")
2027
+ text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path)
2028
  else:
2029
  print("load Diffusers pretrained models")
2030
+ try:
2031
+ pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
2032
+ except EnvironmentError as ex:
2033
+ print(
2034
+ f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}")
2035
  text_encoder = pipe.text_encoder
2036
  vae = pipe.vae
2037
  unet = pipe.unet
 
2200
  model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
2201
  accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
2202
 
2203
+
2204
+ # scheduler:
2205
+ SCHEDULER_LINEAR_START = 0.00085
2206
+ SCHEDULER_LINEAR_END = 0.0120
2207
+ SCHEDULER_TIMESTEPS = 1000
2208
+ SCHEDLER_SCHEDULE = 'scaled_linear'
2209
+
2210
+
2211
+ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None):
2212
+ """
2213
+ 生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない
2214
+ clip skipは対応した
2215
+ """
2216
+ if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
2217
+ return
2218
+ if args.sample_every_n_epochs is not None:
2219
+ # sample_every_n_steps は無視する
2220
+ if epoch is None or epoch % args.sample_every_n_epochs != 0:
2221
+ return
2222
+ else:
2223
+ if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
2224
+ return
2225
+
2226
+ print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
2227
+ if not os.path.isfile(args.sample_prompts):
2228
+ print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
2229
+ return
2230
+
2231
+ org_vae_device = vae.device # CPUにいるはず
2232
+ vae.to(device)
2233
+
2234
+ # clip skip 対応のための wrapper を作る
2235
+ if args.clip_skip is None:
2236
+ text_encoder_or_wrapper = text_encoder
2237
+ else:
2238
+ class Wrapper():
2239
+ def __init__(self, tenc) -> None:
2240
+ self.tenc = tenc
2241
+ self.config = {}
2242
+ super().__init__()
2243
+
2244
+ def __call__(self, input_ids, attention_mask):
2245
+ enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True)
2246
+ encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
2247
+ encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states)
2248
+ pooled_output = enc_out['pooler_output']
2249
+ return encoder_hidden_states, pooled_output # 1st output is only used
2250
+
2251
+ text_encoder_or_wrapper = Wrapper(text_encoder)
2252
+
2253
+ # read prompts
2254
+ with open(args.sample_prompts, 'rt', encoding='utf-8') as f:
2255
+ prompts = f.readlines()
2256
+
2257
+ # schedulerを用意する
2258
+ sched_init_args = {}
2259
+ if args.sample_sampler == "ddim":
2260
+ scheduler_cls = DDIMScheduler
2261
+ elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
2262
+ scheduler_cls = DDPMScheduler
2263
+ elif args.sample_sampler == "pndm":
2264
+ scheduler_cls = PNDMScheduler
2265
+ elif args.sample_sampler == 'lms' or args.sample_sampler == 'k_lms':
2266
+ scheduler_cls = LMSDiscreteScheduler
2267
+ elif args.sample_sampler == 'euler' or args.sample_sampler == 'k_euler':
2268
+ scheduler_cls = EulerDiscreteScheduler
2269
+ elif args.sample_sampler == 'euler_a' or args.sample_sampler == 'k_euler_a':
2270
+ scheduler_cls = EulerAncestralDiscreteScheduler
2271
+ elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
2272
+ scheduler_cls = DPMSolverMultistepScheduler
2273
+ sched_init_args['algorithm_type'] = args.sample_sampler
2274
+ elif args.sample_sampler == "dpmsingle":
2275
+ scheduler_cls = DPMSolverSinglestepScheduler
2276
+ elif args.sample_sampler == "heun":
2277
+ scheduler_cls = HeunDiscreteScheduler
2278
+ elif args.sample_sampler == 'dpm_2' or args.sample_sampler == 'k_dpm_2':
2279
+ scheduler_cls = KDPM2DiscreteScheduler
2280
+ elif args.sample_sampler == 'dpm_2_a' or args.sample_sampler == 'k_dpm_2_a':
2281
+ scheduler_cls = KDPM2AncestralDiscreteScheduler
2282
+ else:
2283
+ scheduler_cls = DDIMScheduler
2284
+
2285
+ if args.v_parameterization:
2286
+ sched_init_args['prediction_type'] = 'v_prediction'
2287
+
2288
+ scheduler = scheduler_cls(num_train_timesteps=SCHEDULER_TIMESTEPS,
2289
+ beta_start=SCHEDULER_LINEAR_START, beta_end=SCHEDULER_LINEAR_END,
2290
+ beta_schedule=SCHEDLER_SCHEDULE, **sched_init_args)
2291
+
2292
+ # clip_sample=Trueにする
2293
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
2294
+ # print("set clip_sample to True")
2295
+ scheduler.config.clip_sample = True
2296
+
2297
+ pipeline = StableDiffusionPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer,
2298
+ scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False)
2299
+ pipeline.to(device)
2300
+
2301
+ save_dir = args.output_dir + "/sample"
2302
+ os.makedirs(save_dir, exist_ok=True)
2303
+
2304
+ rng_state = torch.get_rng_state()
2305
+ cuda_rng_state = torch.cuda.get_rng_state()
2306
+
2307
+ with torch.no_grad():
2308
+ with accelerator.autocast():
2309
+ for i, prompt in enumerate(prompts):
2310
+ if not accelerator.is_main_process:
2311
+ continue
2312
+ prompt = prompt.strip()
2313
+ if len(prompt) == 0 or prompt[0] == '#':
2314
+ continue
2315
+
2316
+ # subset of gen_img_diffusers
2317
+ prompt_args = prompt.split(' --')
2318
+ prompt = prompt_args[0]
2319
+ negative_prompt = None
2320
+ sample_steps = 30
2321
+ width = height = 512
2322
+ scale = 7.5
2323
+ seed = None
2324
+ for parg in prompt_args:
2325
+ try:
2326
+ m = re.match(r'w (\d+)', parg, re.IGNORECASE)
2327
+ if m:
2328
+ width = int(m.group(1))
2329
+ continue
2330
+
2331
+ m = re.match(r'h (\d+)', parg, re.IGNORECASE)
2332
+ if m:
2333
+ height = int(m.group(1))
2334
+ continue
2335
+
2336
+ m = re.match(r'd (\d+)', parg, re.IGNORECASE)
2337
+ if m:
2338
+ seed = int(m.group(1))
2339
+ continue
2340
+
2341
+ m = re.match(r's (\d+)', parg, re.IGNORECASE)
2342
+ if m: # steps
2343
+ sample_steps = max(1, min(1000, int(m.group(1))))
2344
+ continue
2345
+
2346
+ m = re.match(r'l ([\d\.]+)', parg, re.IGNORECASE)
2347
+ if m: # scale
2348
+ scale = float(m.group(1))
2349
+ continue
2350
+
2351
+ m = re.match(r'n (.+)', parg, re.IGNORECASE)
2352
+ if m: # negative prompt
2353
+ negative_prompt = m.group(1)
2354
+ continue
2355
+
2356
+ except ValueError as ex:
2357
+ print(f"Exception in parsing / 解析エラー: {parg}")
2358
+ print(ex)
2359
+
2360
+ if seed is not None:
2361
+ torch.manual_seed(seed)
2362
+ torch.cuda.manual_seed(seed)
2363
+
2364
+ if prompt_replacement is not None:
2365
+ prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
2366
+ if negative_prompt is not None:
2367
+ negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
2368
+
2369
+ height = max(64, height - height % 8) # round to divisible by 8
2370
+ width = max(64, width - width % 8) # round to divisible by 8
2371
+ print(f"prompt: {prompt}")
2372
+ print(f"negative_prompt: {negative_prompt}")
2373
+ print(f"height: {height}")
2374
+ print(f"width: {width}")
2375
+ print(f"sample_steps: {sample_steps}")
2376
+ print(f"scale: {scale}")
2377
+ image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
2378
+
2379
+ ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
2380
+ num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
2381
+ seed_suffix = "" if seed is None else f"_{seed}"
2382
+ img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
2383
+
2384
+ image.save(os.path.join(save_dir, img_filename))
2385
+
2386
+ # clear pipeline and cache to reduce vram usage
2387
+ del pipeline
2388
+ torch.cuda.empty_cache()
2389
+
2390
+ torch.set_rng_state(rng_state)
2391
+ torch.cuda.set_rng_state(cuda_rng_state)
2392
+ vae.to(org_vae_device)
2393
+
2394
  # endregion
2395
 
2396
  # region 前処理用
networks/check_lora_weights.py CHANGED
@@ -21,7 +21,7 @@ def main(file):
21
 
22
  for key, value in values:
23
  value = value.to(torch.float32)
24
- print(f"{key},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
25
 
26
 
27
  if __name__ == '__main__':
 
21
 
22
  for key, value in values:
23
  value = value.to(torch.float32)
24
+ print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
25
 
26
 
27
  if __name__ == '__main__':
networks/extract_lora_from_models.py CHANGED
@@ -45,8 +45,13 @@ def svd(args):
45
  text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
46
 
47
  # create LoRA network to extract weights: Use dim (rank) as alpha
48
- lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o)
49
- lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t)
 
 
 
 
 
50
  assert len(lora_network_o.text_encoder_loras) == len(
51
  lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
52
 
@@ -85,13 +90,28 @@ def svd(args):
85
 
86
  # make LoRA with svd
87
  print("calculating by svd")
88
- rank = args.dim
89
  lora_weights = {}
90
  with torch.no_grad():
91
  for lora_name, mat in tqdm(list(diffs.items())):
 
92
  conv2d = (len(mat.size()) == 4)
 
 
 
 
 
 
 
 
 
 
 
 
93
  if conv2d:
94
- mat = mat.squeeze()
 
 
 
95
 
96
  U, S, Vh = torch.linalg.svd(mat)
97
 
@@ -108,30 +128,27 @@ def svd(args):
108
  U = U.clamp(low_val, hi_val)
109
  Vh = Vh.clamp(low_val, hi_val)
110
 
111
- lora_weights[lora_name] = (U, Vh)
112
-
113
- # make state dict for LoRA
114
- lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict
115
- lora_sd = lora_network_o.state_dict()
116
- print(f"LoRA has {len(lora_sd)} weights.")
117
-
118
- for key in list(lora_sd.keys()):
119
- if "alpha" in key:
120
- continue
121
 
122
- lora_name = key.split('.')[0]
123
- i = 0 if "lora_up" in key else 1
124
 
125
- weights = lora_weights[lora_name][i]
126
- # print(key, i, weights.size(), lora_sd[key].size())
127
- if len(lora_sd[key].size()) == 4:
128
- weights = weights.unsqueeze(2).unsqueeze(3)
129
 
130
- assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
131
- lora_sd[key] = weights
 
 
 
 
132
 
133
  # load state dict to LoRA and save it
134
- info = lora_network_o.load_state_dict(lora_sd)
 
 
 
135
  print(f"Loading extracted LoRA weights: {info}")
136
 
137
  dir_name = os.path.dirname(args.save_to)
@@ -139,9 +156,9 @@ def svd(args):
139
  os.makedirs(dir_name, exist_ok=True)
140
 
141
  # minimum metadata
142
- metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
143
 
144
- lora_network_o.save_weights(args.save_to, save_dtype, metadata)
145
  print(f"LoRA weights are saved to: {args.save_to}")
146
 
147
 
@@ -158,6 +175,8 @@ if __name__ == '__main__':
158
  parser.add_argument("--save_to", type=str, default=None,
159
  help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
160
  parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
 
 
161
  parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイ��、cuda でGPUを使う")
162
 
163
  args = parser.parse_args()
 
45
  text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
46
 
47
  # create LoRA network to extract weights: Use dim (rank) as alpha
48
+ if args.conv_dim is None:
49
+ kwargs = {}
50
+ else:
51
+ kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim}
52
+
53
+ lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o, **kwargs)
54
+ lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t, **kwargs)
55
  assert len(lora_network_o.text_encoder_loras) == len(
56
  lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
57
 
 
90
 
91
  # make LoRA with svd
92
  print("calculating by svd")
 
93
  lora_weights = {}
94
  with torch.no_grad():
95
  for lora_name, mat in tqdm(list(diffs.items())):
96
+ # if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3
97
  conv2d = (len(mat.size()) == 4)
98
+ kernel_size = None if not conv2d else mat.size()[2:4]
99
+ conv2d_3x3 = conv2d and kernel_size != (1, 1)
100
+
101
+ rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim
102
+ out_dim, in_dim = mat.size()[0:2]
103
+
104
+ if args.device:
105
+ mat = mat.to(args.device)
106
+
107
+ # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
108
+ rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
109
+
110
  if conv2d:
111
+ if conv2d_3x3:
112
+ mat = mat.flatten(start_dim=1)
113
+ else:
114
+ mat = mat.squeeze()
115
 
116
  U, S, Vh = torch.linalg.svd(mat)
117
 
 
128
  U = U.clamp(low_val, hi_val)
129
  Vh = Vh.clamp(low_val, hi_val)
130
 
131
+ if conv2d:
132
+ U = U.reshape(out_dim, rank, 1, 1)
133
+ Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
 
 
 
 
 
 
 
134
 
135
+ U = U.to("cpu").contiguous()
136
+ Vh = Vh.to("cpu").contiguous()
137
 
138
+ lora_weights[lora_name] = (U, Vh)
 
 
 
139
 
140
+ # make state dict for LoRA
141
+ lora_sd = {}
142
+ for lora_name, (up_weight, down_weight) in lora_weights.items():
143
+ lora_sd[lora_name + '.lora_up.weight'] = up_weight
144
+ lora_sd[lora_name + '.lora_down.weight'] = down_weight
145
+ lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0])
146
 
147
  # load state dict to LoRA and save it
148
+ lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
149
+ lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict
150
+
151
+ info = lora_network_save.load_state_dict(lora_sd)
152
  print(f"Loading extracted LoRA weights: {info}")
153
 
154
  dir_name = os.path.dirname(args.save_to)
 
156
  os.makedirs(dir_name, exist_ok=True)
157
 
158
  # minimum metadata
159
+ metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
160
 
161
+ lora_network_save.save_weights(args.save_to, save_dtype, metadata)
162
  print(f"LoRA weights are saved to: {args.save_to}")
163
 
164
 
 
175
  parser.add_argument("--save_to", type=str, default=None,
176
  help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
177
  parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
178
+ parser.add_argument("--conv_dim", type=int, default=None,
179
+ help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)")
180
  parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイ��、cuda でGPUを使う")
181
 
182
  args = parser.parse_args()
networks/lora.py CHANGED
@@ -6,6 +6,7 @@
6
  import math
7
  import os
8
  from typing import List
 
9
  import torch
10
 
11
  from library import train_util
@@ -20,22 +21,34 @@ class LoRAModule(torch.nn.Module):
20
  """ if alpha == 0 or None, alpha is rank (no scaling). """
21
  super().__init__()
22
  self.lora_name = lora_name
23
- self.lora_dim = lora_dim
24
 
25
  if org_module.__class__.__name__ == 'Conv2d':
26
  in_dim = org_module.in_channels
27
  out_dim = org_module.out_channels
28
- self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
29
- self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
30
  else:
31
  in_dim = org_module.in_features
32
  out_dim = org_module.out_features
33
- self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
34
- self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  if type(alpha) == torch.Tensor:
37
  alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
38
- alpha = lora_dim if alpha is None or alpha == 0 else alpha
39
  self.scale = alpha / self.lora_dim
40
  self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
41
 
@@ -45,69 +58,192 @@ class LoRAModule(torch.nn.Module):
45
 
46
  self.multiplier = multiplier
47
  self.org_module = org_module # remove in applying
 
 
48
 
49
  def apply_to(self):
50
  self.org_forward = self.org_module.forward
51
  self.org_module.forward = self.forward
52
  del self.org_module
53
 
 
 
 
 
54
  def forward(self, x):
55
- return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
 
58
  def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
59
  if network_dim is None:
60
  network_dim = 4 # default
61
- network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
62
- return network
63
 
 
 
 
 
 
 
 
 
 
64
 
65
- def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
66
- if os.path.splitext(file)[1] == '.safetensors':
67
- from safetensors.torch import load_file, safe_open
68
- weights_sd = load_file(file)
69
- else:
70
- weights_sd = torch.load(file, map_location='cpu')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- # get dim (rank)
73
- network_alpha = None
74
- network_dim = None
75
- for key, value in weights_sd.items():
76
- if network_alpha is None and 'alpha' in key:
77
- network_alpha = value
78
- if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
79
- network_dim = value.size()[0]
80
 
81
- if network_alpha is None:
82
- network_alpha = network_dim
83
 
84
- network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  network.weights_sd = weights_sd
86
  return network
87
 
88
 
89
  class LoRANetwork(torch.nn.Module):
 
90
  UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
 
91
  TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
92
  LORA_PREFIX_UNET = 'lora_unet'
93
  LORA_PREFIX_TEXT_ENCODER = 'lora_te'
94
 
95
- def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
96
  super().__init__()
97
  self.multiplier = multiplier
 
98
  self.lora_dim = lora_dim
99
  self.alpha = alpha
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  # create module instances
102
  def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
103
  loras = []
104
  for name, module in root_module.named_modules():
105
  if module.__class__.__name__ in target_replace_modules:
 
106
  for child_name, child_module in module.named_modules():
107
- if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
 
 
 
108
  lora_name = prefix + '.' + name + '.' + child_name
109
  lora_name = lora_name.replace('.', '_')
110
- lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  loras.append(lora)
112
  return loras
113
 
@@ -115,7 +251,12 @@ class LoRANetwork(torch.nn.Module):
115
  text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
116
  print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
117
 
118
- self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
 
 
 
 
 
119
  print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
120
 
121
  self.weights_sd = None
@@ -126,6 +267,11 @@ class LoRANetwork(torch.nn.Module):
126
  assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
127
  names.add(lora.lora_name)
128
 
 
 
 
 
 
129
  def load_weights(self, file):
130
  if os.path.splitext(file)[1] == '.safetensors':
131
  from safetensors.torch import load_file, safe_open
@@ -235,3 +381,18 @@ class LoRANetwork(torch.nn.Module):
235
  save_file(state_dict, file, metadata)
236
  else:
237
  torch.save(state_dict, file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import math
7
  import os
8
  from typing import List
9
+ import numpy as np
10
  import torch
11
 
12
  from library import train_util
 
21
  """ if alpha == 0 or None, alpha is rank (no scaling). """
22
  super().__init__()
23
  self.lora_name = lora_name
 
24
 
25
  if org_module.__class__.__name__ == 'Conv2d':
26
  in_dim = org_module.in_channels
27
  out_dim = org_module.out_channels
 
 
28
  else:
29
  in_dim = org_module.in_features
30
  out_dim = org_module.out_features
31
+
32
+ # if limit_rank:
33
+ # self.lora_dim = min(lora_dim, in_dim, out_dim)
34
+ # if self.lora_dim != lora_dim:
35
+ # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
36
+ # else:
37
+ self.lora_dim = lora_dim
38
+
39
+ if org_module.__class__.__name__ == 'Conv2d':
40
+ kernel_size = org_module.kernel_size
41
+ stride = org_module.stride
42
+ padding = org_module.padding
43
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
44
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
45
+ else:
46
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
47
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
48
 
49
  if type(alpha) == torch.Tensor:
50
  alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
51
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
52
  self.scale = alpha / self.lora_dim
53
  self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
54
 
 
58
 
59
  self.multiplier = multiplier
60
  self.org_module = org_module # remove in applying
61
+ self.region = None
62
+ self.region_mask = None
63
 
64
  def apply_to(self):
65
  self.org_forward = self.org_module.forward
66
  self.org_module.forward = self.forward
67
  del self.org_module
68
 
69
+ def set_region(self, region):
70
+ self.region = region
71
+ self.region_mask = None
72
+
73
  def forward(self, x):
74
+ if self.region is None:
75
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
76
+
77
+ # regional LoRA FIXME same as additional-network extension
78
+ if x.size()[1] % 77 == 0:
79
+ # print(f"LoRA for context: {self.lora_name}")
80
+ self.region = None
81
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
82
+
83
+ # calculate region mask first time
84
+ if self.region_mask is None:
85
+ if len(x.size()) == 4:
86
+ h, w = x.size()[2:4]
87
+ else:
88
+ seq_len = x.size()[1]
89
+ ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
90
+ h = int(self.region.size()[0] / ratio + .5)
91
+ w = seq_len // h
92
+
93
+ r = self.region.to(x.device)
94
+ if r.dtype == torch.bfloat16:
95
+ r = r.to(torch.float)
96
+ r = r.unsqueeze(0).unsqueeze(1)
97
+ # print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
98
+ r = torch.nn.functional.interpolate(r, (h, w), mode='bilinear')
99
+ r = r.to(x.dtype)
100
+
101
+ if len(x.size()) == 3:
102
+ r = torch.reshape(r, (1, x.size()[1], -1))
103
+
104
+ self.region_mask = r
105
+
106
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask
107
 
108
 
109
  def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
110
  if network_dim is None:
111
  network_dim = 4 # default
 
 
112
 
113
+ # extract dim/alpha for conv2d, and block dim
114
+ conv_dim = kwargs.get('conv_dim', None)
115
+ conv_alpha = kwargs.get('conv_alpha', None)
116
+ if conv_dim is not None:
117
+ conv_dim = int(conv_dim)
118
+ if conv_alpha is None:
119
+ conv_alpha = 1.0
120
+ else:
121
+ conv_alpha = float(conv_alpha)
122
 
123
+ """
124
+ block_dims = kwargs.get("block_dims")
125
+ block_alphas = None
126
+
127
+ if block_dims is not None:
128
+ block_dims = [int(d) for d in block_dims.split(',')]
129
+ assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
130
+ block_alphas = kwargs.get("block_alphas")
131
+ if block_alphas is None:
132
+ block_alphas = [1] * len(block_dims)
133
+ else:
134
+ block_alphas = [int(a) for a in block_alphas(',')]
135
+ assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
136
+
137
+ conv_block_dims = kwargs.get("conv_block_dims")
138
+ conv_block_alphas = None
139
+
140
+ if conv_block_dims is not None:
141
+ conv_block_dims = [int(d) for d in conv_block_dims.split(',')]
142
+ assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
143
+ conv_block_alphas = kwargs.get("conv_block_alphas")
144
+ if conv_block_alphas is None:
145
+ conv_block_alphas = [1] * len(conv_block_dims)
146
+ else:
147
+ conv_block_alphas = [int(a) for a in conv_block_alphas(',')]
148
+ assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
149
+ """
150
 
151
+ network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim,
152
+ alpha=network_alpha, conv_lora_dim=conv_dim, conv_alpha=conv_alpha)
153
+ return network
 
 
 
 
 
154
 
 
 
155
 
156
+ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
157
+ if weights_sd is None:
158
+ if os.path.splitext(file)[1] == '.safetensors':
159
+ from safetensors.torch import load_file, safe_open
160
+ weights_sd = load_file(file)
161
+ else:
162
+ weights_sd = torch.load(file, map_location='cpu')
163
+
164
+ # get dim/alpha mapping
165
+ modules_dim = {}
166
+ modules_alpha = {}
167
+ for key, value in weights_sd.items():
168
+ if '.' not in key:
169
+ continue
170
+
171
+ lora_name = key.split('.')[0]
172
+ if 'alpha' in key:
173
+ modules_alpha[lora_name] = value
174
+ elif 'lora_down' in key:
175
+ dim = value.size()[0]
176
+ modules_dim[lora_name] = dim
177
+ # print(lora_name, value.size(), dim)
178
+
179
+ # support old LoRA without alpha
180
+ for key in modules_dim.keys():
181
+ if key not in modules_alpha:
182
+ modules_alpha = modules_dim[key]
183
+
184
+ network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
185
  network.weights_sd = weights_sd
186
  return network
187
 
188
 
189
  class LoRANetwork(torch.nn.Module):
190
+ # is it possible to apply conv_in and conv_out?
191
  UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
192
+ UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
193
  TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
194
  LORA_PREFIX_UNET = 'lora_unet'
195
  LORA_PREFIX_TEXT_ENCODER = 'lora_te'
196
 
197
+ def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1, conv_lora_dim=None, conv_alpha=None, modules_dim=None, modules_alpha=None) -> None:
198
  super().__init__()
199
  self.multiplier = multiplier
200
+
201
  self.lora_dim = lora_dim
202
  self.alpha = alpha
203
+ self.conv_lora_dim = conv_lora_dim
204
+ self.conv_alpha = conv_alpha
205
+
206
+ if modules_dim is not None:
207
+ print(f"create LoRA network from weights")
208
+ else:
209
+ print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
210
+
211
+ self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None
212
+ if self.apply_to_conv2d_3x3:
213
+ if self.conv_alpha is None:
214
+ self.conv_alpha = self.alpha
215
+ print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
216
 
217
  # create module instances
218
  def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
219
  loras = []
220
  for name, module in root_module.named_modules():
221
  if module.__class__.__name__ in target_replace_modules:
222
+ # TODO get block index here
223
  for child_name, child_module in module.named_modules():
224
+ is_linear = child_module.__class__.__name__ == "Linear"
225
+ is_conv2d = child_module.__class__.__name__ == "Conv2d"
226
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
227
+ if is_linear or is_conv2d:
228
  lora_name = prefix + '.' + name + '.' + child_name
229
  lora_name = lora_name.replace('.', '_')
230
+
231
+ if modules_dim is not None:
232
+ if lora_name not in modules_dim:
233
+ continue # no LoRA module in this weights file
234
+ dim = modules_dim[lora_name]
235
+ alpha = modules_alpha[lora_name]
236
+ else:
237
+ if is_linear or is_conv2d_1x1:
238
+ dim = self.lora_dim
239
+ alpha = self.alpha
240
+ elif self.apply_to_conv2d_3x3:
241
+ dim = self.conv_lora_dim
242
+ alpha = self.conv_alpha
243
+ else:
244
+ continue
245
+
246
+ lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
247
  loras.append(lora)
248
  return loras
249
 
 
251
  text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
252
  print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
253
 
254
+ # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
255
+ target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
256
+ if modules_dim is not None or self.conv_lora_dim is not None:
257
+ target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
258
+
259
+ self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules)
260
  print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
261
 
262
  self.weights_sd = None
 
267
  assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
268
  names.add(lora.lora_name)
269
 
270
+ def set_multiplier(self, multiplier):
271
+ self.multiplier = multiplier
272
+ for lora in self.text_encoder_loras + self.unet_loras:
273
+ lora.multiplier = self.multiplier
274
+
275
  def load_weights(self, file):
276
  if os.path.splitext(file)[1] == '.safetensors':
277
  from safetensors.torch import load_file, safe_open
 
381
  save_file(state_dict, file, metadata)
382
  else:
383
  torch.save(state_dict, file)
384
+
385
+ @ staticmethod
386
+ def set_regions(networks, image):
387
+ image = image.astype(np.float32) / 255.0
388
+ for i, network in enumerate(networks[:3]):
389
+ # NOTE: consider averaging overwrapping area
390
+ region = image[:, :, i]
391
+ if region.max() == 0:
392
+ continue
393
+ region = torch.tensor(region)
394
+ network.set_region(region)
395
+
396
+ def set_region(self, region):
397
+ for lora in self.unet_loras:
398
+ lora.set_region(region)
networks/merge_lora.py CHANGED
@@ -48,7 +48,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
48
  for name, module in root_module.named_modules():
49
  if module.__class__.__name__ in target_replace_modules:
50
  for child_name, child_module in module.named_modules():
51
- if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
52
  lora_name = prefix + '.' + name + '.' + child_name
53
  lora_name = lora_name.replace('.', '_')
54
  name_to_module[lora_name] = child_module
@@ -80,13 +80,19 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
80
 
81
  # W <- W + U * D
82
  weight = module.weight
 
83
  if len(weight.size()) == 2:
84
  # linear
85
  weight = weight + ratio * (up_weight @ down_weight) * scale
86
- else:
87
- # conv2d
88
  weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
89
  ).unsqueeze(2).unsqueeze(3) * scale
 
 
 
 
 
90
 
91
  module.weight = torch.nn.Parameter(weight)
92
 
@@ -123,7 +129,7 @@ def merge_lora_models(models, ratios, merge_dtype):
123
  alphas[lora_module_name] = alpha
124
  if lora_module_name not in base_alphas:
125
  base_alphas[lora_module_name] = alpha
126
-
127
  print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
128
 
129
  # merge
@@ -145,7 +151,7 @@ def merge_lora_models(models, ratios, merge_dtype):
145
  merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
146
  else:
147
  merged_sd[key] = lora_sd[key] * scale
148
-
149
  # set alpha to sd
150
  for lora_module_name, alpha in base_alphas.items():
151
  key = lora_module_name + ".alpha"
 
48
  for name, module in root_module.named_modules():
49
  if module.__class__.__name__ in target_replace_modules:
50
  for child_name, child_module in module.named_modules():
51
+ if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
52
  lora_name = prefix + '.' + name + '.' + child_name
53
  lora_name = lora_name.replace('.', '_')
54
  name_to_module[lora_name] = child_module
 
80
 
81
  # W <- W + U * D
82
  weight = module.weight
83
+ # print(module_name, down_weight.size(), up_weight.size())
84
  if len(weight.size()) == 2:
85
  # linear
86
  weight = weight + ratio * (up_weight @ down_weight) * scale
87
+ elif down_weight.size()[2:4] == (1, 1):
88
+ # conv2d 1x1
89
  weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
90
  ).unsqueeze(2).unsqueeze(3) * scale
91
+ else:
92
+ # conv2d 3x3
93
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
94
+ # print(conved.size(), weight.size(), module.stride, module.padding)
95
+ weight = weight + ratio * conved * scale
96
 
97
  module.weight = torch.nn.Parameter(weight)
98
 
 
129
  alphas[lora_module_name] = alpha
130
  if lora_module_name not in base_alphas:
131
  base_alphas[lora_module_name] = alpha
132
+
133
  print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
134
 
135
  # merge
 
151
  merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
152
  else:
153
  merged_sd[key] = lora_sd[key] * scale
154
+
155
  # set alpha to sd
156
  for lora_module_name, alpha in base_alphas.items():
157
  key = lora_module_name + ".alpha"
networks/resize_lora.py CHANGED
@@ -1,14 +1,15 @@
1
  # Convert LoRA to different rank approximation (should only be used to go to lower rank)
2
  # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
3
- # Thanks to cloneofsimo and kohya
4
 
5
  import argparse
6
- import os
7
  import torch
8
  from safetensors.torch import load_file, save_file, safe_open
9
  from tqdm import tqdm
10
  from library import train_util, model_util
 
11
 
 
12
 
13
  def load_state_dict(file_name, dtype):
14
  if model_util.is_safetensors(file_name):
@@ -38,12 +39,149 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
38
  torch.save(model, file_name)
39
 
40
 
41
- def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  network_alpha = None
43
  network_dim = None
44
  verbose_str = "\n"
45
-
46
- CLAMP_QUANTILE = 0.99
47
 
48
  # Extract loaded lora dim and alpha
49
  for key, value in lora_sd.items():
@@ -57,9 +195,9 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
57
  network_alpha = network_dim
58
 
59
  scale = network_alpha/network_dim
60
- new_alpha = float(scale*new_rank) # calculate new alpha from scale
61
 
62
- print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}")
 
63
 
64
  lora_down_weight = None
65
  lora_up_weight = None
@@ -68,7 +206,6 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
68
  block_down_name = None
69
  block_up_name = None
70
 
71
- print("resizing lora...")
72
  with torch.no_grad():
73
  for key, value in tqdm(lora_sd.items()):
74
  if 'lora_down' in key:
@@ -85,57 +222,43 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
85
  conv2d = (len(lora_down_weight.size()) == 4)
86
 
87
  if conv2d:
88
- lora_down_weight = lora_down_weight.squeeze()
89
- lora_up_weight = lora_up_weight.squeeze()
90
-
91
- if device:
92
- org_device = lora_up_weight.device
93
- lora_up_weight = lora_up_weight.to(args.device)
94
- lora_down_weight = lora_down_weight.to(args.device)
95
-
96
- full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
97
-
98
- U, S, Vh = torch.linalg.svd(full_weight_matrix)
99
 
100
  if verbose:
101
- s_sum = torch.sum(torch.abs(S))
102
- s_rank = torch.sum(torch.abs(S[:new_rank]))
103
- verbose_str+=f"{block_down_name:76} | "
104
- verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n"
 
105
 
106
- U = U[:, :new_rank]
107
- S = S[:new_rank]
108
- U = U @ torch.diag(S)
109
 
110
- Vh = Vh[:new_rank, :]
 
 
 
111
 
112
- dist = torch.cat([U.flatten(), Vh.flatten()])
113
- hi_val = torch.quantile(dist, CLAMP_QUANTILE)
114
- low_val = -hi_val
115
-
116
- U = U.clamp(low_val, hi_val)
117
- Vh = Vh.clamp(low_val, hi_val)
118
-
119
- if conv2d:
120
- U = U.unsqueeze(2).unsqueeze(3)
121
- Vh = Vh.unsqueeze(2).unsqueeze(3)
122
-
123
- if device:
124
- U = U.to(org_device)
125
- Vh = Vh.to(org_device)
126
-
127
- o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
128
- o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
129
- o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
130
 
131
  block_down_name = None
132
  block_up_name = None
133
  lora_down_weight = None
134
  lora_up_weight = None
135
  weights_loaded = False
 
136
 
137
  if verbose:
138
  print(verbose_str)
 
 
139
  print("resizing complete")
140
  return o_lora_sd, network_dim, new_alpha
141
 
@@ -151,6 +274,9 @@ def resize(args):
151
  return torch.bfloat16
152
  return None
153
 
 
 
 
154
  merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
155
  save_dtype = str_to_dtype(args.save_precision)
156
  if save_dtype is None:
@@ -159,17 +285,23 @@ def resize(args):
159
  print("loading Model...")
160
  lora_sd, metadata = load_state_dict(args.model, merge_dtype)
161
 
162
- print("resizing rank...")
163
- state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose)
164
 
165
  # update metadata
166
  if metadata is None:
167
  metadata = {}
168
 
169
  comment = metadata.get("ss_training_comment", "")
170
- metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
171
- metadata["ss_network_dim"] = str(args.new_rank)
172
- metadata["ss_network_alpha"] = str(new_alpha)
 
 
 
 
 
 
173
 
174
  model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
175
  metadata["sshs_model_hash"] = model_hash
@@ -193,6 +325,11 @@ if __name__ == '__main__':
193
  parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
194
  parser.add_argument("--verbose", action="store_true",
195
  help="Display verbose resizing information / rank変更時の詳細情報を出力する")
 
 
 
 
 
196
 
197
  args = parser.parse_args()
198
  resize(args)
 
1
  # Convert LoRA to different rank approximation (should only be used to go to lower rank)
2
  # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
3
+ # Thanks to cloneofsimo
4
 
5
  import argparse
 
6
  import torch
7
  from safetensors.torch import load_file, save_file, safe_open
8
  from tqdm import tqdm
9
  from library import train_util, model_util
10
+ import numpy as np
11
 
12
+ MIN_SV = 1e-6
13
 
14
  def load_state_dict(file_name, dtype):
15
  if model_util.is_safetensors(file_name):
 
39
  torch.save(model, file_name)
40
 
41
 
42
+ def index_sv_cumulative(S, target):
43
+ original_sum = float(torch.sum(S))
44
+ cumulative_sums = torch.cumsum(S, dim=0)/original_sum
45
+ index = int(torch.searchsorted(cumulative_sums, target)) + 1
46
+ if index >= len(S):
47
+ index = len(S) - 1
48
+
49
+ return index
50
+
51
+
52
+ def index_sv_fro(S, target):
53
+ S_squared = S.pow(2)
54
+ s_fro_sq = float(torch.sum(S_squared))
55
+ sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
56
+ index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
57
+ if index >= len(S):
58
+ index = len(S) - 1
59
+
60
+ return index
61
+
62
+
63
+ # Modified from Kohaku-blueleaf's extract/merge functions
64
+ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
65
+ out_size, in_size, kernel_size, _ = weight.size()
66
+ U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
67
+
68
+ param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
69
+ lora_rank = param_dict["new_rank"]
70
+
71
+ U = U[:, :lora_rank]
72
+ S = S[:lora_rank]
73
+ U = U @ torch.diag(S)
74
+ Vh = Vh[:lora_rank, :]
75
+
76
+ param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
77
+ param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
78
+ del U, S, Vh, weight
79
+ return param_dict
80
+
81
+
82
+ def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
83
+ out_size, in_size = weight.size()
84
+
85
+ U, S, Vh = torch.linalg.svd(weight.to(device))
86
+
87
+ param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
88
+ lora_rank = param_dict["new_rank"]
89
+
90
+ U = U[:, :lora_rank]
91
+ S = S[:lora_rank]
92
+ U = U @ torch.diag(S)
93
+ Vh = Vh[:lora_rank, :]
94
+
95
+ param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
96
+ param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
97
+ del U, S, Vh, weight
98
+ return param_dict
99
+
100
+
101
+ def merge_conv(lora_down, lora_up, device):
102
+ in_rank, in_size, kernel_size, k_ = lora_down.shape
103
+ out_size, out_rank, _, _ = lora_up.shape
104
+ assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
105
+
106
+ lora_down = lora_down.to(device)
107
+ lora_up = lora_up.to(device)
108
+
109
+ merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
110
+ weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
111
+ del lora_up, lora_down
112
+ return weight
113
+
114
+
115
+ def merge_linear(lora_down, lora_up, device):
116
+ in_rank, in_size = lora_down.shape
117
+ out_size, out_rank = lora_up.shape
118
+ assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
119
+
120
+ lora_down = lora_down.to(device)
121
+ lora_up = lora_up.to(device)
122
+
123
+ weight = lora_up @ lora_down
124
+ del lora_up, lora_down
125
+ return weight
126
+
127
+
128
+ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
129
+ param_dict = {}
130
+
131
+ if dynamic_method=="sv_ratio":
132
+ # Calculate new dim and alpha based off ratio
133
+ max_sv = S[0]
134
+ min_sv = max_sv/dynamic_param
135
+ new_rank = max(torch.sum(S > min_sv).item(),1)
136
+ new_alpha = float(scale*new_rank)
137
+
138
+ elif dynamic_method=="sv_cumulative":
139
+ # Calculate new dim and alpha based off cumulative sum
140
+ new_rank = index_sv_cumulative(S, dynamic_param)
141
+ new_rank = max(new_rank, 1)
142
+ new_alpha = float(scale*new_rank)
143
+
144
+ elif dynamic_method=="sv_fro":
145
+ # Calculate new dim and alpha based off sqrt sum of squares
146
+ new_rank = index_sv_fro(S, dynamic_param)
147
+ new_rank = min(max(new_rank, 1), len(S)-1)
148
+ new_alpha = float(scale*new_rank)
149
+ else:
150
+ new_rank = rank
151
+ new_alpha = float(scale*new_rank)
152
+
153
+
154
+ if S[0] <= MIN_SV: # Zero matrix, set dim to 1
155
+ new_rank = 1
156
+ new_alpha = float(scale*new_rank)
157
+ elif new_rank > rank: # cap max rank at rank
158
+ new_rank = rank
159
+ new_alpha = float(scale*new_rank)
160
+
161
+
162
+ # Calculate resize info
163
+ s_sum = torch.sum(torch.abs(S))
164
+ s_rank = torch.sum(torch.abs(S[:new_rank]))
165
+
166
+ S_squared = S.pow(2)
167
+ s_fro = torch.sqrt(torch.sum(S_squared))
168
+ s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
169
+ fro_percent = float(s_red_fro/s_fro)
170
+
171
+ param_dict["new_rank"] = new_rank
172
+ param_dict["new_alpha"] = new_alpha
173
+ param_dict["sum_retained"] = (s_rank)/s_sum
174
+ param_dict["fro_retained"] = fro_percent
175
+ param_dict["max_ratio"] = S[0]/S[new_rank]
176
+
177
+ return param_dict
178
+
179
+
180
+ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
181
  network_alpha = None
182
  network_dim = None
183
  verbose_str = "\n"
184
+ fro_list = []
 
185
 
186
  # Extract loaded lora dim and alpha
187
  for key, value in lora_sd.items():
 
195
  network_alpha = network_dim
196
 
197
  scale = network_alpha/network_dim
 
198
 
199
+ if dynamic_method:
200
+ print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
201
 
202
  lora_down_weight = None
203
  lora_up_weight = None
 
206
  block_down_name = None
207
  block_up_name = None
208
 
 
209
  with torch.no_grad():
210
  for key, value in tqdm(lora_sd.items()):
211
  if 'lora_down' in key:
 
222
  conv2d = (len(lora_down_weight.size()) == 4)
223
 
224
  if conv2d:
225
+ full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
226
+ param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
227
+ else:
228
+ full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
229
+ param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
 
 
 
 
 
 
230
 
231
  if verbose:
232
+ max_ratio = param_dict['max_ratio']
233
+ sum_retained = param_dict['sum_retained']
234
+ fro_retained = param_dict['fro_retained']
235
+ if not np.isnan(fro_retained):
236
+ fro_list.append(float(fro_retained))
237
 
238
+ verbose_str+=f"{block_down_name:75} | "
239
+ verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
 
240
 
241
+ if verbose and dynamic_method:
242
+ verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
243
+ else:
244
+ verbose_str+=f"\n"
245
 
246
+ new_alpha = param_dict['new_alpha']
247
+ o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
248
+ o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
249
+ o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
  block_down_name = None
252
  block_up_name = None
253
  lora_down_weight = None
254
  lora_up_weight = None
255
  weights_loaded = False
256
+ del param_dict
257
 
258
  if verbose:
259
  print(verbose_str)
260
+
261
+ print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
262
  print("resizing complete")
263
  return o_lora_sd, network_dim, new_alpha
264
 
 
274
  return torch.bfloat16
275
  return None
276
 
277
+ if args.dynamic_method and not args.dynamic_param:
278
+ raise Exception("If using dynamic_method, then dynamic_param is required")
279
+
280
  merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
281
  save_dtype = str_to_dtype(args.save_precision)
282
  if save_dtype is None:
 
285
  print("loading Model...")
286
  lora_sd, metadata = load_state_dict(args.model, merge_dtype)
287
 
288
+ print("Resizing Lora...")
289
+ state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
290
 
291
  # update metadata
292
  if metadata is None:
293
  metadata = {}
294
 
295
  comment = metadata.get("ss_training_comment", "")
296
+
297
+ if not args.dynamic_method:
298
+ metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
299
+ metadata["ss_network_dim"] = str(args.new_rank)
300
+ metadata["ss_network_alpha"] = str(new_alpha)
301
+ else:
302
+ metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
303
+ metadata["ss_network_dim"] = 'Dynamic'
304
+ metadata["ss_network_alpha"] = 'Dynamic'
305
 
306
  model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
307
  metadata["sshs_model_hash"] = model_hash
 
325
  parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
326
  parser.add_argument("--verbose", action="store_true",
327
  help="Display verbose resizing information / rank変更時の詳細情報を出力する")
328
+ parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
329
+ help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
330
+ parser.add_argument("--dynamic_param", type=float, default=None,
331
+ help="Specify target for dynamic reduction")
332
+
333
 
334
  args = parser.parse_args()
335
  resize(args)
networks/svd_merge_lora.py CHANGED
@@ -23,19 +23,20 @@ def load_state_dict(file_name, dtype):
23
  return sd
24
 
25
 
26
- def save_to_file(file_name, model, state_dict, dtype):
27
  if dtype is not None:
28
  for key in list(state_dict.keys()):
29
  if type(state_dict[key]) == torch.Tensor:
30
  state_dict[key] = state_dict[key].to(dtype)
31
 
32
  if os.path.splitext(file_name)[1] == '.safetensors':
33
- save_file(model, file_name)
34
  else:
35
- torch.save(model, file_name)
36
 
37
 
38
- def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
 
39
  merged_sd = {}
40
  for model, ratio in zip(models, ratios):
41
  print(f"loading: {model}")
@@ -58,11 +59,12 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
58
  in_dim = down_weight.size()[1]
59
  out_dim = up_weight.size()[0]
60
  conv2d = len(down_weight.size()) == 4
61
- print(lora_module_name, network_dim, alpha, in_dim, out_dim)
 
62
 
63
  # make original weight if not exist
64
  if lora_module_name not in merged_sd:
65
- weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
66
  if device:
67
  weight = weight.to(device)
68
  else:
@@ -75,11 +77,18 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
75
 
76
  # W <- W + U * D
77
  scale = (alpha / network_dim)
 
 
 
 
78
  if not conv2d: # linear
79
  weight = weight + ratio * (up_weight @ down_weight) * scale
80
- else:
81
  weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
82
  ).unsqueeze(2).unsqueeze(3) * scale
 
 
 
83
 
84
  merged_sd[lora_module_name] = weight
85
 
@@ -89,16 +98,26 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
89
  with torch.no_grad():
90
  for lora_module_name, mat in tqdm(list(merged_sd.items())):
91
  conv2d = (len(mat.size()) == 4)
 
 
 
 
92
  if conv2d:
93
- mat = mat.squeeze()
 
 
 
 
 
 
94
 
95
  U, S, Vh = torch.linalg.svd(mat)
96
 
97
- U = U[:, :new_rank]
98
- S = S[:new_rank]
99
  U = U @ torch.diag(S)
100
 
101
- Vh = Vh[:new_rank, :]
102
 
103
  dist = torch.cat([U.flatten(), Vh.flatten()])
104
  hi_val = torch.quantile(dist, CLAMP_QUANTILE)
@@ -107,16 +126,16 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
107
  U = U.clamp(low_val, hi_val)
108
  Vh = Vh.clamp(low_val, hi_val)
109
 
 
 
 
 
110
  up_weight = U
111
  down_weight = Vh
112
 
113
- if conv2d:
114
- up_weight = up_weight.unsqueeze(2).unsqueeze(3)
115
- down_weight = down_weight.unsqueeze(2).unsqueeze(3)
116
-
117
  merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
118
  merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
119
- merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank)
120
 
121
  return merged_lora_sd
122
 
@@ -138,10 +157,11 @@ def merge(args):
138
  if save_dtype is None:
139
  save_dtype = merge_dtype
140
 
141
- state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, args.device, merge_dtype)
 
142
 
143
  print(f"saving model to: {args.save_to}")
144
- save_to_file(args.save_to, state_dict, state_dict, save_dtype)
145
 
146
 
147
  if __name__ == '__main__':
@@ -158,6 +178,8 @@ if __name__ == '__main__':
158
  help="ratios for each model / それぞれのLoRAモデルの比率")
159
  parser.add_argument("--new_rank", type=int, default=4,
160
  help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
 
 
161
  parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
162
 
163
  args = parser.parse_args()
 
23
  return sd
24
 
25
 
26
+ def save_to_file(file_name, state_dict, dtype):
27
  if dtype is not None:
28
  for key in list(state_dict.keys()):
29
  if type(state_dict[key]) == torch.Tensor:
30
  state_dict[key] = state_dict[key].to(dtype)
31
 
32
  if os.path.splitext(file_name)[1] == '.safetensors':
33
+ save_file(state_dict, file_name)
34
  else:
35
+ torch.save(state_dict, file_name)
36
 
37
 
38
+ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
39
+ print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
40
  merged_sd = {}
41
  for model, ratio in zip(models, ratios):
42
  print(f"loading: {model}")
 
59
  in_dim = down_weight.size()[1]
60
  out_dim = up_weight.size()[0]
61
  conv2d = len(down_weight.size()) == 4
62
+ kernel_size = None if not conv2d else down_weight.size()[2:4]
63
+ # print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
64
 
65
  # make original weight if not exist
66
  if lora_module_name not in merged_sd:
67
+ weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
68
  if device:
69
  weight = weight.to(device)
70
  else:
 
77
 
78
  # W <- W + U * D
79
  scale = (alpha / network_dim)
80
+
81
+ if device: # and isinstance(scale, torch.Tensor):
82
+ scale = scale.to(device)
83
+
84
  if not conv2d: # linear
85
  weight = weight + ratio * (up_weight @ down_weight) * scale
86
+ elif kernel_size == (1, 1):
87
  weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
88
  ).unsqueeze(2).unsqueeze(3) * scale
89
+ else:
90
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
91
+ weight = weight + ratio * conved * scale
92
 
93
  merged_sd[lora_module_name] = weight
94
 
 
98
  with torch.no_grad():
99
  for lora_module_name, mat in tqdm(list(merged_sd.items())):
100
  conv2d = (len(mat.size()) == 4)
101
+ kernel_size = None if not conv2d else mat.size()[2:4]
102
+ conv2d_3x3 = conv2d and kernel_size != (1, 1)
103
+ out_dim, in_dim = mat.size()[0:2]
104
+
105
  if conv2d:
106
+ if conv2d_3x3:
107
+ mat = mat.flatten(start_dim=1)
108
+ else:
109
+ mat = mat.squeeze()
110
+
111
+ module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
112
+ module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
113
 
114
  U, S, Vh = torch.linalg.svd(mat)
115
 
116
+ U = U[:, :module_new_rank]
117
+ S = S[:module_new_rank]
118
  U = U @ torch.diag(S)
119
 
120
+ Vh = Vh[:module_new_rank, :]
121
 
122
  dist = torch.cat([U.flatten(), Vh.flatten()])
123
  hi_val = torch.quantile(dist, CLAMP_QUANTILE)
 
126
  U = U.clamp(low_val, hi_val)
127
  Vh = Vh.clamp(low_val, hi_val)
128
 
129
+ if conv2d:
130
+ U = U.reshape(out_dim, module_new_rank, 1, 1)
131
+ Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1])
132
+
133
  up_weight = U
134
  down_weight = Vh
135
 
 
 
 
 
136
  merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
137
  merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
138
+ merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(module_new_rank)
139
 
140
  return merged_lora_sd
141
 
 
157
  if save_dtype is None:
158
  save_dtype = merge_dtype
159
 
160
+ new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
161
+ state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype)
162
 
163
  print(f"saving model to: {args.save_to}")
164
+ save_to_file(args.save_to, state_dict, save_dtype)
165
 
166
 
167
  if __name__ == '__main__':
 
178
  help="ratios for each model / それぞれのLoRAモデルの比率")
179
  parser.add_argument("--new_rank", type=int, default=4,
180
  help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
181
+ parser.add_argument("--new_conv_rank", type=int, default=None,
182
+ help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ")
183
  parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
184
 
185
  args = parser.parse_args()
requirements.txt CHANGED
@@ -12,6 +12,8 @@ safetensors==0.2.6
12
  gradio==3.16.2
13
  altair==4.2.2
14
  easygui==0.98.3
 
 
15
  # for BLIP captioning
16
  requests==2.28.2
17
  timm==0.6.12
@@ -21,5 +23,4 @@ fairscale==0.4.13
21
  tensorflow==2.10.1
22
  huggingface-hub==0.12.0
23
  # for kohya_ss library
24
- #locon.locon_kohya
25
  .
 
12
  gradio==3.16.2
13
  altair==4.2.2
14
  easygui==0.98.3
15
+ toml==0.10.2
16
+ voluptuous==0.13.1
17
  # for BLIP captioning
18
  requests==2.28.2
19
  timm==0.6.12
 
23
  tensorflow==2.10.1
24
  huggingface-hub==0.12.0
25
  # for kohya_ss library
 
26
  .
train_db.py CHANGED
@@ -15,7 +15,11 @@ import diffusers
15
  from diffusers import DDPMScheduler
16
 
17
  import library.train_util as train_util
18
- from library.train_util import DreamBoothDataset
 
 
 
 
19
 
20
 
21
  def collate_fn(examples):
@@ -33,24 +37,33 @@ def train(args):
33
 
34
  tokenizer = train_util.load_tokenizer(args)
35
 
36
- train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
37
- tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
38
- args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
39
- args.bucket_reso_steps, args.bucket_no_upscale,
40
- args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
41
-
42
- if args.no_token_padding:
43
- train_dataset.disable_token_padding()
 
 
 
 
 
44
 
45
- # 学習データのdropout率を設定する
46
- train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
47
 
48
- train_dataset.make_buckets()
 
49
 
50
  if args.debug_dataset:
51
- train_util.debug_dataset(train_dataset)
52
  return
53
 
 
 
 
54
  # acceleratorを準備する
55
  print("prepare accelerator")
56
 
@@ -91,7 +104,7 @@ def train(args):
91
  vae.requires_grad_(False)
92
  vae.eval()
93
  with torch.no_grad():
94
- train_dataset.cache_latents(vae)
95
  vae.to("cpu")
96
  if torch.cuda.is_available():
97
  torch.cuda.empty_cache()
@@ -115,38 +128,18 @@ def train(args):
115
 
116
  # 学習に必要なクラスを準備する
117
  print("prepare optimizer, data loader etc.")
118
-
119
- # 8-bit Adamを使う
120
- if args.use_8bit_adam:
121
- try:
122
- import bitsandbytes as bnb
123
- except ImportError:
124
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
125
- print("use 8-bit Adam optimizer")
126
- optimizer_class = bnb.optim.AdamW8bit
127
- elif args.use_lion_optimizer:
128
- try:
129
- import lion_pytorch
130
- except ImportError:
131
- raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
132
- print("use Lion optimizer")
133
- optimizer_class = lion_pytorch.Lion
134
- else:
135
- optimizer_class = torch.optim.AdamW
136
-
137
  if train_text_encoder:
138
  trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
139
  else:
140
  trainable_params = unet.parameters()
141
 
142
- # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
143
- optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
144
 
145
  # dataloaderを準備する
146
  # DataLoaderのプロセス数:0はメインプロセスになる
147
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
148
  train_dataloader = torch.utils.data.DataLoader(
149
- train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
150
 
151
  # 学習ステップ数を計算する
152
  if args.max_train_epochs is not None:
@@ -156,9 +149,10 @@ def train(args):
156
  if args.stop_text_encoder_training is None:
157
  args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
158
 
159
- # lr schedulerを用意する
160
- lr_scheduler = diffusers.optimization.get_scheduler(
161
- args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
 
162
 
163
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
164
  if args.full_fp16:
@@ -195,8 +189,8 @@ def train(args):
195
  # 学習する
196
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
197
  print("running training / 学習開始")
198
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
199
- print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
200
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
201
  print(f" num epochs / epoch数: {num_train_epochs}")
202
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
@@ -217,7 +211,7 @@ def train(args):
217
  loss_total = 0.0
218
  for epoch in range(num_train_epochs):
219
  print(f"epoch {epoch+1}/{num_train_epochs}")
220
- train_dataset.set_current_epoch(epoch + 1)
221
 
222
  # 指定したステップ数までText Encoderを学習する:epoch最初の状態
223
  unet.train()
@@ -281,12 +275,12 @@ def train(args):
281
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
282
 
283
  accelerator.backward(loss)
284
- if accelerator.sync_gradients:
285
  if train_text_encoder:
286
  params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
287
  else:
288
  params_to_clip = unet.parameters()
289
- accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
290
 
291
  optimizer.step()
292
  lr_scheduler.step()
@@ -297,9 +291,13 @@ def train(args):
297
  progress_bar.update(1)
298
  global_step += 1
299
 
 
 
300
  current_loss = loss.detach().item()
301
  if args.logging_dir is not None:
302
- logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
 
 
303
  accelerator.log(logs, step=global_step)
304
 
305
  if epoch == 0:
@@ -326,6 +324,8 @@ def train(args):
326
  train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
327
  save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
328
 
 
 
329
  is_main_process = accelerator.is_main_process
330
  if is_main_process:
331
  unet = unwrap_model(unet)
@@ -352,6 +352,8 @@ if __name__ == '__main__':
352
  train_util.add_dataset_arguments(parser, True, False, True)
353
  train_util.add_training_arguments(parser, True)
354
  train_util.add_sd_saving_arguments(parser)
 
 
355
 
356
  parser.add_argument("--no_token_padding", action="store_true",
357
  help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
 
15
  from diffusers import DDPMScheduler
16
 
17
  import library.train_util as train_util
18
+ import library.config_util as config_util
19
+ from library.config_util import (
20
+ ConfigSanitizer,
21
+ BlueprintGenerator,
22
+ )
23
 
24
 
25
  def collate_fn(examples):
 
37
 
38
  tokenizer = train_util.load_tokenizer(args)
39
 
40
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
41
+ if args.dataset_config is not None:
42
+ print(f"Load dataset config from {args.dataset_config}")
43
+ user_config = config_util.load_user_config(args.dataset_config)
44
+ ignored = ["train_data_dir", "reg_data_dir"]
45
+ if any(getattr(args, attr) is not None for attr in ignored):
46
+ print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
47
+ else:
48
+ user_config = {
49
+ "datasets": [{
50
+ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
51
+ }]
52
+ }
53
 
54
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
55
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
56
 
57
+ if args.no_token_padding:
58
+ train_dataset_group.disable_token_padding()
59
 
60
  if args.debug_dataset:
61
+ train_util.debug_dataset(train_dataset_group)
62
  return
63
 
64
+ if cache_latents:
65
+ assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
66
+
67
  # acceleratorを準備する
68
  print("prepare accelerator")
69
 
 
104
  vae.requires_grad_(False)
105
  vae.eval()
106
  with torch.no_grad():
107
+ train_dataset_group.cache_latents(vae)
108
  vae.to("cpu")
109
  if torch.cuda.is_available():
110
  torch.cuda.empty_cache()
 
128
 
129
  # 学習に必要なクラスを準備する
130
  print("prepare optimizer, data loader etc.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  if train_text_encoder:
132
  trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
133
  else:
134
  trainable_params = unet.parameters()
135
 
136
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params)
 
137
 
138
  # dataloaderを準備する
139
  # DataLoaderのプロセス数:0はメインプロセスになる
140
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
141
  train_dataloader = torch.utils.data.DataLoader(
142
+ train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
143
 
144
  # 学習ステップ数を計算する
145
  if args.max_train_epochs is not None:
 
149
  if args.stop_text_encoder_training is None:
150
  args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
151
 
152
+ # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
153
+ lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
154
+ num_training_steps=args.max_train_steps,
155
+ num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
156
 
157
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
158
  if args.full_fp16:
 
189
  # 学習する
190
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
191
  print("running training / 学習開始")
192
+ print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
193
+ print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
194
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
195
  print(f" num epochs / epoch数: {num_train_epochs}")
196
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
 
211
  loss_total = 0.0
212
  for epoch in range(num_train_epochs):
213
  print(f"epoch {epoch+1}/{num_train_epochs}")
214
+ train_dataset_group.set_current_epoch(epoch + 1)
215
 
216
  # 指定したステップ数までText Encoderを学習する:epoch最初の状態
217
  unet.train()
 
275
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
276
 
277
  accelerator.backward(loss)
278
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
279
  if train_text_encoder:
280
  params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
281
  else:
282
  params_to_clip = unet.parameters()
283
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
284
 
285
  optimizer.step()
286
  lr_scheduler.step()
 
291
  progress_bar.update(1)
292
  global_step += 1
293
 
294
+ train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
295
+
296
  current_loss = loss.detach().item()
297
  if args.logging_dir is not None:
298
+ logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
299
+ if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
300
+ logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
301
  accelerator.log(logs, step=global_step)
302
 
303
  if epoch == 0:
 
324
  train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
325
  save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
326
 
327
+ train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
328
+
329
  is_main_process = accelerator.is_main_process
330
  if is_main_process:
331
  unet = unwrap_model(unet)
 
352
  train_util.add_dataset_arguments(parser, True, False, True)
353
  train_util.add_training_arguments(parser, True)
354
  train_util.add_sd_saving_arguments(parser)
355
+ train_util.add_optimizer_arguments(parser)
356
+ config_util.add_config_arguments(parser)
357
 
358
  parser.add_argument("--no_token_padding", action="store_true",
359
  help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
train_network.py CHANGED
@@ -1,8 +1,4 @@
1
- from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
2
- from torch.optim import Optimizer
3
- from torch.cuda.amp import autocast
4
  from torch.nn.parallel import DistributedDataParallel as DDP
5
- from typing import Optional, Union
6
  import importlib
7
  import argparse
8
  import gc
@@ -15,92 +11,39 @@ import json
15
  from tqdm import tqdm
16
  import torch
17
  from accelerate.utils import set_seed
18
- import diffusers
19
  from diffusers import DDPMScheduler
20
 
21
  import library.train_util as train_util
22
- from library.train_util import DreamBoothDataset, FineTuningDataset
 
 
 
 
 
 
 
23
 
24
 
25
  def collate_fn(examples):
26
  return examples[0]
27
 
28
 
 
29
  def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
30
  logs = {"loss/current": current_loss, "loss/average": avr_loss}
31
 
32
  if args.network_train_unet_only:
33
- logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
34
  elif args.network_train_text_encoder_only:
35
- logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
36
  else:
37
- logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
38
- logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
39
-
40
- return logs
41
 
 
 
42
 
43
- # Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
44
- # code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
45
- # Which is a newer release of diffusers than currently packaged with sd-scripts
46
- # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
47
-
48
-
49
- def get_scheduler_fix(
50
- name: Union[str, SchedulerType],
51
- optimizer: Optimizer,
52
- num_warmup_steps: Optional[int] = None,
53
- num_training_steps: Optional[int] = None,
54
- num_cycles: int = 1,
55
- power: float = 1.0,
56
- ):
57
- """
58
- Unified API to get any scheduler from its name.
59
- Args:
60
- name (`str` or `SchedulerType`):
61
- The name of the scheduler to use.
62
- optimizer (`torch.optim.Optimizer`):
63
- The optimizer that will be used during training.
64
- num_warmup_steps (`int`, *optional*):
65
- The number of warmup steps to do. This is not required by all schedulers (hence the argument being
66
- optional), the function will raise an error if it's unset and the scheduler type requires it.
67
- num_training_steps (`int``, *optional*):
68
- The number of training steps to do. This is not required by all schedulers (hence the argument being
69
- optional), the function will raise an error if it's unset and the scheduler type requires it.
70
- num_cycles (`int`, *optional*):
71
- The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
72
- power (`float`, *optional*, defaults to 1.0):
73
- Power factor. See `POLYNOMIAL` scheduler
74
- last_epoch (`int`, *optional*, defaults to -1):
75
- The index of the last epoch when resuming training.
76
- """
77
- name = SchedulerType(name)
78
- schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
79
- if name == SchedulerType.CONSTANT:
80
- return schedule_func(optimizer)
81
-
82
- # All other schedulers require `num_warmup_steps`
83
- if num_warmup_steps is None:
84
- raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
85
-
86
- if name == SchedulerType.CONSTANT_WITH_WARMUP:
87
- return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
88
-
89
- # All other schedulers require `num_training_steps`
90
- if num_training_steps is None:
91
- raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
92
-
93
- if name == SchedulerType.COSINE_WITH_RESTARTS:
94
- return schedule_func(
95
- optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
96
- )
97
-
98
- if name == SchedulerType.POLYNOMIAL:
99
- return schedule_func(
100
- optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
101
- )
102
-
103
- return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
104
 
105
 
106
  def train(args):
@@ -111,6 +54,7 @@ def train(args):
111
 
112
  cache_latents = args.cache_latents
113
  use_dreambooth_method = args.in_json is None
 
114
 
115
  if args.seed is not None:
116
  set_seed(args.seed)
@@ -118,38 +62,51 @@ def train(args):
118
  tokenizer = train_util.load_tokenizer(args)
119
 
120
  # データセットを準備する
121
- if use_dreambooth_method:
122
- print("Use DreamBooth method.")
123
- train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
124
- tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
125
- args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
126
- args.bucket_reso_steps, args.bucket_no_upscale,
127
- args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
128
- args.random_crop, args.debug_dataset)
129
  else:
130
- print("Train with captions.")
131
- train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
132
- tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
133
- args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
134
- args.bucket_reso_steps, args.bucket_no_upscale,
135
- args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
136
- args.dataset_repeats, args.debug_dataset)
137
-
138
- # 学習データのdropout率を設定する
139
- train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
140
-
141
- train_dataset.make_buckets()
 
 
 
 
 
 
 
 
142
 
143
  if args.debug_dataset:
144
- train_util.debug_dataset(train_dataset)
145
  return
146
- if len(train_dataset) == 0:
147
  print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
148
  return
149
 
 
 
 
 
150
  # acceleratorを準備する
151
  print("prepare accelerator")
152
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
 
153
 
154
  # mixed precisionに対応した型を用意しておき適宜castする
155
  weight_dtype, save_dtype = train_util.prepare_dtype(args)
@@ -161,7 +118,7 @@ def train(args):
161
  if args.lowram:
162
  text_encoder.to("cuda")
163
  unet.to("cuda")
164
-
165
  # モデルに xformers とか memory efficient attention を組み込む
166
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
167
 
@@ -171,13 +128,15 @@ def train(args):
171
  vae.requires_grad_(False)
172
  vae.eval()
173
  with torch.no_grad():
174
- train_dataset.cache_latents(vae)
175
  vae.to("cpu")
176
  if torch.cuda.is_available():
177
  torch.cuda.empty_cache()
178
  gc.collect()
179
 
180
  # prepare network
 
 
181
  print("import network module:", args.network_module)
182
  network_module = importlib.import_module(args.network_module)
183
 
@@ -208,48 +167,25 @@ def train(args):
208
  # 学習に必要なクラスを準備する
209
  print("prepare optimizer, data loader etc.")
210
 
211
- # 8-bit Adamを使う
212
- if args.use_8bit_adam:
213
- try:
214
- import bitsandbytes as bnb
215
- except ImportError:
216
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
217
- print("use 8-bit Adam optimizer")
218
- optimizer_class = bnb.optim.AdamW8bit
219
- elif args.use_lion_optimizer:
220
- try:
221
- import lion_pytorch
222
- except ImportError:
223
- raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
224
- print("use Lion optimizer")
225
- optimizer_class = lion_pytorch.Lion
226
- else:
227
- optimizer_class = torch.optim.AdamW
228
-
229
- optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
230
-
231
  trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
232
-
233
- # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
234
- optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
235
 
236
  # dataloaderを準備する
237
  # DataLoaderのプロセス数:0はメインプロセスになる
238
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
239
  train_dataloader = torch.utils.data.DataLoader(
240
- train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
241
 
242
  # 学習ステップ数を計算する
243
  if args.max_train_epochs is not None:
244
- args.max_train_steps = args.max_train_epochs * len(train_dataloader)
245
- print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
 
246
 
247
  # lr schedulerを用意する
248
- # lr_scheduler = diffusers.optimization.get_scheduler(
249
- lr_scheduler = get_scheduler_fix(
250
- args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
251
- num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
252
- num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
253
 
254
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
255
  if args.full_fp16:
@@ -317,17 +253,21 @@ def train(args):
317
  args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
318
 
319
  # 学習する
 
320
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
321
- print("running training / 学習開始")
322
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
323
- print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
324
- print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
325
- print(f" num epochs / epoch数: {num_train_epochs}")
326
- print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
327
- print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
328
- print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
329
- print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
330
-
 
 
 
331
  metadata = {
332
  "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
333
  "ss_training_started_at": training_started_at, # unix timestamp
@@ -335,12 +275,10 @@ def train(args):
335
  "ss_learning_rate": args.learning_rate,
336
  "ss_text_encoder_lr": args.text_encoder_lr,
337
  "ss_unet_lr": args.unet_lr,
338
- "ss_num_train_images": train_dataset.num_train_images, # includes repeating
339
- "ss_num_reg_images": train_dataset.num_reg_images,
340
  "ss_num_batches_per_epoch": len(train_dataloader),
341
  "ss_num_epochs": num_train_epochs,
342
- "ss_batch_size_per_device": args.train_batch_size,
343
- "ss_total_batch_size": total_batch_size,
344
  "ss_gradient_checkpointing": args.gradient_checkpointing,
345
  "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
346
  "ss_max_train_steps": args.max_train_steps,
@@ -352,33 +290,156 @@ def train(args):
352
  "ss_mixed_precision": args.mixed_precision,
353
  "ss_full_fp16": bool(args.full_fp16),
354
  "ss_v2": bool(args.v2),
355
- "ss_resolution": args.resolution,
356
  "ss_clip_skip": args.clip_skip,
357
  "ss_max_token_length": args.max_token_length,
358
- "ss_color_aug": bool(args.color_aug),
359
- "ss_flip_aug": bool(args.flip_aug),
360
- "ss_random_crop": bool(args.random_crop),
361
- "ss_shuffle_caption": bool(args.shuffle_caption),
362
  "ss_cache_latents": bool(args.cache_latents),
363
- "ss_enable_bucket": bool(train_dataset.enable_bucket),
364
- "ss_min_bucket_reso": train_dataset.min_bucket_reso,
365
- "ss_max_bucket_reso": train_dataset.max_bucket_reso,
366
  "ss_seed": args.seed,
367
- "ss_keep_tokens": args.keep_tokens,
368
  "ss_noise_offset": args.noise_offset,
369
- "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
370
- "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
371
- "ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
372
- "ss_bucket_info": json.dumps(train_dataset.bucket_info),
373
  "ss_training_comment": args.training_comment, # will not be updated after training
374
  "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
375
- "ss_optimizer": optimizer_name
 
 
 
 
 
 
376
  }
377
 
378
- # uncomment if another network is added
379
- # for key, value in net_kwargs.items():
380
- # metadata["ss_arg_" + key] = value
381
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  if args.pretrained_model_name_or_path is not None:
383
  sd_model_name = args.pretrained_model_name_or_path
384
  if os.path.exists(sd_model_name):
@@ -397,6 +458,13 @@ def train(args):
397
 
398
  metadata = {k: str(v) for k, v in metadata.items()}
399
 
 
 
 
 
 
 
 
400
  progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
401
  global_step = 0
402
 
@@ -409,8 +477,9 @@ def train(args):
409
  loss_list = []
410
  loss_total = 0.0
411
  for epoch in range(num_train_epochs):
412
- print(f"epoch {epoch+1}/{num_train_epochs}")
413
- train_dataset.set_current_epoch(epoch + 1)
 
414
 
415
  metadata["ss_epoch"] = str(epoch+1)
416
 
@@ -447,7 +516,7 @@ def train(args):
447
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
448
 
449
  # Predict the noise residual
450
- with autocast():
451
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
452
 
453
  if args.v_parameterization:
@@ -465,9 +534,9 @@ def train(args):
465
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
466
 
467
  accelerator.backward(loss)
468
- if accelerator.sync_gradients:
469
  params_to_clip = network.get_trainable_params()
470
- accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
471
 
472
  optimizer.step()
473
  lr_scheduler.step()
@@ -478,6 +547,8 @@ def train(args):
478
  progress_bar.update(1)
479
  global_step += 1
480
 
 
 
481
  current_loss = loss.detach().item()
482
  if epoch == 0:
483
  loss_list.append(current_loss)
@@ -508,8 +579,9 @@ def train(args):
508
  def save_func():
509
  ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
510
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
 
511
  print(f"saving checkpoint: {ckpt_file}")
512
- unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
513
 
514
  def remove_old_func(old_epoch_no):
515
  old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
@@ -518,15 +590,18 @@ def train(args):
518
  print(f"removing old checkpoint: {old_ckpt_file}")
519
  os.remove(old_ckpt_file)
520
 
521
- saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
522
- if saving and args.save_state:
523
- train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
 
 
 
524
 
525
  # end of epoch
526
 
527
  metadata["ss_epoch"] = str(num_train_epochs)
 
528
 
529
- is_main_process = accelerator.is_main_process
530
  if is_main_process:
531
  network = unwrap_model(network)
532
 
@@ -545,7 +620,7 @@ def train(args):
545
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
546
 
547
  print(f"save trained model to {ckpt_file}")
548
- network.save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
549
  print("model saved.")
550
 
551
 
@@ -555,6 +630,8 @@ if __name__ == '__main__':
555
  train_util.add_sd_models_arguments(parser)
556
  train_util.add_dataset_arguments(parser, True, True, True)
557
  train_util.add_training_arguments(parser, True)
 
 
558
 
559
  parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
560
  parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
@@ -562,10 +639,6 @@ if __name__ == '__main__':
562
 
563
  parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
564
  parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
565
- parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
566
- help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
567
- parser.add_argument("--lr_scheduler_power", type=float, default=1,
568
- help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
569
 
570
  parser.add_argument("--network_weights", type=str, default=None,
571
  help="pretrained weights for network / 学習するネットワークの初期重み")
 
 
 
 
1
  from torch.nn.parallel import DistributedDataParallel as DDP
 
2
  import importlib
3
  import argparse
4
  import gc
 
11
  from tqdm import tqdm
12
  import torch
13
  from accelerate.utils import set_seed
 
14
  from diffusers import DDPMScheduler
15
 
16
  import library.train_util as train_util
17
+ from library.train_util import (
18
+ DreamBoothDataset,
19
+ )
20
+ import library.config_util as config_util
21
+ from library.config_util import (
22
+ ConfigSanitizer,
23
+ BlueprintGenerator,
24
+ )
25
 
26
 
27
  def collate_fn(examples):
28
  return examples[0]
29
 
30
 
31
+ # TODO 他のスクリプトと共通化する
32
  def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
33
  logs = {"loss/current": current_loss, "loss/average": avr_loss}
34
 
35
  if args.network_train_unet_only:
36
+ logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
37
  elif args.network_train_text_encoder_only:
38
+ logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
39
  else:
40
+ logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
41
+ logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
 
 
42
 
43
+ if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
44
+ logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
45
 
46
+ return logs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
 
49
  def train(args):
 
54
 
55
  cache_latents = args.cache_latents
56
  use_dreambooth_method = args.in_json is None
57
+ use_user_config = args.dataset_config is not None
58
 
59
  if args.seed is not None:
60
  set_seed(args.seed)
 
62
  tokenizer = train_util.load_tokenizer(args)
63
 
64
  # データセットを準備する
65
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
66
+ if use_user_config:
67
+ print(f"Load dataset config from {args.dataset_config}")
68
+ user_config = config_util.load_user_config(args.dataset_config)
69
+ ignored = ["train_data_dir", "reg_data_dir", "in_json"]
70
+ if any(getattr(args, attr) is not None for attr in ignored):
71
+ print(
72
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
73
  else:
74
+ if use_dreambooth_method:
75
+ print("Use DreamBooth method.")
76
+ user_config = {
77
+ "datasets": [{
78
+ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
79
+ }]
80
+ }
81
+ else:
82
+ print("Train with captions.")
83
+ user_config = {
84
+ "datasets": [{
85
+ "subsets": [{
86
+ "image_dir": args.train_data_dir,
87
+ "metadata_file": args.in_json,
88
+ }]
89
+ }]
90
+ }
91
+
92
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
93
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
94
 
95
  if args.debug_dataset:
96
+ train_util.debug_dataset(train_dataset_group)
97
  return
98
+ if len(train_dataset_group) == 0:
99
  print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
100
  return
101
 
102
+ if cache_latents:
103
+ assert train_dataset_group.is_latent_cacheable(
104
+ ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
105
+
106
  # acceleratorを準備する
107
  print("prepare accelerator")
108
  accelerator, unwrap_model = train_util.prepare_accelerator(args)
109
+ is_main_process = accelerator.is_main_process
110
 
111
  # mixed precisionに対応した型を用意しておき適宜castする
112
  weight_dtype, save_dtype = train_util.prepare_dtype(args)
 
118
  if args.lowram:
119
  text_encoder.to("cuda")
120
  unet.to("cuda")
121
+
122
  # モデルに xformers とか memory efficient attention を組み込む
123
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
124
 
 
128
  vae.requires_grad_(False)
129
  vae.eval()
130
  with torch.no_grad():
131
+ train_dataset_group.cache_latents(vae)
132
  vae.to("cpu")
133
  if torch.cuda.is_available():
134
  torch.cuda.empty_cache()
135
  gc.collect()
136
 
137
  # prepare network
138
+ import sys
139
+ sys.path.append(os.path.dirname(__file__))
140
  print("import network module:", args.network_module)
141
  network_module = importlib.import_module(args.network_module)
142
 
 
167
  # 学習に必要なクラスを準備する
168
  print("prepare optimizer, data loader etc.")
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
171
+ optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
 
 
172
 
173
  # dataloaderを準備する
174
  # DataLoaderのプロセス数:0はメインプロセスになる
175
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
176
  train_dataloader = torch.utils.data.DataLoader(
177
+ train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
178
 
179
  # 学習ステップ数を計算する
180
  if args.max_train_epochs is not None:
181
+ args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes)
182
+ if is_main_process:
183
+ print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
184
 
185
  # lr schedulerを用意する
186
+ lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
187
+ num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps,
188
+ num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
 
 
189
 
190
  # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
191
  if args.full_fp16:
 
253
  args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
254
 
255
  # 学習する
256
+ # TODO: find a way to handle total batch size when there are multiple datasets
257
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
258
+
259
+ if is_main_process:
260
+ print("running training / 学習開始")
261
+ print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
262
+ print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
263
+ print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
264
+ print(f" num epochs / epoch数: {num_train_epochs}")
265
+ print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
266
+ # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
267
+ print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
268
+ print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
269
+
270
+ # TODO refactor metadata creation and move to util
271
  metadata = {
272
  "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
273
  "ss_training_started_at": training_started_at, # unix timestamp
 
275
  "ss_learning_rate": args.learning_rate,
276
  "ss_text_encoder_lr": args.text_encoder_lr,
277
  "ss_unet_lr": args.unet_lr,
278
+ "ss_num_train_images": train_dataset_group.num_train_images,
279
+ "ss_num_reg_images": train_dataset_group.num_reg_images,
280
  "ss_num_batches_per_epoch": len(train_dataloader),
281
  "ss_num_epochs": num_train_epochs,
 
 
282
  "ss_gradient_checkpointing": args.gradient_checkpointing,
283
  "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
284
  "ss_max_train_steps": args.max_train_steps,
 
290
  "ss_mixed_precision": args.mixed_precision,
291
  "ss_full_fp16": bool(args.full_fp16),
292
  "ss_v2": bool(args.v2),
 
293
  "ss_clip_skip": args.clip_skip,
294
  "ss_max_token_length": args.max_token_length,
 
 
 
 
295
  "ss_cache_latents": bool(args.cache_latents),
 
 
 
296
  "ss_seed": args.seed,
297
+ "ss_lowram": args.lowram,
298
  "ss_noise_offset": args.noise_offset,
 
 
 
 
299
  "ss_training_comment": args.training_comment, # will not be updated after training
300
  "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
301
+ "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
302
+ "ss_max_grad_norm": args.max_grad_norm,
303
+ "ss_caption_dropout_rate": args.caption_dropout_rate,
304
+ "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
305
+ "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
306
+ "ss_face_crop_aug_range": args.face_crop_aug_range,
307
+ "ss_prior_loss_weight": args.prior_loss_weight,
308
  }
309
 
310
+ if use_user_config:
311
+ # save metadata of multiple datasets
312
+ # NOTE: pack "ss_datasets" value as json one time
313
+ # or should also pack nested collections as json?
314
+ datasets_metadata = []
315
+ tag_frequency = {} # merge tag frequency for metadata editor
316
+ dataset_dirs_info = {} # merge subset dirs for metadata editor
317
+
318
+ for dataset in train_dataset_group.datasets:
319
+ is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
320
+ dataset_metadata = {
321
+ "is_dreambooth": is_dreambooth_dataset,
322
+ "batch_size_per_device": dataset.batch_size,
323
+ "num_train_images": dataset.num_train_images, # includes repeating
324
+ "num_reg_images": dataset.num_reg_images,
325
+ "resolution": (dataset.width, dataset.height),
326
+ "enable_bucket": bool(dataset.enable_bucket),
327
+ "min_bucket_reso": dataset.min_bucket_reso,
328
+ "max_bucket_reso": dataset.max_bucket_reso,
329
+ "tag_frequency": dataset.tag_frequency,
330
+ "bucket_info": dataset.bucket_info,
331
+ }
332
+
333
+ subsets_metadata = []
334
+ for subset in dataset.subsets:
335
+ subset_metadata = {
336
+ "img_count": subset.img_count,
337
+ "num_repeats": subset.num_repeats,
338
+ "color_aug": bool(subset.color_aug),
339
+ "flip_aug": bool(subset.flip_aug),
340
+ "random_crop": bool(subset.random_crop),
341
+ "shuffle_caption": bool(subset.shuffle_caption),
342
+ "keep_tokens": subset.keep_tokens,
343
+ }
344
+
345
+ image_dir_or_metadata_file = None
346
+ if subset.image_dir:
347
+ image_dir = os.path.basename(subset.image_dir)
348
+ subset_metadata["image_dir"] = image_dir
349
+ image_dir_or_metadata_file = image_dir
350
+
351
+ if is_dreambooth_dataset:
352
+ subset_metadata["class_tokens"] = subset.class_tokens
353
+ subset_metadata["is_reg"] = subset.is_reg
354
+ if subset.is_reg:
355
+ image_dir_or_metadata_file = None # not merging reg dataset
356
+ else:
357
+ metadata_file = os.path.basename(subset.metadata_file)
358
+ subset_metadata["metadata_file"] = metadata_file
359
+ image_dir_or_metadata_file = metadata_file # may overwrite
360
+
361
+ subsets_metadata.append(subset_metadata)
362
+
363
+ # merge dataset dir: not reg subset only
364
+ # TODO update additional-network extension to show detailed dataset config from metadata
365
+ if image_dir_or_metadata_file is not None:
366
+ # datasets may have a certain dir multiple times
367
+ v = image_dir_or_metadata_file
368
+ i = 2
369
+ while v in dataset_dirs_info:
370
+ v = image_dir_or_metadata_file + f" ({i})"
371
+ i += 1
372
+ image_dir_or_metadata_file = v
373
+
374
+ dataset_dirs_info[image_dir_or_metadata_file] = {
375
+ "n_repeats": subset.num_repeats,
376
+ "img_count": subset.img_count
377
+ }
378
+
379
+ dataset_metadata["subsets"] = subsets_metadata
380
+ datasets_metadata.append(dataset_metadata)
381
+
382
+ # merge tag frequency:
383
+ for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
384
+ # あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
385
+ # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
386
+ # なので、ここで複数datasetの回数を合算してもあまり意味はない
387
+ if ds_dir_name in tag_frequency:
388
+ continue
389
+ tag_frequency[ds_dir_name] = ds_freq_for_dir
390
+
391
+ metadata["ss_datasets"] = json.dumps(datasets_metadata)
392
+ metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
393
+ metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
394
+ else:
395
+ # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
396
+ assert len(
397
+ train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
398
+
399
+ dataset = train_dataset_group.datasets[0]
400
+
401
+ dataset_dirs_info = {}
402
+ reg_dataset_dirs_info = {}
403
+ if use_dreambooth_method:
404
+ for subset in dataset.subsets:
405
+ info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
406
+ info[os.path.basename(subset.image_dir)] = {
407
+ "n_repeats": subset.num_repeats,
408
+ "img_count": subset.img_count
409
+ }
410
+ else:
411
+ for subset in dataset.subsets:
412
+ dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
413
+ "n_repeats": subset.num_repeats,
414
+ "img_count": subset.img_count
415
+ }
416
+
417
+ metadata.update({
418
+ "ss_batch_size_per_device": args.train_batch_size,
419
+ "ss_total_batch_size": total_batch_size,
420
+ "ss_resolution": args.resolution,
421
+ "ss_color_aug": bool(args.color_aug),
422
+ "ss_flip_aug": bool(args.flip_aug),
423
+ "ss_random_crop": bool(args.random_crop),
424
+ "ss_shuffle_caption": bool(args.shuffle_caption),
425
+ "ss_enable_bucket": bool(dataset.enable_bucket),
426
+ "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
427
+ "ss_min_bucket_reso": dataset.min_bucket_reso,
428
+ "ss_max_bucket_reso": dataset.max_bucket_reso,
429
+ "ss_keep_tokens": args.keep_tokens,
430
+ "ss_dataset_dirs": json.dumps(dataset_dirs_info),
431
+ "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
432
+ "ss_tag_frequency": json.dumps(dataset.tag_frequency),
433
+ "ss_bucket_info": json.dumps(dataset.bucket_info),
434
+ })
435
+
436
+ # add extra args
437
+ if args.network_args:
438
+ metadata["ss_network_args"] = json.dumps(net_kwargs)
439
+ # for key, value in net_kwargs.items():
440
+ # metadata["ss_arg_" + key] = value
441
+
442
+ # model name and hash
443
  if args.pretrained_model_name_or_path is not None:
444
  sd_model_name = args.pretrained_model_name_or_path
445
  if os.path.exists(sd_model_name):
 
458
 
459
  metadata = {k: str(v) for k, v in metadata.items()}
460
 
461
+ # make minimum metadata for filtering
462
+ minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"]
463
+ minimum_metadata = {}
464
+ for key in minimum_keys:
465
+ if key in metadata:
466
+ minimum_metadata[key] = metadata[key]
467
+
468
  progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
469
  global_step = 0
470
 
 
477
  loss_list = []
478
  loss_total = 0.0
479
  for epoch in range(num_train_epochs):
480
+ if is_main_process:
481
+ print(f"epoch {epoch+1}/{num_train_epochs}")
482
+ train_dataset_group.set_current_epoch(epoch + 1)
483
 
484
  metadata["ss_epoch"] = str(epoch+1)
485
 
 
516
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
517
 
518
  # Predict the noise residual
519
+ with accelerator.autocast():
520
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
521
 
522
  if args.v_parameterization:
 
534
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
535
 
536
  accelerator.backward(loss)
537
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
538
  params_to_clip = network.get_trainable_params()
539
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
540
 
541
  optimizer.step()
542
  lr_scheduler.step()
 
547
  progress_bar.update(1)
548
  global_step += 1
549
 
550
+ train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
551
+
552
  current_loss = loss.detach().item()
553
  if epoch == 0:
554
  loss_list.append(current_loss)
 
579
  def save_func():
580
  ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
581
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
582
+ metadata["ss_training_finished_at"] = str(time.time())
583
  print(f"saving checkpoint: {ckpt_file}")
584
+ unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
585
 
586
  def remove_old_func(old_epoch_no):
587
  old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
 
590
  print(f"removing old checkpoint: {old_ckpt_file}")
591
  os.remove(old_ckpt_file)
592
 
593
+ if is_main_process:
594
+ saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
595
+ if saving and args.save_state:
596
+ train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
597
+
598
+ train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
599
 
600
  # end of epoch
601
 
602
  metadata["ss_epoch"] = str(num_train_epochs)
603
+ metadata["ss_training_finished_at"] = str(time.time())
604
 
 
605
  if is_main_process:
606
  network = unwrap_model(network)
607
 
 
620
  ckpt_file = os.path.join(args.output_dir, ckpt_name)
621
 
622
  print(f"save trained model to {ckpt_file}")
623
+ network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
624
  print("model saved.")
625
 
626
 
 
630
  train_util.add_sd_models_arguments(parser)
631
  train_util.add_dataset_arguments(parser, True, True, True)
632
  train_util.add_training_arguments(parser, True)
633
+ train_util.add_optimizer_arguments(parser)
634
+ config_util.add_config_arguments(parser)
635
 
636
  parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
637
  parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
 
639
 
640
  parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
641
  parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
 
 
 
 
642
 
643
  parser.add_argument("--network_weights", type=str, default=None,
644
  help="pretrained weights for network / 学習するネットワークの初期重み")
train_textual_inversion.py CHANGED
@@ -11,7 +11,11 @@ import diffusers
11
  from diffusers import DDPMScheduler
12
 
13
  import library.train_util as train_util
14
- from library.train_util import DreamBoothDataset, FineTuningDataset
 
 
 
 
15
 
16
  imagenet_templates_small = [
17
  "a photo of a {}",
@@ -79,7 +83,6 @@ def train(args):
79
  train_util.prepare_dataset_args(args, True)
80
 
81
  cache_latents = args.cache_latents
82
- use_dreambooth_method = args.in_json is None
83
 
84
  if args.seed is not None:
85
  set_seed(args.seed)
@@ -139,21 +142,35 @@ def train(args):
139
  print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
140
 
141
  # データセットを準備する
142
- if use_dreambooth_method:
143
- print("Use DreamBooth method.")
144
- train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
145
- tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
146
- args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
147
- args.bucket_reso_steps, args.bucket_no_upscale,
148
- args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
149
  else:
150
- print("Train with captions.")
151
- train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
152
- tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
153
- args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
154
- args.bucket_reso_steps, args.bucket_no_upscale,
155
- args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
156
- args.dataset_repeats, args.debug_dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
159
  if use_template:
@@ -163,20 +180,30 @@ def train(args):
163
  captions = []
164
  for tmpl in templates:
165
  captions.append(tmpl.format(replace_to))
166
- train_dataset.add_replacement("", captions)
167
- elif args.num_vectors_per_token > 1:
168
- replace_to = " ".join(token_strings)
169
- train_dataset.add_replacement(args.token_string, replace_to)
170
 
171
- train_dataset.make_buckets()
 
 
 
 
 
 
 
 
 
 
172
 
173
  if args.debug_dataset:
174
- train_util.debug_dataset(train_dataset, show_input_ids=True)
175
  return
176
- if len(train_dataset) == 0:
177
  print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
178
  return
179
 
 
 
 
180
  # モデルに xformers とか memory efficient attention を組み込む
181
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
182
 
@@ -186,7 +213,7 @@ def train(args):
186
  vae.requires_grad_(False)
187
  vae.eval()
188
  with torch.no_grad():
189
- train_dataset.cache_latents(vae)
190
  vae.to("cpu")
191
  if torch.cuda.is_available():
192
  torch.cuda.empty_cache()
@@ -198,35 +225,14 @@ def train(args):
198
 
199
  # 学習に必要なクラスを準備する
200
  print("prepare optimizer, data loader etc.")
201
-
202
- # 8-bit Adamを使う
203
- if args.use_8bit_adam:
204
- try:
205
- import bitsandbytes as bnb
206
- except ImportError:
207
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
208
- print("use 8-bit Adam optimizer")
209
- optimizer_class = bnb.optim.AdamW8bit
210
- elif args.use_lion_optimizer:
211
- try:
212
- import lion_pytorch
213
- except ImportError:
214
- raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
215
- print("use Lion optimizer")
216
- optimizer_class = lion_pytorch.Lion
217
- else:
218
- optimizer_class = torch.optim.AdamW
219
-
220
  trainable_params = text_encoder.get_input_embeddings().parameters()
221
-
222
- # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
223
- optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
224
 
225
  # dataloaderを準備する
226
  # DataLoaderのプロセス数:0はメインプロセスになる
227
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
228
  train_dataloader = torch.utils.data.DataLoader(
229
- train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
230
 
231
  # 学習ステップ数を計算する
232
  if args.max_train_epochs is not None:
@@ -234,8 +240,9 @@ def train(args):
234
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
235
 
236
  # lr schedulerを用意する
237
- lr_scheduler = diffusers.optimization.get_scheduler(
238
- args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
 
239
 
240
  # acceleratorがなんかよろしくやってくれるらしい
241
  text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
@@ -283,8 +290,8 @@ def train(args):
283
  # 学習する
284
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
285
  print("running training / 学習開始")
286
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
287
- print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
288
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
289
  print(f" num epochs / epoch数: {num_train_epochs}")
290
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
@@ -303,12 +310,11 @@ def train(args):
303
 
304
  for epoch in range(num_train_epochs):
305
  print(f"epoch {epoch+1}/{num_train_epochs}")
306
- train_dataset.set_current_epoch(epoch + 1)
307
 
308
  text_encoder.train()
309
 
310
  loss_total = 0
311
- bef_epo_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
312
  for step, batch in enumerate(train_dataloader):
313
  with accelerator.accumulate(text_encoder):
314
  with torch.no_grad():
@@ -357,9 +363,9 @@ def train(args):
357
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
358
 
359
  accelerator.backward(loss)
360
- if accelerator.sync_gradients:
361
  params_to_clip = text_encoder.get_input_embeddings().parameters()
362
- accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
363
 
364
  optimizer.step()
365
  lr_scheduler.step()
@@ -374,9 +380,14 @@ def train(args):
374
  progress_bar.update(1)
375
  global_step += 1
376
 
 
 
 
377
  current_loss = loss.detach().item()
378
  if args.logging_dir is not None:
379
- logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
 
 
380
  accelerator.log(logs, step=global_step)
381
 
382
  loss_total += current_loss
@@ -394,8 +405,6 @@ def train(args):
394
  accelerator.wait_for_everyone()
395
 
396
  updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
397
- # d = updated_embs - bef_epo_embs
398
- # print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
399
 
400
  if args.save_every_n_epochs is not None:
401
  model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
@@ -417,6 +426,9 @@ def train(args):
417
  if saving and args.save_state:
418
  train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
419
 
 
 
 
420
  # end of epoch
421
 
422
  is_main_process = accelerator.is_main_process
@@ -491,6 +503,8 @@ if __name__ == '__main__':
491
  train_util.add_sd_models_arguments(parser)
492
  train_util.add_dataset_arguments(parser, True, True, False)
493
  train_util.add_training_arguments(parser, True)
 
 
494
 
495
  parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
496
  help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
 
11
  from diffusers import DDPMScheduler
12
 
13
  import library.train_util as train_util
14
+ import library.config_util as config_util
15
+ from library.config_util import (
16
+ ConfigSanitizer,
17
+ BlueprintGenerator,
18
+ )
19
 
20
  imagenet_templates_small = [
21
  "a photo of a {}",
 
83
  train_util.prepare_dataset_args(args, True)
84
 
85
  cache_latents = args.cache_latents
 
86
 
87
  if args.seed is not None:
88
  set_seed(args.seed)
 
142
  print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
143
 
144
  # データセットを準備する
145
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
146
+ if args.dataset_config is not None:
147
+ print(f"Load dataset config from {args.dataset_config}")
148
+ user_config = config_util.load_user_config(args.dataset_config)
149
+ ignored = ["train_data_dir", "reg_data_dir", "in_json"]
150
+ if any(getattr(args, attr) is not None for attr in ignored):
151
+ print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
152
  else:
153
+ use_dreambooth_method = args.in_json is None
154
+ if use_dreambooth_method:
155
+ print("Use DreamBooth method.")
156
+ user_config = {
157
+ "datasets": [{
158
+ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
159
+ }]
160
+ }
161
+ else:
162
+ print("Train with captions.")
163
+ user_config = {
164
+ "datasets": [{
165
+ "subsets": [{
166
+ "image_dir": args.train_data_dir,
167
+ "metadata_file": args.in_json,
168
+ }]
169
+ }]
170
+ }
171
+
172
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
173
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
174
 
175
  # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
176
  if use_template:
 
180
  captions = []
181
  for tmpl in templates:
182
  captions.append(tmpl.format(replace_to))
183
+ train_dataset_group.add_replacement("", captions)
 
 
 
184
 
185
+ if args.num_vectors_per_token > 1:
186
+ prompt_replacement = (args.token_string, replace_to)
187
+ else:
188
+ prompt_replacement = None
189
+ else:
190
+ if args.num_vectors_per_token > 1:
191
+ replace_to = " ".join(token_strings)
192
+ train_dataset_group.add_replacement(args.token_string, replace_to)
193
+ prompt_replacement = (args.token_string, replace_to)
194
+ else:
195
+ prompt_replacement = None
196
 
197
  if args.debug_dataset:
198
+ train_util.debug_dataset(train_dataset_group, show_input_ids=True)
199
  return
200
+ if len(train_dataset_group) == 0:
201
  print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
202
  return
203
 
204
+ if cache_latents:
205
+ assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
206
+
207
  # モデルに xformers とか memory efficient attention を組み込む
208
  train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
209
 
 
213
  vae.requires_grad_(False)
214
  vae.eval()
215
  with torch.no_grad():
216
+ train_dataset_group.cache_latents(vae)
217
  vae.to("cpu")
218
  if torch.cuda.is_available():
219
  torch.cuda.empty_cache()
 
225
 
226
  # 学習に必要なクラスを準備する
227
  print("prepare optimizer, data loader etc.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  trainable_params = text_encoder.get_input_embeddings().parameters()
229
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params)
 
 
230
 
231
  # dataloaderを準備する
232
  # DataLoaderのプロセス数:0はメインプロセスになる
233
  n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
234
  train_dataloader = torch.utils.data.DataLoader(
235
+ train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
236
 
237
  # 学習ステップ数を計算する
238
  if args.max_train_epochs is not None:
 
240
  print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
241
 
242
  # lr schedulerを用意する
243
+ lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
244
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
245
+ num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
246
 
247
  # acceleratorがなんかよろしくやってくれるらしい
248
  text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
 
290
  # 学習する
291
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
292
  print("running training / 学習開始")
293
+ print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
294
+ print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
295
  print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
296
  print(f" num epochs / epoch数: {num_train_epochs}")
297
  print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
 
310
 
311
  for epoch in range(num_train_epochs):
312
  print(f"epoch {epoch+1}/{num_train_epochs}")
313
+ train_dataset_group.set_current_epoch(epoch + 1)
314
 
315
  text_encoder.train()
316
 
317
  loss_total = 0
 
318
  for step, batch in enumerate(train_dataloader):
319
  with accelerator.accumulate(text_encoder):
320
  with torch.no_grad():
 
363
  loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
364
 
365
  accelerator.backward(loss)
366
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
367
  params_to_clip = text_encoder.get_input_embeddings().parameters()
368
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
369
 
370
  optimizer.step()
371
  lr_scheduler.step()
 
380
  progress_bar.update(1)
381
  global_step += 1
382
 
383
+ train_util.sample_images(accelerator, args, None, global_step, accelerator.device,
384
+ vae, tokenizer, text_encoder, unet, prompt_replacement)
385
+
386
  current_loss = loss.detach().item()
387
  if args.logging_dir is not None:
388
+ logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
389
+ if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
390
+ logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
391
  accelerator.log(logs, step=global_step)
392
 
393
  loss_total += current_loss
 
405
  accelerator.wait_for_everyone()
406
 
407
  updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
 
 
408
 
409
  if args.save_every_n_epochs is not None:
410
  model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
 
426
  if saving and args.save_state:
427
  train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
428
 
429
+ train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device,
430
+ vae, tokenizer, text_encoder, unet, prompt_replacement)
431
+
432
  # end of epoch
433
 
434
  is_main_process = accelerator.is_main_process
 
503
  train_util.add_sd_models_arguments(parser)
504
  train_util.add_dataset_arguments(parser, True, True, False)
505
  train_util.add_training_arguments(parser, True)
506
+ train_util.add_optimizer_arguments(parser)
507
+ config_util.add_config_arguments(parser)
508
 
509
  parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
510
  help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")