abc
commited on
Commit
·
0d34cbd
1
Parent(s):
b9a80b5
Upload 44 files
Browse files- .gitattributes +1 -0
- fine_tune.py +50 -45
- gen_img_diffusers.py +234 -55
- library/model_util.py +5 -1
- library/train_util.py +853 -229
- networks/check_lora_weights.py +1 -1
- networks/extract_lora_from_models.py +44 -25
- networks/lora.py +191 -30
- networks/merge_lora.py +11 -5
- networks/resize_lora.py +187 -50
- networks/svd_merge_lora.py +40 -18
- requirements.txt +2 -1
- train_db.py +47 -45
- train_network.py +248 -175
- train_textual_inversion.py +72 -58
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
bitsandbytes_windows/libbitsandbytes_cuda116.dll filter=lfs diff=lfs merge=lfs -text
|
fine_tune.py
CHANGED
|
@@ -13,7 +13,11 @@ import diffusers
|
|
| 13 |
from diffusers import DDPMScheduler
|
| 14 |
|
| 15 |
import library.train_util as train_util
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def collate_fn(examples):
|
| 19 |
return examples[0]
|
|
@@ -30,25 +34,36 @@ def train(args):
|
|
| 30 |
|
| 31 |
tokenizer = train_util.load_tokenizer(args)
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
if args.debug_dataset:
|
| 46 |
-
train_util.debug_dataset(
|
| 47 |
return
|
| 48 |
-
if len(
|
| 49 |
print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
|
| 50 |
return
|
| 51 |
|
|
|
|
|
|
|
|
|
|
| 52 |
# acceleratorを準備する
|
| 53 |
print("prepare accelerator")
|
| 54 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
|
@@ -109,7 +124,7 @@ def train(args):
|
|
| 109 |
vae.requires_grad_(False)
|
| 110 |
vae.eval()
|
| 111 |
with torch.no_grad():
|
| 112 |
-
|
| 113 |
vae.to("cpu")
|
| 114 |
if torch.cuda.is_available():
|
| 115 |
torch.cuda.empty_cache()
|
|
@@ -149,33 +164,13 @@ def train(args):
|
|
| 149 |
|
| 150 |
# 学習に必要なクラスを準備する
|
| 151 |
print("prepare optimizer, data loader etc.")
|
| 152 |
-
|
| 153 |
-
# 8-bit Adamを使う
|
| 154 |
-
if args.use_8bit_adam:
|
| 155 |
-
try:
|
| 156 |
-
import bitsandbytes as bnb
|
| 157 |
-
except ImportError:
|
| 158 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
| 159 |
-
print("use 8-bit Adam optimizer")
|
| 160 |
-
optimizer_class = bnb.optim.AdamW8bit
|
| 161 |
-
elif args.use_lion_optimizer:
|
| 162 |
-
try:
|
| 163 |
-
import lion_pytorch
|
| 164 |
-
except ImportError:
|
| 165 |
-
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
| 166 |
-
print("use Lion optimizer")
|
| 167 |
-
optimizer_class = lion_pytorch.Lion
|
| 168 |
-
else:
|
| 169 |
-
optimizer_class = torch.optim.AdamW
|
| 170 |
-
|
| 171 |
-
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
| 172 |
-
optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
|
| 173 |
|
| 174 |
# dataloaderを準備する
|
| 175 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
| 176 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
| 177 |
train_dataloader = torch.utils.data.DataLoader(
|
| 178 |
-
|
| 179 |
|
| 180 |
# 学習ステップ数を計算する
|
| 181 |
if args.max_train_epochs is not None:
|
|
@@ -183,8 +178,9 @@ def train(args):
|
|
| 183 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
| 184 |
|
| 185 |
# lr schedulerを用意する
|
| 186 |
-
lr_scheduler =
|
| 187 |
-
|
|
|
|
| 188 |
|
| 189 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
| 190 |
if args.full_fp16:
|
|
@@ -218,7 +214,7 @@ def train(args):
|
|
| 218 |
# 学習する
|
| 219 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 220 |
print("running training / 学習開始")
|
| 221 |
-
print(f" num examples / サンプル数: {
|
| 222 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
| 223 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
| 224 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
|
@@ -237,7 +233,7 @@ def train(args):
|
|
| 237 |
|
| 238 |
for epoch in range(num_train_epochs):
|
| 239 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
| 240 |
-
|
| 241 |
|
| 242 |
for m in training_models:
|
| 243 |
m.train()
|
|
@@ -286,11 +282,11 @@ def train(args):
|
|
| 286 |
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
| 287 |
|
| 288 |
accelerator.backward(loss)
|
| 289 |
-
if accelerator.sync_gradients:
|
| 290 |
params_to_clip = []
|
| 291 |
for m in training_models:
|
| 292 |
params_to_clip.extend(m.parameters())
|
| 293 |
-
accelerator.clip_grad_norm_(params_to_clip,
|
| 294 |
|
| 295 |
optimizer.step()
|
| 296 |
lr_scheduler.step()
|
|
@@ -301,11 +297,16 @@ def train(args):
|
|
| 301 |
progress_bar.update(1)
|
| 302 |
global_step += 1
|
| 303 |
|
|
|
|
|
|
|
| 304 |
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
| 305 |
if args.logging_dir is not None:
|
| 306 |
-
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
|
|
|
|
| 307 |
accelerator.log(logs, step=global_step)
|
| 308 |
|
|
|
|
| 309 |
loss_total += current_loss
|
| 310 |
avr_loss = loss_total / (step+1)
|
| 311 |
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
|
@@ -315,7 +316,7 @@ def train(args):
|
|
| 315 |
break
|
| 316 |
|
| 317 |
if args.logging_dir is not None:
|
| 318 |
-
logs = {"
|
| 319 |
accelerator.log(logs, step=epoch+1)
|
| 320 |
|
| 321 |
accelerator.wait_for_everyone()
|
|
@@ -325,6 +326,8 @@ def train(args):
|
|
| 325 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
| 326 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
| 327 |
|
|
|
|
|
|
|
| 328 |
is_main_process = accelerator.is_main_process
|
| 329 |
if is_main_process:
|
| 330 |
unet = unwrap_model(unet)
|
|
@@ -351,6 +354,8 @@ if __name__ == '__main__':
|
|
| 351 |
train_util.add_dataset_arguments(parser, False, True, True)
|
| 352 |
train_util.add_training_arguments(parser, False)
|
| 353 |
train_util.add_sd_saving_arguments(parser)
|
|
|
|
|
|
|
| 354 |
|
| 355 |
parser.add_argument("--diffusers_xformers", action='store_true',
|
| 356 |
help='use xformers by diffusers / Diffusersでxformersを使用する')
|
|
|
|
| 13 |
from diffusers import DDPMScheduler
|
| 14 |
|
| 15 |
import library.train_util as train_util
|
| 16 |
+
import library.config_util as config_util
|
| 17 |
+
from library.config_util import (
|
| 18 |
+
ConfigSanitizer,
|
| 19 |
+
BlueprintGenerator,
|
| 20 |
+
)
|
| 21 |
|
| 22 |
def collate_fn(examples):
|
| 23 |
return examples[0]
|
|
|
|
| 34 |
|
| 35 |
tokenizer = train_util.load_tokenizer(args)
|
| 36 |
|
| 37 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
|
| 38 |
+
if args.dataset_config is not None:
|
| 39 |
+
print(f"Load dataset config from {args.dataset_config}")
|
| 40 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
| 41 |
+
ignored = ["train_data_dir", "in_json"]
|
| 42 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
| 43 |
+
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
| 44 |
+
else:
|
| 45 |
+
user_config = {
|
| 46 |
+
"datasets": [{
|
| 47 |
+
"subsets": [{
|
| 48 |
+
"image_dir": args.train_data_dir,
|
| 49 |
+
"metadata_file": args.in_json,
|
| 50 |
+
}]
|
| 51 |
+
}]
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
| 55 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
| 56 |
|
| 57 |
if args.debug_dataset:
|
| 58 |
+
train_util.debug_dataset(train_dataset_group)
|
| 59 |
return
|
| 60 |
+
if len(train_dataset_group) == 0:
|
| 61 |
print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
|
| 62 |
return
|
| 63 |
|
| 64 |
+
if cache_latents:
|
| 65 |
+
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
| 66 |
+
|
| 67 |
# acceleratorを準備する
|
| 68 |
print("prepare accelerator")
|
| 69 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
|
|
|
| 124 |
vae.requires_grad_(False)
|
| 125 |
vae.eval()
|
| 126 |
with torch.no_grad():
|
| 127 |
+
train_dataset_group.cache_latents(vae)
|
| 128 |
vae.to("cpu")
|
| 129 |
if torch.cuda.is_available():
|
| 130 |
torch.cuda.empty_cache()
|
|
|
|
| 164 |
|
| 165 |
# 学習に必要なクラスを準備する
|
| 166 |
print("prepare optimizer, data loader etc.")
|
| 167 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
# dataloaderを準備する
|
| 170 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
| 171 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
| 172 |
train_dataloader = torch.utils.data.DataLoader(
|
| 173 |
+
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
| 174 |
|
| 175 |
# 学習ステップ数を計算する
|
| 176 |
if args.max_train_epochs is not None:
|
|
|
|
| 178 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
| 179 |
|
| 180 |
# lr schedulerを用意する
|
| 181 |
+
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
| 182 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
| 183 |
+
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
| 184 |
|
| 185 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
| 186 |
if args.full_fp16:
|
|
|
|
| 214 |
# 学習する
|
| 215 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 216 |
print("running training / 学習開始")
|
| 217 |
+
print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
| 218 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
| 219 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
| 220 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
|
|
|
| 233 |
|
| 234 |
for epoch in range(num_train_epochs):
|
| 235 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
| 236 |
+
train_dataset_group.set_current_epoch(epoch + 1)
|
| 237 |
|
| 238 |
for m in training_models:
|
| 239 |
m.train()
|
|
|
|
| 282 |
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
| 283 |
|
| 284 |
accelerator.backward(loss)
|
| 285 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
| 286 |
params_to_clip = []
|
| 287 |
for m in training_models:
|
| 288 |
params_to_clip.extend(m.parameters())
|
| 289 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 290 |
|
| 291 |
optimizer.step()
|
| 292 |
lr_scheduler.step()
|
|
|
|
| 297 |
progress_bar.update(1)
|
| 298 |
global_step += 1
|
| 299 |
|
| 300 |
+
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
| 301 |
+
|
| 302 |
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
| 303 |
if args.logging_dir is not None:
|
| 304 |
+
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
| 305 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
| 306 |
+
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
| 307 |
accelerator.log(logs, step=global_step)
|
| 308 |
|
| 309 |
+
# TODO moving averageにする
|
| 310 |
loss_total += current_loss
|
| 311 |
avr_loss = loss_total / (step+1)
|
| 312 |
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
|
| 316 |
break
|
| 317 |
|
| 318 |
if args.logging_dir is not None:
|
| 319 |
+
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
| 320 |
accelerator.log(logs, step=epoch+1)
|
| 321 |
|
| 322 |
accelerator.wait_for_everyone()
|
|
|
|
| 326 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
| 327 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
| 328 |
|
| 329 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
| 330 |
+
|
| 331 |
is_main_process = accelerator.is_main_process
|
| 332 |
if is_main_process:
|
| 333 |
unet = unwrap_model(unet)
|
|
|
|
| 354 |
train_util.add_dataset_arguments(parser, False, True, True)
|
| 355 |
train_util.add_training_arguments(parser, False)
|
| 356 |
train_util.add_sd_saving_arguments(parser)
|
| 357 |
+
train_util.add_optimizer_arguments(parser)
|
| 358 |
+
config_util.add_config_arguments(parser)
|
| 359 |
|
| 360 |
parser.add_argument("--diffusers_xformers", action='store_true',
|
| 361 |
help='use xformers by diffusers / Diffusersでxformersを使用する')
|
gen_img_diffusers.py
CHANGED
|
@@ -47,7 +47,7 @@ VGG(
|
|
| 47 |
"""
|
| 48 |
|
| 49 |
import json
|
| 50 |
-
from typing import List, Optional, Union
|
| 51 |
import glob
|
| 52 |
import importlib
|
| 53 |
import inspect
|
|
@@ -60,7 +60,6 @@ import math
|
|
| 60 |
import os
|
| 61 |
import random
|
| 62 |
import re
|
| 63 |
-
from typing import Any, Callable, List, Optional, Union
|
| 64 |
|
| 65 |
import diffusers
|
| 66 |
import numpy as np
|
|
@@ -81,6 +80,9 @@ from PIL import Image
|
|
| 81 |
from PIL.PngImagePlugin import PngInfo
|
| 82 |
|
| 83 |
import library.model_util as model_util
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
| 86 |
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
|
@@ -487,6 +489,9 @@ class PipelineLike():
|
|
| 487 |
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
| 488 |
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
| 489 |
|
|
|
|
|
|
|
|
|
|
| 490 |
# Textual Inversion
|
| 491 |
def add_token_replacement(self, target_token_id, rep_token_ids):
|
| 492 |
self.token_replacements[target_token_id] = rep_token_ids
|
|
@@ -500,7 +505,11 @@ class PipelineLike():
|
|
| 500 |
new_tokens.append(token)
|
| 501 |
return new_tokens
|
| 502 |
|
|
|
|
|
|
|
|
|
|
| 503 |
# region xformersとか使う部分:独自に書き換えるので関係なし
|
|
|
|
| 504 |
def enable_xformers_memory_efficient_attention(self):
|
| 505 |
r"""
|
| 506 |
Enable memory efficient attention as implemented in xformers.
|
|
@@ -581,6 +590,8 @@ class PipelineLike():
|
|
| 581 |
latents: Optional[torch.FloatTensor] = None,
|
| 582 |
max_embeddings_multiples: Optional[int] = 3,
|
| 583 |
output_type: Optional[str] = "pil",
|
|
|
|
|
|
|
| 584 |
# return_dict: bool = True,
|
| 585 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 586 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
|
@@ -672,6 +683,9 @@ class PipelineLike():
|
|
| 672 |
else:
|
| 673 |
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 674 |
|
|
|
|
|
|
|
|
|
|
| 675 |
if strength < 0 or strength > 1:
|
| 676 |
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
| 677 |
|
|
@@ -752,7 +766,7 @@ class PipelineLike():
|
|
| 752 |
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
|
| 753 |
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
|
| 754 |
|
| 755 |
-
if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None:
|
| 756 |
if isinstance(clip_guide_images, PIL.Image.Image):
|
| 757 |
clip_guide_images = [clip_guide_images]
|
| 758 |
|
|
@@ -765,7 +779,7 @@ class PipelineLike():
|
|
| 765 |
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
|
| 766 |
if len(image_embeddings_clip) == 1:
|
| 767 |
image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
|
| 768 |
-
|
| 769 |
size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
|
| 770 |
clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
|
| 771 |
clip_guide_images = torch.cat(clip_guide_images, dim=0)
|
|
@@ -774,6 +788,10 @@ class PipelineLike():
|
|
| 774 |
image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
|
| 775 |
if len(image_embeddings_vgg16) == 1:
|
| 776 |
image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 777 |
|
| 778 |
# set timesteps
|
| 779 |
self.scheduler.set_timesteps(num_inference_steps, self.device)
|
|
@@ -781,7 +799,6 @@ class PipelineLike():
|
|
| 781 |
latents_dtype = text_embeddings.dtype
|
| 782 |
init_latents_orig = None
|
| 783 |
mask = None
|
| 784 |
-
noise = None
|
| 785 |
|
| 786 |
if init_image is None:
|
| 787 |
# get the initial random noise unless the user supplied it
|
|
@@ -813,6 +830,8 @@ class PipelineLike():
|
|
| 813 |
if isinstance(init_image[0], PIL.Image.Image):
|
| 814 |
init_image = [preprocess_image(im) for im in init_image]
|
| 815 |
init_image = torch.cat(init_image)
|
|
|
|
|
|
|
| 816 |
|
| 817 |
# mask image to tensor
|
| 818 |
if mask_image is not None:
|
|
@@ -823,9 +842,24 @@ class PipelineLike():
|
|
| 823 |
|
| 824 |
# encode the init image into latents and scale the latents
|
| 825 |
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 829 |
if len(init_latents) == 1:
|
| 830 |
init_latents = init_latents.repeat((batch_size, 1, 1, 1))
|
| 831 |
init_latents_orig = init_latents
|
|
@@ -864,12 +898,21 @@ class PipelineLike():
|
|
| 864 |
extra_step_kwargs["eta"] = eta
|
| 865 |
|
| 866 |
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
|
|
|
|
|
|
|
|
|
|
|
|
|
| 867 |
for i, t in enumerate(tqdm(timesteps)):
|
| 868 |
# expand the latents if we are doing classifier free guidance
|
| 869 |
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
| 870 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
|
|
|
| 871 |
# predict the noise residual
|
| 872 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 873 |
|
| 874 |
# perform guidance
|
| 875 |
if do_classifier_free_guidance:
|
|
@@ -911,8 +954,19 @@ class PipelineLike():
|
|
| 911 |
if is_cancelled_callback is not None and is_cancelled_callback():
|
| 912 |
return None
|
| 913 |
|
|
|
|
|
|
|
|
|
|
| 914 |
latents = 1 / 0.18215 * latents
|
| 915 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 916 |
|
| 917 |
image = (image / 2 + 0.5).clamp(0, 1)
|
| 918 |
|
|
@@ -1595,10 +1649,11 @@ def get_unweighted_text_embeddings(
|
|
| 1595 |
if pad == eos: # v1
|
| 1596 |
text_input_chunk[:, -1] = text_input[0, -1]
|
| 1597 |
else: # v2
|
| 1598 |
-
|
| 1599 |
-
text_input_chunk[
|
| 1600 |
-
|
| 1601 |
-
text_input_chunk[
|
|
|
|
| 1602 |
|
| 1603 |
if clip_skip is None or clip_skip == 1:
|
| 1604 |
text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
|
@@ -1799,7 +1854,7 @@ def preprocess_mask(mask):
|
|
| 1799 |
mask = mask.convert("L")
|
| 1800 |
w, h = mask.size
|
| 1801 |
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
| 1802 |
-
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.LANCZOS)
|
| 1803 |
mask = np.array(mask).astype(np.float32) / 255.0
|
| 1804 |
mask = np.tile(mask, (4, 1, 1))
|
| 1805 |
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
|
@@ -1817,6 +1872,35 @@ def preprocess_mask(mask):
|
|
| 1817 |
# return text_encoder
|
| 1818 |
|
| 1819 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1820 |
def main(args):
|
| 1821 |
if args.fp16:
|
| 1822 |
dtype = torch.float16
|
|
@@ -1881,10 +1965,7 @@ def main(args):
|
|
| 1881 |
# tokenizerを読み込む
|
| 1882 |
print("loading tokenizer")
|
| 1883 |
if use_stable_diffusion_format:
|
| 1884 |
-
|
| 1885 |
-
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
| 1886 |
-
else:
|
| 1887 |
-
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
|
| 1888 |
|
| 1889 |
# schedulerを用意する
|
| 1890 |
sched_init_args = {}
|
|
@@ -1995,11 +2076,13 @@ def main(args):
|
|
| 1995 |
# networkを組み込む
|
| 1996 |
if args.network_module:
|
| 1997 |
networks = []
|
|
|
|
| 1998 |
for i, network_module in enumerate(args.network_module):
|
| 1999 |
print("import network module:", network_module)
|
| 2000 |
imported_module = importlib.import_module(network_module)
|
| 2001 |
|
| 2002 |
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
|
|
|
| 2003 |
|
| 2004 |
net_kwargs = {}
|
| 2005 |
if args.network_args and i < len(args.network_args):
|
|
@@ -2014,7 +2097,7 @@ def main(args):
|
|
| 2014 |
network_weight = args.network_weights[i]
|
| 2015 |
print("load network weights from:", network_weight)
|
| 2016 |
|
| 2017 |
-
if model_util.is_safetensors(network_weight):
|
| 2018 |
from safetensors.torch import safe_open
|
| 2019 |
with safe_open(network_weight, framework="pt") as f:
|
| 2020 |
metadata = f.metadata()
|
|
@@ -2037,6 +2120,18 @@ def main(args):
|
|
| 2037 |
else:
|
| 2038 |
networks = []
|
| 2039 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2040 |
if args.opt_channels_last:
|
| 2041 |
print(f"set optimizing: channels last")
|
| 2042 |
text_encoder.to(memory_format=torch.channels_last)
|
|
@@ -2050,9 +2145,14 @@ def main(args):
|
|
| 2050 |
if vgg16_model is not None:
|
| 2051 |
vgg16_model.to(memory_format=torch.channels_last)
|
| 2052 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2053 |
pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
|
| 2054 |
clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
|
| 2055 |
vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
|
|
|
|
| 2056 |
print("pipeline is ready.")
|
| 2057 |
|
| 2058 |
if args.diffusers_xformers:
|
|
@@ -2177,18 +2277,34 @@ def main(args):
|
|
| 2177 |
mask_images = l
|
| 2178 |
|
| 2179 |
# 画像サイズにオプション指定があるときはリサイズする
|
| 2180 |
-
if
|
| 2181 |
-
|
| 2182 |
-
|
|
|
|
| 2183 |
if mask_images is not None:
|
| 2184 |
print(f"resize img2img mask images to {args.W}*{args.H}")
|
| 2185 |
mask_images = resize_images(mask_images, (args.W, args.H))
|
| 2186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2187 |
prev_image = None # for VGG16 guided
|
| 2188 |
if args.guide_image_path is not None:
|
| 2189 |
-
print(f"load image for CLIP/VGG16 guidance: {args.guide_image_path}")
|
| 2190 |
-
guide_images =
|
| 2191 |
-
|
|
|
|
|
|
|
|
|
|
| 2192 |
if len(guide_images) == 0:
|
| 2193 |
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
|
| 2194 |
guide_images = None
|
|
@@ -2219,33 +2335,46 @@ def main(args):
|
|
| 2219 |
iter_seed = random.randint(0, 0x7fffffff)
|
| 2220 |
|
| 2221 |
# バッチ処理の関数
|
| 2222 |
-
def process_batch(batch, highres_fix, highres_1st=False):
|
| 2223 |
batch_size = len(batch)
|
| 2224 |
|
| 2225 |
# highres_fixの処理
|
| 2226 |
if highres_fix and not highres_1st:
|
| 2227 |
-
# 1st stage
|
| 2228 |
-
print("process 1st
|
| 2229 |
batch_1st = []
|
| 2230 |
-
for
|
| 2231 |
-
width_1st = int(width * args.highres_fix_scale + .5)
|
| 2232 |
-
height_1st = int(height * args.highres_fix_scale + .5)
|
| 2233 |
width_1st = width_1st - width_1st % 32
|
| 2234 |
height_1st = height_1st - height_1st % 32
|
| 2235 |
-
|
|
|
|
|
|
|
|
|
|
| 2236 |
images_1st = process_batch(batch_1st, True, True)
|
| 2237 |
|
| 2238 |
# 2nd stageのバッチを作成して以下処理する
|
| 2239 |
-
print("process 2nd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2240 |
batch_2nd = []
|
| 2241 |
-
for i, (
|
| 2242 |
-
|
| 2243 |
-
|
| 2244 |
-
|
|
|
|
| 2245 |
batch = batch_2nd
|
| 2246 |
|
| 2247 |
-
|
| 2248 |
-
|
|
|
|
| 2249 |
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
| 2250 |
|
| 2251 |
prompts = []
|
|
@@ -2278,7 +2407,7 @@ def main(args):
|
|
| 2278 |
all_images_are_same = True
|
| 2279 |
all_masks_are_same = True
|
| 2280 |
all_guide_images_are_same = True
|
| 2281 |
-
for i, ((_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
|
| 2282 |
prompts.append(prompt)
|
| 2283 |
negative_prompts.append(negative_prompt)
|
| 2284 |
seeds.append(seed)
|
|
@@ -2295,9 +2424,13 @@ def main(args):
|
|
| 2295 |
all_masks_are_same = mask_images[-2] is mask_image
|
| 2296 |
|
| 2297 |
if guide_image is not None:
|
| 2298 |
-
|
| 2299 |
-
|
| 2300 |
-
all_guide_images_are_same =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2301 |
|
| 2302 |
# make start code
|
| 2303 |
torch.manual_seed(seed)
|
|
@@ -2320,10 +2453,24 @@ def main(args):
|
|
| 2320 |
if guide_images is not None and all_guide_images_are_same:
|
| 2321 |
guide_images = guide_images[0]
|
| 2322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2323 |
# generate
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2324 |
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
|
| 2325 |
-
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises,
|
| 2326 |
-
|
|
|
|
|
|
|
| 2327 |
return images
|
| 2328 |
|
| 2329 |
# save image
|
|
@@ -2398,6 +2545,7 @@ def main(args):
|
|
| 2398 |
strength = 0.8 if args.strength is None else args.strength
|
| 2399 |
negative_prompt = ""
|
| 2400 |
clip_prompt = None
|
|
|
|
| 2401 |
|
| 2402 |
prompt_args = prompt.strip().split(' --')
|
| 2403 |
prompt = prompt_args[0]
|
|
@@ -2461,6 +2609,15 @@ def main(args):
|
|
| 2461 |
clip_prompt = m.group(1)
|
| 2462 |
print(f"clip prompt: {clip_prompt}")
|
| 2463 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2464 |
except ValueError as ex:
|
| 2465 |
print(f"Exception in parsing / 解析エラー: {parg}")
|
| 2466 |
print(ex)
|
|
@@ -2498,7 +2655,12 @@ def main(args):
|
|
| 2498 |
mask_image = mask_images[global_step % len(mask_images)]
|
| 2499 |
|
| 2500 |
if guide_images is not None:
|
| 2501 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2502 |
elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
|
| 2503 |
if prev_image is None:
|
| 2504 |
print("Generate 1st image without guide image.")
|
|
@@ -2506,10 +2668,9 @@ def main(args):
|
|
| 2506 |
print("Use previous image as guide image.")
|
| 2507 |
guide_image = prev_image
|
| 2508 |
|
| 2509 |
-
|
| 2510 |
-
|
| 2511 |
-
|
| 2512 |
-
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
|
| 2513 |
process_batch(batch_data, highres_fix)
|
| 2514 |
batch_data.clear()
|
| 2515 |
|
|
@@ -2553,6 +2714,8 @@ if __name__ == '__main__':
|
|
| 2553 |
parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
|
| 2554 |
parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
|
| 2555 |
parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
|
|
|
|
|
|
|
| 2556 |
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
|
| 2557 |
parser.add_argument('--sampler', type=str, default='ddim',
|
| 2558 |
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
|
|
@@ -2564,6 +2727,8 @@ if __name__ == '__main__':
|
|
| 2564 |
parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
|
| 2565 |
parser.add_argument("--vae", type=str, default=None,
|
| 2566 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
|
|
|
|
|
|
| 2567 |
# parser.add_argument("--replace_clip_l14_336", action='store_true',
|
| 2568 |
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
|
| 2569 |
parser.add_argument("--seed", type=int, default=None,
|
|
@@ -2578,12 +2743,15 @@ if __name__ == '__main__':
|
|
| 2578 |
parser.add_argument("--opt_channels_last", action='store_true',
|
| 2579 |
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
|
| 2580 |
parser.add_argument("--network_module", type=str, default=None, nargs='*',
|
| 2581 |
-
help='
|
| 2582 |
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
| 2583 |
-
help='
|
| 2584 |
-
parser.add_argument("--network_mul", type=float, default=None, nargs='*',
|
|
|
|
| 2585 |
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
| 2586 |
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
|
|
|
|
|
|
| 2587 |
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
|
| 2588 |
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
|
| 2589 |
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
|
@@ -2597,15 +2765,26 @@ if __name__ == '__main__':
|
|
| 2597 |
help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
|
| 2598 |
parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
|
| 2599 |
help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
|
| 2600 |
-
parser.add_argument("--guide_image_path", type=str, default=None,
|
|
|
|
| 2601 |
parser.add_argument("--highres_fix_scale", type=float, default=None,
|
| 2602 |
help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
|
| 2603 |
parser.add_argument("--highres_fix_steps", type=int, default=28,
|
| 2604 |
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
|
| 2605 |
parser.add_argument("--highres_fix_save_1st", action='store_true',
|
| 2606 |
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
|
|
|
|
|
|
|
| 2607 |
parser.add_argument("--negative_scale", type=float, default=None,
|
| 2608 |
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
|
| 2609 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2610 |
args = parser.parse_args()
|
| 2611 |
main(args)
|
|
|
|
| 47 |
"""
|
| 48 |
|
| 49 |
import json
|
| 50 |
+
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
|
| 51 |
import glob
|
| 52 |
import importlib
|
| 53 |
import inspect
|
|
|
|
| 60 |
import os
|
| 61 |
import random
|
| 62 |
import re
|
|
|
|
| 63 |
|
| 64 |
import diffusers
|
| 65 |
import numpy as np
|
|
|
|
| 80 |
from PIL.PngImagePlugin import PngInfo
|
| 81 |
|
| 82 |
import library.model_util as model_util
|
| 83 |
+
import library.train_util as train_util
|
| 84 |
+
import tools.original_control_net as original_control_net
|
| 85 |
+
from tools.original_control_net import ControlNetInfo
|
| 86 |
|
| 87 |
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
| 88 |
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
|
|
|
| 489 |
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
| 490 |
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
| 491 |
|
| 492 |
+
# ControlNet
|
| 493 |
+
self.control_nets: List[ControlNetInfo] = []
|
| 494 |
+
|
| 495 |
# Textual Inversion
|
| 496 |
def add_token_replacement(self, target_token_id, rep_token_ids):
|
| 497 |
self.token_replacements[target_token_id] = rep_token_ids
|
|
|
|
| 505 |
new_tokens.append(token)
|
| 506 |
return new_tokens
|
| 507 |
|
| 508 |
+
def set_control_nets(self, ctrl_nets):
|
| 509 |
+
self.control_nets = ctrl_nets
|
| 510 |
+
|
| 511 |
# region xformersとか使う部分:独自に書き換えるので関係なし
|
| 512 |
+
|
| 513 |
def enable_xformers_memory_efficient_attention(self):
|
| 514 |
r"""
|
| 515 |
Enable memory efficient attention as implemented in xformers.
|
|
|
|
| 590 |
latents: Optional[torch.FloatTensor] = None,
|
| 591 |
max_embeddings_multiples: Optional[int] = 3,
|
| 592 |
output_type: Optional[str] = "pil",
|
| 593 |
+
vae_batch_size: float = None,
|
| 594 |
+
return_latents: bool = False,
|
| 595 |
# return_dict: bool = True,
|
| 596 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 597 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
|
|
|
| 683 |
else:
|
| 684 |
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 685 |
|
| 686 |
+
vae_batch_size = batch_size if vae_batch_size is None else (
|
| 687 |
+
int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size)))
|
| 688 |
+
|
| 689 |
if strength < 0 or strength > 1:
|
| 690 |
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
| 691 |
|
|
|
|
| 766 |
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
|
| 767 |
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
|
| 768 |
|
| 769 |
+
if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None or self.control_nets:
|
| 770 |
if isinstance(clip_guide_images, PIL.Image.Image):
|
| 771 |
clip_guide_images = [clip_guide_images]
|
| 772 |
|
|
|
|
| 779 |
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
|
| 780 |
if len(image_embeddings_clip) == 1:
|
| 781 |
image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
|
| 782 |
+
elif self.vgg16_guidance_scale > 0:
|
| 783 |
size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
|
| 784 |
clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
|
| 785 |
clip_guide_images = torch.cat(clip_guide_images, dim=0)
|
|
|
|
| 788 |
image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
|
| 789 |
if len(image_embeddings_vgg16) == 1:
|
| 790 |
image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
|
| 791 |
+
else:
|
| 792 |
+
# ControlNetのhintにguide imageを流用する
|
| 793 |
+
# 前処理はControlNet側で行う
|
| 794 |
+
pass
|
| 795 |
|
| 796 |
# set timesteps
|
| 797 |
self.scheduler.set_timesteps(num_inference_steps, self.device)
|
|
|
|
| 799 |
latents_dtype = text_embeddings.dtype
|
| 800 |
init_latents_orig = None
|
| 801 |
mask = None
|
|
|
|
| 802 |
|
| 803 |
if init_image is None:
|
| 804 |
# get the initial random noise unless the user supplied it
|
|
|
|
| 830 |
if isinstance(init_image[0], PIL.Image.Image):
|
| 831 |
init_image = [preprocess_image(im) for im in init_image]
|
| 832 |
init_image = torch.cat(init_image)
|
| 833 |
+
if isinstance(init_image, list):
|
| 834 |
+
init_image = torch.stack(init_image)
|
| 835 |
|
| 836 |
# mask image to tensor
|
| 837 |
if mask_image is not None:
|
|
|
|
| 842 |
|
| 843 |
# encode the init image into latents and scale the latents
|
| 844 |
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
| 845 |
+
if init_image.size()[2:] == (height // 8, width // 8):
|
| 846 |
+
init_latents = init_image
|
| 847 |
+
else:
|
| 848 |
+
if vae_batch_size >= batch_size:
|
| 849 |
+
init_latent_dist = self.vae.encode(init_image).latent_dist
|
| 850 |
+
init_latents = init_latent_dist.sample(generator=generator)
|
| 851 |
+
else:
|
| 852 |
+
if torch.cuda.is_available():
|
| 853 |
+
torch.cuda.empty_cache()
|
| 854 |
+
init_latents = []
|
| 855 |
+
for i in tqdm(range(0, batch_size, vae_batch_size)):
|
| 856 |
+
init_latent_dist = self.vae.encode(init_image[i:i + vae_batch_size]
|
| 857 |
+
if vae_batch_size > 1 else init_image[i].unsqueeze(0)).latent_dist
|
| 858 |
+
init_latents.append(init_latent_dist.sample(generator=generator))
|
| 859 |
+
init_latents = torch.cat(init_latents)
|
| 860 |
+
|
| 861 |
+
init_latents = 0.18215 * init_latents
|
| 862 |
+
|
| 863 |
if len(init_latents) == 1:
|
| 864 |
init_latents = init_latents.repeat((batch_size, 1, 1, 1))
|
| 865 |
init_latents_orig = init_latents
|
|
|
|
| 898 |
extra_step_kwargs["eta"] = eta
|
| 899 |
|
| 900 |
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
|
| 901 |
+
|
| 902 |
+
if self.control_nets:
|
| 903 |
+
guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
|
| 904 |
+
|
| 905 |
for i, t in enumerate(tqdm(timesteps)):
|
| 906 |
# expand the latents if we are doing classifier free guidance
|
| 907 |
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
| 908 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 909 |
+
|
| 910 |
# predict the noise residual
|
| 911 |
+
if self.control_nets:
|
| 912 |
+
noise_pred = original_control_net.call_unet_and_control_net(
|
| 913 |
+
i, num_latent_input, self.unet, self.control_nets, guided_hints, i / len(timesteps), latent_model_input, t, text_embeddings).sample
|
| 914 |
+
else:
|
| 915 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
| 916 |
|
| 917 |
# perform guidance
|
| 918 |
if do_classifier_free_guidance:
|
|
|
|
| 954 |
if is_cancelled_callback is not None and is_cancelled_callback():
|
| 955 |
return None
|
| 956 |
|
| 957 |
+
if return_latents:
|
| 958 |
+
return (latents, False)
|
| 959 |
+
|
| 960 |
latents = 1 / 0.18215 * latents
|
| 961 |
+
if vae_batch_size >= batch_size:
|
| 962 |
+
image = self.vae.decode(latents).sample
|
| 963 |
+
else:
|
| 964 |
+
if torch.cuda.is_available():
|
| 965 |
+
torch.cuda.empty_cache()
|
| 966 |
+
images = []
|
| 967 |
+
for i in tqdm(range(0, batch_size, vae_batch_size)):
|
| 968 |
+
images.append(self.vae.decode(latents[i:i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample)
|
| 969 |
+
image = torch.cat(images)
|
| 970 |
|
| 971 |
image = (image / 2 + 0.5).clamp(0, 1)
|
| 972 |
|
|
|
|
| 1649 |
if pad == eos: # v1
|
| 1650 |
text_input_chunk[:, -1] = text_input[0, -1]
|
| 1651 |
else: # v2
|
| 1652 |
+
for j in range(len(text_input_chunk)):
|
| 1653 |
+
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
| 1654 |
+
text_input_chunk[j, -1] = eos
|
| 1655 |
+
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
| 1656 |
+
text_input_chunk[j, 1] = eos
|
| 1657 |
|
| 1658 |
if clip_skip is None or clip_skip == 1:
|
| 1659 |
text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
|
|
|
| 1854 |
mask = mask.convert("L")
|
| 1855 |
w, h = mask.size
|
| 1856 |
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
| 1857 |
+
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS)
|
| 1858 |
mask = np.array(mask).astype(np.float32) / 255.0
|
| 1859 |
mask = np.tile(mask, (4, 1, 1))
|
| 1860 |
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
|
|
|
| 1872 |
# return text_encoder
|
| 1873 |
|
| 1874 |
|
| 1875 |
+
class BatchDataBase(NamedTuple):
|
| 1876 |
+
# バッチ分割が必要ないデータ
|
| 1877 |
+
step: int
|
| 1878 |
+
prompt: str
|
| 1879 |
+
negative_prompt: str
|
| 1880 |
+
seed: int
|
| 1881 |
+
init_image: Any
|
| 1882 |
+
mask_image: Any
|
| 1883 |
+
clip_prompt: str
|
| 1884 |
+
guide_image: Any
|
| 1885 |
+
|
| 1886 |
+
|
| 1887 |
+
class BatchDataExt(NamedTuple):
|
| 1888 |
+
# バッチ分割が必要なデータ
|
| 1889 |
+
width: int
|
| 1890 |
+
height: int
|
| 1891 |
+
steps: int
|
| 1892 |
+
scale: float
|
| 1893 |
+
negative_scale: float
|
| 1894 |
+
strength: float
|
| 1895 |
+
network_muls: Tuple[float]
|
| 1896 |
+
|
| 1897 |
+
|
| 1898 |
+
class BatchData(NamedTuple):
|
| 1899 |
+
return_latents: bool
|
| 1900 |
+
base: BatchDataBase
|
| 1901 |
+
ext: BatchDataExt
|
| 1902 |
+
|
| 1903 |
+
|
| 1904 |
def main(args):
|
| 1905 |
if args.fp16:
|
| 1906 |
dtype = torch.float16
|
|
|
|
| 1965 |
# tokenizerを読み込む
|
| 1966 |
print("loading tokenizer")
|
| 1967 |
if use_stable_diffusion_format:
|
| 1968 |
+
tokenizer = train_util.load_tokenizer(args)
|
|
|
|
|
|
|
|
|
|
| 1969 |
|
| 1970 |
# schedulerを用意する
|
| 1971 |
sched_init_args = {}
|
|
|
|
| 2076 |
# networkを組み込む
|
| 2077 |
if args.network_module:
|
| 2078 |
networks = []
|
| 2079 |
+
network_default_muls = []
|
| 2080 |
for i, network_module in enumerate(args.network_module):
|
| 2081 |
print("import network module:", network_module)
|
| 2082 |
imported_module = importlib.import_module(network_module)
|
| 2083 |
|
| 2084 |
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
| 2085 |
+
network_default_muls.append(network_mul)
|
| 2086 |
|
| 2087 |
net_kwargs = {}
|
| 2088 |
if args.network_args and i < len(args.network_args):
|
|
|
|
| 2097 |
network_weight = args.network_weights[i]
|
| 2098 |
print("load network weights from:", network_weight)
|
| 2099 |
|
| 2100 |
+
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
| 2101 |
from safetensors.torch import safe_open
|
| 2102 |
with safe_open(network_weight, framework="pt") as f:
|
| 2103 |
metadata = f.metadata()
|
|
|
|
| 2120 |
else:
|
| 2121 |
networks = []
|
| 2122 |
|
| 2123 |
+
# ControlNetの処理
|
| 2124 |
+
control_nets: List[ControlNetInfo] = []
|
| 2125 |
+
if args.control_net_models:
|
| 2126 |
+
for i, model in enumerate(args.control_net_models):
|
| 2127 |
+
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
|
| 2128 |
+
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
|
| 2129 |
+
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
|
| 2130 |
+
|
| 2131 |
+
ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
|
| 2132 |
+
prep = original_control_net.load_preprocess(prep_type)
|
| 2133 |
+
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
|
| 2134 |
+
|
| 2135 |
if args.opt_channels_last:
|
| 2136 |
print(f"set optimizing: channels last")
|
| 2137 |
text_encoder.to(memory_format=torch.channels_last)
|
|
|
|
| 2145 |
if vgg16_model is not None:
|
| 2146 |
vgg16_model.to(memory_format=torch.channels_last)
|
| 2147 |
|
| 2148 |
+
for cn in control_nets:
|
| 2149 |
+
cn.unet.to(memory_format=torch.channels_last)
|
| 2150 |
+
cn.net.to(memory_format=torch.channels_last)
|
| 2151 |
+
|
| 2152 |
pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
|
| 2153 |
clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
|
| 2154 |
vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
|
| 2155 |
+
pipe.set_control_nets(control_nets)
|
| 2156 |
print("pipeline is ready.")
|
| 2157 |
|
| 2158 |
if args.diffusers_xformers:
|
|
|
|
| 2277 |
mask_images = l
|
| 2278 |
|
| 2279 |
# 画像サイズにオプション指定があるときはリサイズする
|
| 2280 |
+
if args.W is not None and args.H is not None:
|
| 2281 |
+
if init_images is not None:
|
| 2282 |
+
print(f"resize img2img source images to {args.W}*{args.H}")
|
| 2283 |
+
init_images = resize_images(init_images, (args.W, args.H))
|
| 2284 |
if mask_images is not None:
|
| 2285 |
print(f"resize img2img mask images to {args.W}*{args.H}")
|
| 2286 |
mask_images = resize_images(mask_images, (args.W, args.H))
|
| 2287 |
|
| 2288 |
+
if networks and mask_images:
|
| 2289 |
+
# mask を領域情報として流用する、現在は1枚だけ対応
|
| 2290 |
+
# TODO 複数のnetwork classの混在時の考慮
|
| 2291 |
+
print("use mask as region")
|
| 2292 |
+
# import cv2
|
| 2293 |
+
# for i in range(3):
|
| 2294 |
+
# cv2.imshow("msk", np.array(mask_images[0])[:,:,i])
|
| 2295 |
+
# cv2.waitKey()
|
| 2296 |
+
# cv2.destroyAllWindows()
|
| 2297 |
+
networks[0].__class__.set_regions(networks, np.array(mask_images[0]))
|
| 2298 |
+
mask_images = None
|
| 2299 |
+
|
| 2300 |
prev_image = None # for VGG16 guided
|
| 2301 |
if args.guide_image_path is not None:
|
| 2302 |
+
print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
|
| 2303 |
+
guide_images = []
|
| 2304 |
+
for p in args.guide_image_path:
|
| 2305 |
+
guide_images.extend(load_images(p))
|
| 2306 |
+
|
| 2307 |
+
print(f"loaded {len(guide_images)} guide images for guidance")
|
| 2308 |
if len(guide_images) == 0:
|
| 2309 |
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
|
| 2310 |
guide_images = None
|
|
|
|
| 2335 |
iter_seed = random.randint(0, 0x7fffffff)
|
| 2336 |
|
| 2337 |
# バッチ処理の関数
|
| 2338 |
+
def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
|
| 2339 |
batch_size = len(batch)
|
| 2340 |
|
| 2341 |
# highres_fixの処理
|
| 2342 |
if highres_fix and not highres_1st:
|
| 2343 |
+
# 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
|
| 2344 |
+
print("process 1st stage")
|
| 2345 |
batch_1st = []
|
| 2346 |
+
for _, base, ext in batch:
|
| 2347 |
+
width_1st = int(ext.width * args.highres_fix_scale + .5)
|
| 2348 |
+
height_1st = int(ext.height * args.highres_fix_scale + .5)
|
| 2349 |
width_1st = width_1st - width_1st % 32
|
| 2350 |
height_1st = height_1st - height_1st % 32
|
| 2351 |
+
|
| 2352 |
+
ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale,
|
| 2353 |
+
ext.negative_scale, ext.strength, ext.network_muls)
|
| 2354 |
+
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
|
| 2355 |
images_1st = process_batch(batch_1st, True, True)
|
| 2356 |
|
| 2357 |
# 2nd stageのバッチを作成して以下処理する
|
| 2358 |
+
print("process 2nd stage")
|
| 2359 |
+
if args.highres_fix_latents_upscaling:
|
| 2360 |
+
org_dtype = images_1st.dtype
|
| 2361 |
+
if images_1st.dtype == torch.bfloat16:
|
| 2362 |
+
images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
|
| 2363 |
+
images_1st = torch.nn.functional.interpolate(
|
| 2364 |
+
images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode='bilinear') # , antialias=True)
|
| 2365 |
+
images_1st = images_1st.to(org_dtype)
|
| 2366 |
+
|
| 2367 |
batch_2nd = []
|
| 2368 |
+
for i, (bd, image) in enumerate(zip(batch, images_1st)):
|
| 2369 |
+
if not args.highres_fix_latents_upscaling:
|
| 2370 |
+
image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
|
| 2371 |
+
bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext)
|
| 2372 |
+
batch_2nd.append(bd_2nd)
|
| 2373 |
batch = batch_2nd
|
| 2374 |
|
| 2375 |
+
# このバッチの情報を取り出す
|
| 2376 |
+
return_latents, (step_first, _, _, _, init_image, mask_image, _, guide_image), \
|
| 2377 |
+
(width, height, steps, scale, negative_scale, strength, network_muls) = batch[0]
|
| 2378 |
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
| 2379 |
|
| 2380 |
prompts = []
|
|
|
|
| 2407 |
all_images_are_same = True
|
| 2408 |
all_masks_are_same = True
|
| 2409 |
all_guide_images_are_same = True
|
| 2410 |
+
for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
|
| 2411 |
prompts.append(prompt)
|
| 2412 |
negative_prompts.append(negative_prompt)
|
| 2413 |
seeds.append(seed)
|
|
|
|
| 2424 |
all_masks_are_same = mask_images[-2] is mask_image
|
| 2425 |
|
| 2426 |
if guide_image is not None:
|
| 2427 |
+
if type(guide_image) is list:
|
| 2428 |
+
guide_images.extend(guide_image)
|
| 2429 |
+
all_guide_images_are_same = False
|
| 2430 |
+
else:
|
| 2431 |
+
guide_images.append(guide_image)
|
| 2432 |
+
if i > 0 and all_guide_images_are_same:
|
| 2433 |
+
all_guide_images_are_same = guide_images[-2] is guide_image
|
| 2434 |
|
| 2435 |
# make start code
|
| 2436 |
torch.manual_seed(seed)
|
|
|
|
| 2453 |
if guide_images is not None and all_guide_images_are_same:
|
| 2454 |
guide_images = guide_images[0]
|
| 2455 |
|
| 2456 |
+
# ControlNet使用時はguide imageをリサイズする
|
| 2457 |
+
if control_nets:
|
| 2458 |
+
# TODO resample��メソッド
|
| 2459 |
+
guide_images = guide_images if type(guide_images) == list else [guide_images]
|
| 2460 |
+
guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images]
|
| 2461 |
+
if len(guide_images) == 1:
|
| 2462 |
+
guide_images = guide_images[0]
|
| 2463 |
+
|
| 2464 |
# generate
|
| 2465 |
+
if networks:
|
| 2466 |
+
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
| 2467 |
+
n.set_multiplier(m)
|
| 2468 |
+
|
| 2469 |
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
|
| 2470 |
+
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises,
|
| 2471 |
+
vae_batch_size=args.vae_batch_size, return_latents=return_latents,
|
| 2472 |
+
clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
|
| 2473 |
+
if highres_1st and not args.highres_fix_save_1st: # return images or latents
|
| 2474 |
return images
|
| 2475 |
|
| 2476 |
# save image
|
|
|
|
| 2545 |
strength = 0.8 if args.strength is None else args.strength
|
| 2546 |
negative_prompt = ""
|
| 2547 |
clip_prompt = None
|
| 2548 |
+
network_muls = None
|
| 2549 |
|
| 2550 |
prompt_args = prompt.strip().split(' --')
|
| 2551 |
prompt = prompt_args[0]
|
|
|
|
| 2609 |
clip_prompt = m.group(1)
|
| 2610 |
print(f"clip prompt: {clip_prompt}")
|
| 2611 |
continue
|
| 2612 |
+
|
| 2613 |
+
m = re.match(r'am ([\d\.\-,]+)', parg, re.IGNORECASE)
|
| 2614 |
+
if m: # network multiplies
|
| 2615 |
+
network_muls = [float(v) for v in m.group(1).split(",")]
|
| 2616 |
+
while len(network_muls) < len(networks):
|
| 2617 |
+
network_muls.append(network_muls[-1])
|
| 2618 |
+
print(f"network mul: {network_muls}")
|
| 2619 |
+
continue
|
| 2620 |
+
|
| 2621 |
except ValueError as ex:
|
| 2622 |
print(f"Exception in parsing / 解析エラー: {parg}")
|
| 2623 |
print(ex)
|
|
|
|
| 2655 |
mask_image = mask_images[global_step % len(mask_images)]
|
| 2656 |
|
| 2657 |
if guide_images is not None:
|
| 2658 |
+
if control_nets: # 複数件の場合あり
|
| 2659 |
+
c = len(control_nets)
|
| 2660 |
+
p = global_step % (len(guide_images) // c)
|
| 2661 |
+
guide_image = guide_images[p * c:p * c + c]
|
| 2662 |
+
else:
|
| 2663 |
+
guide_image = guide_images[global_step % len(guide_images)]
|
| 2664 |
elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
|
| 2665 |
if prev_image is None:
|
| 2666 |
print("Generate 1st image without guide image.")
|
|
|
|
| 2668 |
print("Use previous image as guide image.")
|
| 2669 |
guide_image = prev_image
|
| 2670 |
|
| 2671 |
+
b1 = BatchData(False, BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
| 2672 |
+
BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None))
|
| 2673 |
+
if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
|
|
|
|
| 2674 |
process_batch(batch_data, highres_fix)
|
| 2675 |
batch_data.clear()
|
| 2676 |
|
|
|
|
| 2714 |
parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
|
| 2715 |
parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
|
| 2716 |
parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
|
| 2717 |
+
parser.add_argument("--vae_batch_size", type=float, default=None,
|
| 2718 |
+
help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率")
|
| 2719 |
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
|
| 2720 |
parser.add_argument('--sampler', type=str, default='ddim',
|
| 2721 |
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
|
|
|
|
| 2727 |
parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
|
| 2728 |
parser.add_argument("--vae", type=str, default=None,
|
| 2729 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
| 2730 |
+
parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
|
| 2731 |
+
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
|
| 2732 |
# parser.add_argument("--replace_clip_l14_336", action='store_true',
|
| 2733 |
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
|
| 2734 |
parser.add_argument("--seed", type=int, default=None,
|
|
|
|
| 2743 |
parser.add_argument("--opt_channels_last", action='store_true',
|
| 2744 |
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
|
| 2745 |
parser.add_argument("--network_module", type=str, default=None, nargs='*',
|
| 2746 |
+
help='additional network module to use / 追加ネットワークを使う時そのモジュール名')
|
| 2747 |
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
| 2748 |
+
help='additional network weights to load / 追加ネットワークの重み')
|
| 2749 |
+
parser.add_argument("--network_mul", type=float, default=None, nargs='*',
|
| 2750 |
+
help='additional network multiplier / 追加ネットワークの効果の倍率')
|
| 2751 |
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
| 2752 |
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
| 2753 |
+
parser.add_argument("--network_show_meta", action='store_true',
|
| 2754 |
+
help='show metadata of network model / ネットワークモデルのメタデータを表示する')
|
| 2755 |
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
|
| 2756 |
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
|
| 2757 |
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
|
|
|
| 2765 |
help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
|
| 2766 |
parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
|
| 2767 |
help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
|
| 2768 |
+
parser.add_argument("--guide_image_path", type=str, default=None, nargs="*",
|
| 2769 |
+
help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
|
| 2770 |
parser.add_argument("--highres_fix_scale", type=float, default=None,
|
| 2771 |
help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
|
| 2772 |
parser.add_argument("--highres_fix_steps", type=int, default=28,
|
| 2773 |
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
|
| 2774 |
parser.add_argument("--highres_fix_save_1st", action='store_true',
|
| 2775 |
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
|
| 2776 |
+
parser.add_argument("--highres_fix_latents_upscaling", action='store_true',
|
| 2777 |
+
help="use latents upscaling for highres fix / highres fixでlatentで拡大する")
|
| 2778 |
parser.add_argument("--negative_scale", type=float, default=None,
|
| 2779 |
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
|
| 2780 |
|
| 2781 |
+
parser.add_argument("--control_net_models", type=str, default=None, nargs='*',
|
| 2782 |
+
help='ControlNet models to use / 使用するControlNetのモデル名')
|
| 2783 |
+
parser.add_argument("--control_net_preps", type=str, default=None, nargs='*',
|
| 2784 |
+
help='ControlNet preprocess to use / 使用するControlNetのプリプロセス名')
|
| 2785 |
+
parser.add_argument("--control_net_weights", type=float, default=None, nargs='*', help='ControlNet weights / ControlNetの重み')
|
| 2786 |
+
parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*',
|
| 2787 |
+
help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率')
|
| 2788 |
+
|
| 2789 |
args = parser.parse_args()
|
| 2790 |
main(args)
|
library/model_util.py
CHANGED
|
@@ -4,7 +4,7 @@
|
|
| 4 |
import math
|
| 5 |
import os
|
| 6 |
import torch
|
| 7 |
-
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
|
| 8 |
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
| 9 |
from safetensors.torch import load_file, save_file
|
| 10 |
|
|
@@ -916,7 +916,11 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
|
| 916 |
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
| 917 |
else:
|
| 918 |
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
|
|
|
|
|
|
| 919 |
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
|
|
|
|
|
|
| 920 |
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
| 921 |
print("loading text encoder:", info)
|
| 922 |
|
|
|
|
| 4 |
import math
|
| 5 |
import os
|
| 6 |
import torch
|
| 7 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
| 8 |
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
| 9 |
from safetensors.torch import load_file, save_file
|
| 10 |
|
|
|
|
| 916 |
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
| 917 |
else:
|
| 918 |
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
| 919 |
+
|
| 920 |
+
logging.set_verbosity_error() # don't show annoying warning
|
| 921 |
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
| 922 |
+
logging.set_verbosity_warning()
|
| 923 |
+
|
| 924 |
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
| 925 |
print("loading text encoder:", info)
|
| 926 |
|
library/train_util.py
CHANGED
|
@@ -1,12 +1,21 @@
|
|
| 1 |
# common functions for training
|
| 2 |
|
| 3 |
import argparse
|
|
|
|
| 4 |
import json
|
|
|
|
| 5 |
import shutil
|
| 6 |
import time
|
| 7 |
-
from typing import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from accelerate import Accelerator
|
| 9 |
-
from torch.autograd.function import Function
|
| 10 |
import glob
|
| 11 |
import math
|
| 12 |
import os
|
|
@@ -17,10 +26,16 @@ from io import BytesIO
|
|
| 17 |
|
| 18 |
from tqdm import tqdm
|
| 19 |
import torch
|
|
|
|
| 20 |
from torchvision import transforms
|
| 21 |
from transformers import CLIPTokenizer
|
|
|
|
| 22 |
import diffusers
|
| 23 |
-
from diffusers import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
import albumentations as albu
|
| 25 |
import numpy as np
|
| 26 |
from PIL import Image
|
|
@@ -195,23 +210,95 @@ class BucketBatchIndex(NamedTuple):
|
|
| 195 |
batch_index: int
|
| 196 |
|
| 197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
class BaseDataset(torch.utils.data.Dataset):
|
| 199 |
-
def __init__(self, tokenizer, max_token_length
|
| 200 |
super().__init__()
|
| 201 |
-
self.tokenizer
|
| 202 |
self.max_token_length = max_token_length
|
| 203 |
-
self.shuffle_caption = shuffle_caption
|
| 204 |
-
self.shuffle_keep_tokens = shuffle_keep_tokens
|
| 205 |
# width/height is used when enable_bucket==False
|
| 206 |
self.width, self.height = (None, None) if resolution is None else resolution
|
| 207 |
-
self.face_crop_aug_range = face_crop_aug_range
|
| 208 |
-
self.flip_aug = flip_aug
|
| 209 |
-
self.color_aug = color_aug
|
| 210 |
self.debug_dataset = debug_dataset
|
| 211 |
-
|
|
|
|
|
|
|
| 212 |
self.token_padding_disabled = False
|
| 213 |
-
self.dataset_dirs_info = {}
|
| 214 |
-
self.reg_dataset_dirs_info = {}
|
| 215 |
self.tag_frequency = {}
|
| 216 |
|
| 217 |
self.enable_bucket = False
|
|
@@ -225,49 +312,28 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
| 225 |
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
| 226 |
|
| 227 |
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
| 228 |
-
self.dropout_rate: float = 0
|
| 229 |
-
self.dropout_every_n_epochs: int = None
|
| 230 |
-
self.tag_dropout_rate: float = 0
|
| 231 |
|
| 232 |
# augmentation
|
| 233 |
-
|
| 234 |
-
if color_aug:
|
| 235 |
-
# わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
|
| 236 |
-
self.aug = albu.Compose([
|
| 237 |
-
albu.OneOf([
|
| 238 |
-
albu.HueSaturationValue(8, 0, 0, p=.5),
|
| 239 |
-
albu.RandomGamma((95, 105), p=.5),
|
| 240 |
-
], p=.33),
|
| 241 |
-
albu.HorizontalFlip(p=flip_p)
|
| 242 |
-
], p=1.)
|
| 243 |
-
elif flip_aug:
|
| 244 |
-
self.aug = albu.Compose([
|
| 245 |
-
albu.HorizontalFlip(p=flip_p)
|
| 246 |
-
], p=1.)
|
| 247 |
-
else:
|
| 248 |
-
self.aug = None
|
| 249 |
|
| 250 |
self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
|
| 251 |
|
| 252 |
self.image_data: Dict[str, ImageInfo] = {}
|
|
|
|
| 253 |
|
| 254 |
self.replacements = {}
|
| 255 |
|
| 256 |
def set_current_epoch(self, epoch):
|
| 257 |
self.current_epoch = epoch
|
| 258 |
-
|
| 259 |
-
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
|
| 260 |
-
# コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
|
| 261 |
-
self.dropout_rate = dropout_rate
|
| 262 |
-
self.dropout_every_n_epochs = dropout_every_n_epochs
|
| 263 |
-
self.tag_dropout_rate = tag_dropout_rate
|
| 264 |
|
| 265 |
def set_tag_frequency(self, dir_name, captions):
|
| 266 |
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
| 267 |
self.tag_frequency[dir_name] = frequency_for_dir
|
| 268 |
for caption in captions:
|
| 269 |
for tag in caption.split(","):
|
| 270 |
-
|
|
|
|
| 271 |
tag = tag.lower()
|
| 272 |
frequency = frequency_for_dir.get(tag, 0)
|
| 273 |
frequency_for_dir[tag] = frequency + 1
|
|
@@ -278,42 +344,36 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
| 278 |
def add_replacement(self, str_from, str_to):
|
| 279 |
self.replacements[str_from] = str_to
|
| 280 |
|
| 281 |
-
def process_caption(self, caption):
|
| 282 |
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
| 283 |
-
is_drop_out =
|
| 284 |
-
is_drop_out = is_drop_out or
|
| 285 |
|
| 286 |
if is_drop_out:
|
| 287 |
caption = ""
|
| 288 |
else:
|
| 289 |
-
if
|
| 290 |
def dropout_tags(tokens):
|
| 291 |
-
if
|
| 292 |
return tokens
|
| 293 |
l = []
|
| 294 |
for token in tokens:
|
| 295 |
-
if random.random() >=
|
| 296 |
l.append(token)
|
| 297 |
return l
|
| 298 |
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
tokens = dropout_tags(tokens)
|
| 305 |
-
else:
|
| 306 |
-
if len(tokens) > self.shuffle_keep_tokens:
|
| 307 |
-
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
| 308 |
-
tokens = tokens[self.shuffle_keep_tokens:]
|
| 309 |
|
| 310 |
-
|
| 311 |
-
|
| 312 |
|
| 313 |
-
|
| 314 |
|
| 315 |
-
|
| 316 |
-
caption = ", ".join(tokens)
|
| 317 |
|
| 318 |
# textual inversion対応
|
| 319 |
for str_from, str_to in self.replacements.items():
|
|
@@ -367,8 +427,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
| 367 |
input_ids = torch.stack(iids_list) # 3,77
|
| 368 |
return input_ids
|
| 369 |
|
| 370 |
-
def register_image(self, info: ImageInfo):
|
| 371 |
self.image_data[info.image_key] = info
|
|
|
|
| 372 |
|
| 373 |
def make_buckets(self):
|
| 374 |
'''
|
|
@@ -467,7 +528,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
| 467 |
img = np.array(image, np.uint8)
|
| 468 |
return img
|
| 469 |
|
| 470 |
-
def trim_and_resize_if_required(self, image, reso, resized_size):
|
| 471 |
image_height, image_width = image.shape[0:2]
|
| 472 |
|
| 473 |
if image_width != resized_size[0] or image_height != resized_size[1]:
|
|
@@ -477,22 +538,27 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
| 477 |
image_height, image_width = image.shape[0:2]
|
| 478 |
if image_width > reso[0]:
|
| 479 |
trim_size = image_width - reso[0]
|
| 480 |
-
p = trim_size // 2 if not
|
| 481 |
# print("w", trim_size, p)
|
| 482 |
image = image[:, p:p + reso[0]]
|
| 483 |
if image_height > reso[1]:
|
| 484 |
trim_size = image_height - reso[1]
|
| 485 |
-
p = trim_size // 2 if not
|
| 486 |
# print("h", trim_size, p)
|
| 487 |
image = image[p:p + reso[1]]
|
| 488 |
|
| 489 |
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
| 490 |
return image
|
| 491 |
|
|
|
|
|
|
|
|
|
|
| 492 |
def cache_latents(self, vae):
|
| 493 |
# TODO ここを高速化したい
|
| 494 |
print("caching latents.")
|
| 495 |
for info in tqdm(self.image_data.values()):
|
|
|
|
|
|
|
| 496 |
if info.latents_npz is not None:
|
| 497 |
info.latents = self.load_latents_from_npz(info, False)
|
| 498 |
info.latents = torch.FloatTensor(info.latents)
|
|
@@ -502,13 +568,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
| 502 |
continue
|
| 503 |
|
| 504 |
image = self.load_image(info.absolute_path)
|
| 505 |
-
image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
|
| 506 |
|
| 507 |
img_tensor = self.image_transforms(image)
|
| 508 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
| 509 |
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
| 510 |
|
| 511 |
-
if
|
| 512 |
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
|
| 513 |
img_tensor = self.image_transforms(image)
|
| 514 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
|
@@ -518,11 +584,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
| 518 |
image = Image.open(image_path)
|
| 519 |
return image.size
|
| 520 |
|
| 521 |
-
def load_image_with_face_info(self, image_path: str):
|
| 522 |
img = self.load_image(image_path)
|
| 523 |
|
| 524 |
face_cx = face_cy = face_w = face_h = 0
|
| 525 |
-
if
|
| 526 |
tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
|
| 527 |
if len(tokens) >= 5:
|
| 528 |
face_cx = int(tokens[-4])
|
|
@@ -533,7 +599,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
| 533 |
return img, face_cx, face_cy, face_w, face_h
|
| 534 |
|
| 535 |
# いい感じに切り出す
|
| 536 |
-
def crop_target(self, image, face_cx, face_cy, face_w, face_h):
|
| 537 |
height, width = image.shape[0:2]
|
| 538 |
if height == self.height and width == self.width:
|
| 539 |
return image
|
|
@@ -541,8 +607,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
| 541 |
# 画像サイズはsizeより大きいのでリサイズする
|
| 542 |
face_size = max(face_w, face_h)
|
| 543 |
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
|
| 544 |
-
min_scale = min(1.0, max(min_scale, self.size / (face_size *
|
| 545 |
-
max_scale = min(1.0, max(min_scale, self.size / (face_size *
|
| 546 |
if min_scale >= max_scale: # range指定がmin==max
|
| 547 |
scale = min_scale
|
| 548 |
else:
|
|
@@ -560,13 +626,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
| 560 |
for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
|
| 561 |
p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
|
| 562 |
|
| 563 |
-
if
|
| 564 |
# 背景も含めるために顔を中心に置く確率を高めつつずらす
|
| 565 |
range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
|
| 566 |
p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
|
| 567 |
else:
|
| 568 |
# range指定があるときのみ、すこしだけランダムに(わりと適当)
|
| 569 |
-
if
|
| 570 |
if face_size > self.size // 10 and face_size >= 40:
|
| 571 |
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
|
| 572 |
|
|
@@ -589,9 +655,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
| 589 |
return self._length
|
| 590 |
|
| 591 |
def __getitem__(self, index):
|
| 592 |
-
if index == 0:
|
| 593 |
-
self.shuffle_buckets()
|
| 594 |
-
|
| 595 |
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
| 596 |
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
| 597 |
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
|
@@ -604,28 +667,29 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
| 604 |
|
| 605 |
for image_key in bucket[image_index:image_index + bucket_batch_size]:
|
| 606 |
image_info = self.image_data[image_key]
|
|
|
|
| 607 |
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
| 608 |
|
| 609 |
# image/latentsを処理する
|
| 610 |
if image_info.latents is not None:
|
| 611 |
-
latents = image_info.latents if not
|
| 612 |
image = None
|
| 613 |
elif image_info.latents_npz is not None:
|
| 614 |
-
latents = self.load_latents_from_npz(image_info,
|
| 615 |
latents = torch.FloatTensor(latents)
|
| 616 |
image = None
|
| 617 |
else:
|
| 618 |
# 画像を読み込み、必要ならcropする
|
| 619 |
-
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
|
| 620 |
im_h, im_w = img.shape[0:2]
|
| 621 |
|
| 622 |
if self.enable_bucket:
|
| 623 |
-
img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
|
| 624 |
else:
|
| 625 |
if face_cx > 0: # 顔位置情報あり
|
| 626 |
-
img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
|
| 627 |
elif im_h > self.height or im_w > self.width:
|
| 628 |
-
assert
|
| 629 |
if im_h > self.height:
|
| 630 |
p = random.randint(0, im_h - self.height)
|
| 631 |
img = img[p:p + self.height]
|
|
@@ -637,8 +701,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
| 637 |
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
| 638 |
|
| 639 |
# augmentation
|
| 640 |
-
|
| 641 |
-
|
|
|
|
| 642 |
|
| 643 |
latents = None
|
| 644 |
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
|
@@ -646,7 +711,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
| 646 |
images.append(image)
|
| 647 |
latents_list.append(latents)
|
| 648 |
|
| 649 |
-
caption = self.process_caption(image_info.caption)
|
| 650 |
captions.append(caption)
|
| 651 |
if not self.token_padding_disabled: # this option might be omitted in future
|
| 652 |
input_ids_list.append(self.get_input_ids(caption))
|
|
@@ -677,9 +742,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
| 677 |
|
| 678 |
|
| 679 |
class DreamBoothDataset(BaseDataset):
|
| 680 |
-
def __init__(self,
|
| 681 |
-
super().__init__(tokenizer, max_token_length,
|
| 682 |
-
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
| 683 |
|
| 684 |
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
| 685 |
|
|
@@ -702,7 +766,7 @@ class DreamBoothDataset(BaseDataset):
|
|
| 702 |
self.bucket_reso_steps = None # この情報は使われない
|
| 703 |
self.bucket_no_upscale = False
|
| 704 |
|
| 705 |
-
def read_caption(img_path):
|
| 706 |
# captionの候補ファイル名を作る
|
| 707 |
base_name = os.path.splitext(img_path)[0]
|
| 708 |
base_name_face_det = base_name
|
|
@@ -725,153 +789,181 @@ class DreamBoothDataset(BaseDataset):
|
|
| 725 |
break
|
| 726 |
return caption
|
| 727 |
|
| 728 |
-
def load_dreambooth_dir(
|
| 729 |
-
if not os.path.isdir(
|
| 730 |
-
|
| 731 |
-
return
|
| 732 |
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
n_repeats = int(tokens[0])
|
| 736 |
-
except ValueError as e:
|
| 737 |
-
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
|
| 738 |
-
return 0, [], []
|
| 739 |
-
|
| 740 |
-
caption_by_folder = '_'.join(tokens[1:])
|
| 741 |
-
img_paths = glob_images(dir, "*")
|
| 742 |
-
print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
|
| 743 |
|
| 744 |
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
| 745 |
captions = []
|
| 746 |
for img_path in img_paths:
|
| 747 |
-
cap_for_img = read_caption(img_path)
|
| 748 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 749 |
|
| 750 |
-
self.set_tag_frequency(os.path.basename(
|
| 751 |
|
| 752 |
-
return
|
| 753 |
|
| 754 |
-
print("prepare
|
| 755 |
-
train_dirs = os.listdir(train_data_dir)
|
| 756 |
num_train_images = 0
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 760 |
|
| 761 |
for img_path, caption in zip(img_paths, captions):
|
| 762 |
-
info = ImageInfo(img_path,
|
| 763 |
-
|
|
|
|
|
|
|
|
|
|
| 764 |
|
| 765 |
-
|
|
|
|
| 766 |
|
| 767 |
print(f"{num_train_images} train images with repeating.")
|
| 768 |
self.num_train_images = num_train_images
|
| 769 |
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
print("prepare reg images.")
|
| 774 |
-
reg_infos: List[ImageInfo] = []
|
| 775 |
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 780 |
|
| 781 |
-
|
| 782 |
-
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
| 783 |
-
reg_infos.append(info)
|
| 784 |
|
| 785 |
-
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
| 786 |
|
| 787 |
-
|
| 788 |
-
|
| 789 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 790 |
|
| 791 |
-
|
| 792 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 793 |
else:
|
| 794 |
-
|
| 795 |
-
n = 0
|
| 796 |
-
first_loop = True
|
| 797 |
-
while n < num_train_images:
|
| 798 |
-
for info in reg_infos:
|
| 799 |
-
if first_loop:
|
| 800 |
-
self.register_image(info)
|
| 801 |
-
n += info.num_repeats
|
| 802 |
-
else:
|
| 803 |
-
info.num_repeats += 1
|
| 804 |
-
n += 1
|
| 805 |
-
if n >= num_train_images:
|
| 806 |
-
break
|
| 807 |
-
first_loop = False
|
| 808 |
|
| 809 |
-
|
|
|
|
|
|
|
| 810 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 811 |
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
|
| 819 |
-
print(f"loading existing metadata: {json_file_name}")
|
| 820 |
-
with open(json_file_name, "rt", encoding='utf-8') as f:
|
| 821 |
-
metadata = json.load(f)
|
| 822 |
-
else:
|
| 823 |
-
raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}")
|
| 824 |
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
self.batch_size = batch_size
|
| 828 |
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
| 838 |
-
abs_path = abs_path[0]
|
| 839 |
-
|
| 840 |
-
caption = img_md.get('caption')
|
| 841 |
-
tags = img_md.get('tags')
|
| 842 |
-
if caption is None:
|
| 843 |
-
caption = tags
|
| 844 |
-
elif tags is not None and len(tags) > 0:
|
| 845 |
-
caption = caption + ', ' + tags
|
| 846 |
-
tags_list.append(tags)
|
| 847 |
-
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
|
| 848 |
-
|
| 849 |
-
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
|
| 850 |
-
image_info.image_size = img_md.get('train_resolution')
|
| 851 |
-
|
| 852 |
-
if not self.color_aug and not self.random_crop:
|
| 853 |
-
# if npz exists, use them
|
| 854 |
-
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
|
| 855 |
-
|
| 856 |
-
self.register_image(image_info)
|
| 857 |
-
self.num_train_images = len(metadata) * dataset_repeats
|
| 858 |
-
self.num_reg_images = 0
|
| 859 |
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
|
|
|
|
|
|
|
|
|
|
| 863 |
|
| 864 |
# check existence of all npz files
|
| 865 |
-
use_npz_latents = not (
|
| 866 |
if use_npz_latents:
|
|
|
|
| 867 |
npz_any = False
|
| 868 |
npz_all = True
|
|
|
|
| 869 |
for image_info in self.image_data.values():
|
|
|
|
|
|
|
| 870 |
has_npz = image_info.latents_npz is not None
|
| 871 |
npz_any = npz_any or has_npz
|
| 872 |
|
| 873 |
-
if
|
| 874 |
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
|
|
|
| 875 |
npz_all = npz_all and has_npz
|
| 876 |
|
| 877 |
if npz_any and not npz_all:
|
|
@@ -883,7 +975,7 @@ class FineTuningDataset(BaseDataset):
|
|
| 883 |
elif not npz_all:
|
| 884 |
use_npz_latents = False
|
| 885 |
print(f"some of npz file does not exist. ignore npz files / いくつ���のnpzファイルが見つからないためnpzファイルを無視します")
|
| 886 |
-
if
|
| 887 |
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
| 888 |
# else:
|
| 889 |
# print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
|
@@ -929,7 +1021,7 @@ class FineTuningDataset(BaseDataset):
|
|
| 929 |
for image_info in self.image_data.values():
|
| 930 |
image_info.latents_npz = image_info.latents_npz_flipped = None
|
| 931 |
|
| 932 |
-
def image_key_to_npz_file(self, image_key):
|
| 933 |
base_name = os.path.splitext(image_key)[0]
|
| 934 |
npz_file_norm = base_name + '.npz'
|
| 935 |
|
|
@@ -941,8 +1033,8 @@ class FineTuningDataset(BaseDataset):
|
|
| 941 |
return npz_file_norm, npz_file_flip
|
| 942 |
|
| 943 |
# image_key is relative path
|
| 944 |
-
npz_file_norm = os.path.join(
|
| 945 |
-
npz_file_flip = os.path.join(
|
| 946 |
|
| 947 |
if not os.path.exists(npz_file_norm):
|
| 948 |
npz_file_norm = None
|
|
@@ -953,13 +1045,60 @@ class FineTuningDataset(BaseDataset):
|
|
| 953 |
return npz_file_norm, npz_file_flip
|
| 954 |
|
| 955 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 956 |
def debug_dataset(train_dataset, show_input_ids=False):
|
| 957 |
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
| 958 |
print("Escape for exit. / Escキーで中断、終了します")
|
| 959 |
|
| 960 |
train_dataset.set_current_epoch(1)
|
| 961 |
k = 0
|
| 962 |
-
|
|
|
|
|
|
|
|
|
|
| 963 |
if example['latents'] is not None:
|
| 964 |
print(f"sample has latents from npz file: {example['latents'].size()}")
|
| 965 |
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
|
|
@@ -1364,6 +1503,35 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
|
| 1364 |
help='enable v-parameterization training / v-parameterization学習を有効にする')
|
| 1365 |
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
|
| 1366 |
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1367 |
|
| 1368 |
|
| 1369 |
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
|
@@ -1387,10 +1555,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|
| 1387 |
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
| 1388 |
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
| 1389 |
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
| 1390 |
-
parser.add_argument("--use_8bit_adam", action="store_true",
|
| 1391 |
-
help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
| 1392 |
-
parser.add_argument("--use_lion_optimizer", action="store_true",
|
| 1393 |
-
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
| 1394 |
parser.add_argument("--mem_eff_attn", action="store_true",
|
| 1395 |
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
| 1396 |
parser.add_argument("--xformers", action="store_true",
|
|
@@ -1398,7 +1562,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|
| 1398 |
parser.add_argument("--vae", type=str, default=None,
|
| 1399 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
| 1400 |
|
| 1401 |
-
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
| 1402 |
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
| 1403 |
parser.add_argument("--max_train_epochs", type=int, default=None,
|
| 1404 |
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
|
@@ -1419,15 +1582,23 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|
| 1419 |
parser.add_argument("--logging_dir", type=str, default=None,
|
| 1420 |
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
| 1421 |
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
| 1422 |
-
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
| 1423 |
-
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
|
| 1424 |
-
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
| 1425 |
-
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
| 1426 |
parser.add_argument("--noise_offset", type=float, default=None,
|
| 1427 |
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
|
| 1428 |
parser.add_argument("--lowram", action="store_true",
|
| 1429 |
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
|
| 1430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1431 |
if support_dreambooth:
|
| 1432 |
# DreamBooth training
|
| 1433 |
parser.add_argument("--prior_loss_weight", type=float, default=1.0,
|
|
@@ -1449,8 +1620,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
|
| 1449 |
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
|
| 1450 |
parser.add_argument("--caption_extention", type=str, default=None,
|
| 1451 |
help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
|
| 1452 |
-
parser.add_argument("--keep_tokens", type=int, default=
|
| 1453 |
-
help="keep heading N tokens when shuffling caption tokens / caption
|
| 1454 |
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
|
| 1455 |
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
|
| 1456 |
parser.add_argument("--face_crop_aug_range", type=str, default=None,
|
|
@@ -1475,11 +1646,11 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
|
| 1475 |
if support_caption_dropout:
|
| 1476 |
# Textual Inversion はcaptionのdropoutをsupportしない
|
| 1477 |
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
| 1478 |
-
parser.add_argument("--caption_dropout_rate", type=float, default=0,
|
| 1479 |
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
| 1480 |
-
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=
|
| 1481 |
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
| 1482 |
-
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
|
| 1483 |
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
|
| 1484 |
|
| 1485 |
if support_dreambooth:
|
|
@@ -1504,16 +1675,256 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
|
| 1504 |
# region utils
|
| 1505 |
|
| 1506 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1507 |
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
| 1508 |
# backward compatibility
|
| 1509 |
if args.caption_extention is not None:
|
| 1510 |
args.caption_extension = args.caption_extention
|
| 1511 |
args.caption_extention = None
|
| 1512 |
|
| 1513 |
-
if args.cache_latents:
|
| 1514 |
-
assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
|
| 1515 |
-
assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
|
| 1516 |
-
|
| 1517 |
# assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
|
| 1518 |
if args.resolution is not None:
|
| 1519 |
args.resolution = tuple([int(r) for r in args.resolution.split(',')])
|
|
@@ -1536,12 +1947,28 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
|
| 1536 |
|
| 1537 |
def load_tokenizer(args: argparse.Namespace):
|
| 1538 |
print("prepare tokenizer")
|
| 1539 |
-
if args.v2
|
| 1540 |
-
|
| 1541 |
-
|
| 1542 |
-
|
| 1543 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1544 |
print(f"update token length: {args.max_token_length}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1545 |
return tokenizer
|
| 1546 |
|
| 1547 |
|
|
@@ -1592,13 +2019,19 @@ def prepare_dtype(args: argparse.Namespace):
|
|
| 1592 |
|
| 1593 |
|
| 1594 |
def load_target_model(args: argparse.Namespace, weight_dtype):
|
| 1595 |
-
|
|
|
|
|
|
|
| 1596 |
if load_stable_diffusion_format:
|
| 1597 |
print("load StableDiffusion checkpoint")
|
| 1598 |
-
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2,
|
| 1599 |
else:
|
| 1600 |
print("load Diffusers pretrained models")
|
| 1601 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1602 |
text_encoder = pipe.text_encoder
|
| 1603 |
vae = pipe.vae
|
| 1604 |
unet = pipe.unet
|
|
@@ -1767,6 +2200,197 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
|
| 1767 |
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
| 1768 |
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
| 1769 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1770 |
# endregion
|
| 1771 |
|
| 1772 |
# region 前処理用
|
|
|
|
| 1 |
# common functions for training
|
| 2 |
|
| 3 |
import argparse
|
| 4 |
+
import importlib
|
| 5 |
import json
|
| 6 |
+
import re
|
| 7 |
import shutil
|
| 8 |
import time
|
| 9 |
+
from typing import (
|
| 10 |
+
Dict,
|
| 11 |
+
List,
|
| 12 |
+
NamedTuple,
|
| 13 |
+
Optional,
|
| 14 |
+
Sequence,
|
| 15 |
+
Tuple,
|
| 16 |
+
Union,
|
| 17 |
+
)
|
| 18 |
from accelerate import Accelerator
|
|
|
|
| 19 |
import glob
|
| 20 |
import math
|
| 21 |
import os
|
|
|
|
| 26 |
|
| 27 |
from tqdm import tqdm
|
| 28 |
import torch
|
| 29 |
+
from torch.optim import Optimizer
|
| 30 |
from torchvision import transforms
|
| 31 |
from transformers import CLIPTokenizer
|
| 32 |
+
import transformers
|
| 33 |
import diffusers
|
| 34 |
+
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
| 35 |
+
from diffusers import (StableDiffusionPipeline, DDPMScheduler,
|
| 36 |
+
EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler,
|
| 37 |
+
LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler,
|
| 38 |
+
KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler)
|
| 39 |
import albumentations as albu
|
| 40 |
import numpy as np
|
| 41 |
from PIL import Image
|
|
|
|
| 210 |
batch_index: int
|
| 211 |
|
| 212 |
|
| 213 |
+
class AugHelper:
|
| 214 |
+
def __init__(self):
|
| 215 |
+
# prepare all possible augmentators
|
| 216 |
+
color_aug_method = albu.OneOf([
|
| 217 |
+
albu.HueSaturationValue(8, 0, 0, p=.5),
|
| 218 |
+
albu.RandomGamma((95, 105), p=.5),
|
| 219 |
+
], p=.33)
|
| 220 |
+
flip_aug_method = albu.HorizontalFlip(p=0.5)
|
| 221 |
+
|
| 222 |
+
# key: (use_color_aug, use_flip_aug)
|
| 223 |
+
self.augmentors = {
|
| 224 |
+
(True, True): albu.Compose([
|
| 225 |
+
color_aug_method,
|
| 226 |
+
flip_aug_method,
|
| 227 |
+
], p=1.),
|
| 228 |
+
(True, False): albu.Compose([
|
| 229 |
+
color_aug_method,
|
| 230 |
+
], p=1.),
|
| 231 |
+
(False, True): albu.Compose([
|
| 232 |
+
flip_aug_method,
|
| 233 |
+
], p=1.),
|
| 234 |
+
(False, False): None
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]:
|
| 238 |
+
return self.augmentors[(use_color_aug, use_flip_aug)]
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class BaseSubset:
|
| 242 |
+
def __init__(self, image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, keep_tokens: int, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], random_crop: bool, caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float) -> None:
|
| 243 |
+
self.image_dir = image_dir
|
| 244 |
+
self.num_repeats = num_repeats
|
| 245 |
+
self.shuffle_caption = shuffle_caption
|
| 246 |
+
self.keep_tokens = keep_tokens
|
| 247 |
+
self.color_aug = color_aug
|
| 248 |
+
self.flip_aug = flip_aug
|
| 249 |
+
self.face_crop_aug_range = face_crop_aug_range
|
| 250 |
+
self.random_crop = random_crop
|
| 251 |
+
self.caption_dropout_rate = caption_dropout_rate
|
| 252 |
+
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
|
| 253 |
+
self.caption_tag_dropout_rate = caption_tag_dropout_rate
|
| 254 |
+
|
| 255 |
+
self.img_count = 0
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class DreamBoothSubset(BaseSubset):
|
| 259 |
+
def __init__(self, image_dir: str, is_reg: bool, class_tokens: Optional[str], caption_extension: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
|
| 260 |
+
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
| 261 |
+
|
| 262 |
+
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
|
| 263 |
+
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
|
| 264 |
+
|
| 265 |
+
self.is_reg = is_reg
|
| 266 |
+
self.class_tokens = class_tokens
|
| 267 |
+
self.caption_extension = caption_extension
|
| 268 |
+
|
| 269 |
+
def __eq__(self, other) -> bool:
|
| 270 |
+
if not isinstance(other, DreamBoothSubset):
|
| 271 |
+
return NotImplemented
|
| 272 |
+
return self.image_dir == other.image_dir
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class FineTuningSubset(BaseSubset):
|
| 276 |
+
def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
|
| 277 |
+
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
| 278 |
+
|
| 279 |
+
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
|
| 280 |
+
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
|
| 281 |
+
|
| 282 |
+
self.metadata_file = metadata_file
|
| 283 |
+
|
| 284 |
+
def __eq__(self, other) -> bool:
|
| 285 |
+
if not isinstance(other, FineTuningSubset):
|
| 286 |
+
return NotImplemented
|
| 287 |
+
return self.metadata_file == other.metadata_file
|
| 288 |
+
|
| 289 |
+
|
| 290 |
class BaseDataset(torch.utils.data.Dataset):
|
| 291 |
+
def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None:
|
| 292 |
super().__init__()
|
| 293 |
+
self.tokenizer = tokenizer
|
| 294 |
self.max_token_length = max_token_length
|
|
|
|
|
|
|
| 295 |
# width/height is used when enable_bucket==False
|
| 296 |
self.width, self.height = (None, None) if resolution is None else resolution
|
|
|
|
|
|
|
|
|
|
| 297 |
self.debug_dataset = debug_dataset
|
| 298 |
+
|
| 299 |
+
self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
|
| 300 |
+
|
| 301 |
self.token_padding_disabled = False
|
|
|
|
|
|
|
| 302 |
self.tag_frequency = {}
|
| 303 |
|
| 304 |
self.enable_bucket = False
|
|
|
|
| 312 |
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
| 313 |
|
| 314 |
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
|
|
|
|
|
|
|
|
|
| 315 |
|
| 316 |
# augmentation
|
| 317 |
+
self.aug_helper = AugHelper()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
|
| 319 |
self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
|
| 320 |
|
| 321 |
self.image_data: Dict[str, ImageInfo] = {}
|
| 322 |
+
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
|
| 323 |
|
| 324 |
self.replacements = {}
|
| 325 |
|
| 326 |
def set_current_epoch(self, epoch):
|
| 327 |
self.current_epoch = epoch
|
| 328 |
+
self.shuffle_buckets()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
def set_tag_frequency(self, dir_name, captions):
|
| 331 |
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
| 332 |
self.tag_frequency[dir_name] = frequency_for_dir
|
| 333 |
for caption in captions:
|
| 334 |
for tag in caption.split(","):
|
| 335 |
+
tag = tag.strip()
|
| 336 |
+
if tag:
|
| 337 |
tag = tag.lower()
|
| 338 |
frequency = frequency_for_dir.get(tag, 0)
|
| 339 |
frequency_for_dir[tag] = frequency + 1
|
|
|
|
| 344 |
def add_replacement(self, str_from, str_to):
|
| 345 |
self.replacements[str_from] = str_to
|
| 346 |
|
| 347 |
+
def process_caption(self, subset: BaseSubset, caption):
|
| 348 |
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
| 349 |
+
is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate
|
| 350 |
+
is_drop_out = is_drop_out or subset.caption_dropout_every_n_epochs > 0 and self.current_epoch % subset.caption_dropout_every_n_epochs == 0
|
| 351 |
|
| 352 |
if is_drop_out:
|
| 353 |
caption = ""
|
| 354 |
else:
|
| 355 |
+
if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0:
|
| 356 |
def dropout_tags(tokens):
|
| 357 |
+
if subset.caption_tag_dropout_rate <= 0:
|
| 358 |
return tokens
|
| 359 |
l = []
|
| 360 |
for token in tokens:
|
| 361 |
+
if random.random() >= subset.caption_tag_dropout_rate:
|
| 362 |
l.append(token)
|
| 363 |
return l
|
| 364 |
|
| 365 |
+
fixed_tokens = []
|
| 366 |
+
flex_tokens = [t.strip() for t in caption.strip().split(",")]
|
| 367 |
+
if subset.keep_tokens > 0:
|
| 368 |
+
fixed_tokens = flex_tokens[:subset.keep_tokens]
|
| 369 |
+
flex_tokens = flex_tokens[subset.keep_tokens:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
|
| 371 |
+
if subset.shuffle_caption:
|
| 372 |
+
random.shuffle(flex_tokens)
|
| 373 |
|
| 374 |
+
flex_tokens = dropout_tags(flex_tokens)
|
| 375 |
|
| 376 |
+
caption = ", ".join(fixed_tokens + flex_tokens)
|
|
|
|
| 377 |
|
| 378 |
# textual inversion対応
|
| 379 |
for str_from, str_to in self.replacements.items():
|
|
|
|
| 427 |
input_ids = torch.stack(iids_list) # 3,77
|
| 428 |
return input_ids
|
| 429 |
|
| 430 |
+
def register_image(self, info: ImageInfo, subset: BaseSubset):
|
| 431 |
self.image_data[info.image_key] = info
|
| 432 |
+
self.image_to_subset[info.image_key] = subset
|
| 433 |
|
| 434 |
def make_buckets(self):
|
| 435 |
'''
|
|
|
|
| 528 |
img = np.array(image, np.uint8)
|
| 529 |
return img
|
| 530 |
|
| 531 |
+
def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size):
|
| 532 |
image_height, image_width = image.shape[0:2]
|
| 533 |
|
| 534 |
if image_width != resized_size[0] or image_height != resized_size[1]:
|
|
|
|
| 538 |
image_height, image_width = image.shape[0:2]
|
| 539 |
if image_width > reso[0]:
|
| 540 |
trim_size = image_width - reso[0]
|
| 541 |
+
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
|
| 542 |
# print("w", trim_size, p)
|
| 543 |
image = image[:, p:p + reso[0]]
|
| 544 |
if image_height > reso[1]:
|
| 545 |
trim_size = image_height - reso[1]
|
| 546 |
+
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
|
| 547 |
# print("h", trim_size, p)
|
| 548 |
image = image[p:p + reso[1]]
|
| 549 |
|
| 550 |
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
| 551 |
return image
|
| 552 |
|
| 553 |
+
def is_latent_cacheable(self):
|
| 554 |
+
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
| 555 |
+
|
| 556 |
def cache_latents(self, vae):
|
| 557 |
# TODO ここを高速化したい
|
| 558 |
print("caching latents.")
|
| 559 |
for info in tqdm(self.image_data.values()):
|
| 560 |
+
subset = self.image_to_subset[info.image_key]
|
| 561 |
+
|
| 562 |
if info.latents_npz is not None:
|
| 563 |
info.latents = self.load_latents_from_npz(info, False)
|
| 564 |
info.latents = torch.FloatTensor(info.latents)
|
|
|
|
| 568 |
continue
|
| 569 |
|
| 570 |
image = self.load_image(info.absolute_path)
|
| 571 |
+
image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
|
| 572 |
|
| 573 |
img_tensor = self.image_transforms(image)
|
| 574 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
| 575 |
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
| 576 |
|
| 577 |
+
if subset.flip_aug:
|
| 578 |
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
|
| 579 |
img_tensor = self.image_transforms(image)
|
| 580 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
|
|
|
| 584 |
image = Image.open(image_path)
|
| 585 |
return image.size
|
| 586 |
|
| 587 |
+
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
|
| 588 |
img = self.load_image(image_path)
|
| 589 |
|
| 590 |
face_cx = face_cy = face_w = face_h = 0
|
| 591 |
+
if subset.face_crop_aug_range is not None:
|
| 592 |
tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
|
| 593 |
if len(tokens) >= 5:
|
| 594 |
face_cx = int(tokens[-4])
|
|
|
|
| 599 |
return img, face_cx, face_cy, face_w, face_h
|
| 600 |
|
| 601 |
# いい感じに切り出す
|
| 602 |
+
def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h):
|
| 603 |
height, width = image.shape[0:2]
|
| 604 |
if height == self.height and width == self.width:
|
| 605 |
return image
|
|
|
|
| 607 |
# 画像サイズはsizeより大きいのでリサイズする
|
| 608 |
face_size = max(face_w, face_h)
|
| 609 |
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
|
| 610 |
+
min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
|
| 611 |
+
max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
|
| 612 |
if min_scale >= max_scale: # range指定がmin==max
|
| 613 |
scale = min_scale
|
| 614 |
else:
|
|
|
|
| 626 |
for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
|
| 627 |
p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
|
| 628 |
|
| 629 |
+
if subset.random_crop:
|
| 630 |
# 背景も含めるために顔を中心に置く確率を高めつつずらす
|
| 631 |
range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
|
| 632 |
p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
|
| 633 |
else:
|
| 634 |
# range指定があるときのみ、すこしだけランダムに(わりと適当)
|
| 635 |
+
if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]:
|
| 636 |
if face_size > self.size // 10 and face_size >= 40:
|
| 637 |
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
|
| 638 |
|
|
|
|
| 655 |
return self._length
|
| 656 |
|
| 657 |
def __getitem__(self, index):
|
|
|
|
|
|
|
|
|
|
| 658 |
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
| 659 |
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
| 660 |
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
|
|
|
| 667 |
|
| 668 |
for image_key in bucket[image_index:image_index + bucket_batch_size]:
|
| 669 |
image_info = self.image_data[image_key]
|
| 670 |
+
subset = self.image_to_subset[image_key]
|
| 671 |
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
| 672 |
|
| 673 |
# image/latentsを処理する
|
| 674 |
if image_info.latents is not None:
|
| 675 |
+
latents = image_info.latents if not subset.flip_aug or random.random() < .5 else image_info.latents_flipped
|
| 676 |
image = None
|
| 677 |
elif image_info.latents_npz is not None:
|
| 678 |
+
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= .5)
|
| 679 |
latents = torch.FloatTensor(latents)
|
| 680 |
image = None
|
| 681 |
else:
|
| 682 |
# 画像を読み込み、必要ならcropする
|
| 683 |
+
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path)
|
| 684 |
im_h, im_w = img.shape[0:2]
|
| 685 |
|
| 686 |
if self.enable_bucket:
|
| 687 |
+
img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size)
|
| 688 |
else:
|
| 689 |
if face_cx > 0: # 顔位置情報あり
|
| 690 |
+
img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
|
| 691 |
elif im_h > self.height or im_w > self.width:
|
| 692 |
+
assert subset.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
|
| 693 |
if im_h > self.height:
|
| 694 |
p = random.randint(0, im_h - self.height)
|
| 695 |
img = img[p:p + self.height]
|
|
|
|
| 701 |
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
| 702 |
|
| 703 |
# augmentation
|
| 704 |
+
aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
|
| 705 |
+
if aug is not None:
|
| 706 |
+
img = aug(image=img)['image']
|
| 707 |
|
| 708 |
latents = None
|
| 709 |
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
|
|
|
| 711 |
images.append(image)
|
| 712 |
latents_list.append(latents)
|
| 713 |
|
| 714 |
+
caption = self.process_caption(subset, image_info.caption)
|
| 715 |
captions.append(caption)
|
| 716 |
if not self.token_padding_disabled: # this option might be omitted in future
|
| 717 |
input_ids_list.append(self.get_input_ids(caption))
|
|
|
|
| 742 |
|
| 743 |
|
| 744 |
class DreamBoothDataset(BaseDataset):
|
| 745 |
+
def __init__(self, subsets: Sequence[DreamBoothSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset) -> None:
|
| 746 |
+
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
|
|
|
| 747 |
|
| 748 |
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
| 749 |
|
|
|
|
| 766 |
self.bucket_reso_steps = None # この情報は使われない
|
| 767 |
self.bucket_no_upscale = False
|
| 768 |
|
| 769 |
+
def read_caption(img_path, caption_extension):
|
| 770 |
# captionの候補ファイル名を作る
|
| 771 |
base_name = os.path.splitext(img_path)[0]
|
| 772 |
base_name_face_det = base_name
|
|
|
|
| 789 |
break
|
| 790 |
return caption
|
| 791 |
|
| 792 |
+
def load_dreambooth_dir(subset: DreamBoothSubset):
|
| 793 |
+
if not os.path.isdir(subset.image_dir):
|
| 794 |
+
print(f"not directory: {subset.image_dir}")
|
| 795 |
+
return [], []
|
| 796 |
|
| 797 |
+
img_paths = glob_images(subset.image_dir, "*")
|
| 798 |
+
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 799 |
|
| 800 |
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
| 801 |
captions = []
|
| 802 |
for img_path in img_paths:
|
| 803 |
+
cap_for_img = read_caption(img_path, subset.caption_extension)
|
| 804 |
+
if cap_for_img is None and subset.class_tokens is None:
|
| 805 |
+
print(f"neither caption file nor class tokens are found. use empty caption for {img_path}")
|
| 806 |
+
captions.append("")
|
| 807 |
+
else:
|
| 808 |
+
captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
|
| 809 |
|
| 810 |
+
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
| 811 |
|
| 812 |
+
return img_paths, captions
|
| 813 |
|
| 814 |
+
print("prepare images.")
|
|
|
|
| 815 |
num_train_images = 0
|
| 816 |
+
num_reg_images = 0
|
| 817 |
+
reg_infos: List[ImageInfo] = []
|
| 818 |
+
for subset in subsets:
|
| 819 |
+
if subset.num_repeats < 1:
|
| 820 |
+
print(
|
| 821 |
+
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
| 822 |
+
continue
|
| 823 |
+
|
| 824 |
+
if subset in self.subsets:
|
| 825 |
+
print(
|
| 826 |
+
f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
| 827 |
+
continue
|
| 828 |
+
|
| 829 |
+
img_paths, captions = load_dreambooth_dir(subset)
|
| 830 |
+
if len(img_paths) < 1:
|
| 831 |
+
print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
|
| 832 |
+
continue
|
| 833 |
+
|
| 834 |
+
if subset.is_reg:
|
| 835 |
+
num_reg_images += subset.num_repeats * len(img_paths)
|
| 836 |
+
else:
|
| 837 |
+
num_train_images += subset.num_repeats * len(img_paths)
|
| 838 |
|
| 839 |
for img_path, caption in zip(img_paths, captions):
|
| 840 |
+
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
|
| 841 |
+
if subset.is_reg:
|
| 842 |
+
reg_infos.append(info)
|
| 843 |
+
else:
|
| 844 |
+
self.register_image(info, subset)
|
| 845 |
|
| 846 |
+
subset.img_count = len(img_paths)
|
| 847 |
+
self.subsets.append(subset)
|
| 848 |
|
| 849 |
print(f"{num_train_images} train images with repeating.")
|
| 850 |
self.num_train_images = num_train_images
|
| 851 |
|
| 852 |
+
print(f"{num_reg_images} reg images.")
|
| 853 |
+
if num_train_images < num_reg_images:
|
| 854 |
+
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
|
|
|
|
|
|
|
| 855 |
|
| 856 |
+
if num_reg_images == 0:
|
| 857 |
+
print("no regularization images / 正則化画像が見つかりませんでした")
|
| 858 |
+
else:
|
| 859 |
+
# num_repeatsを計算する:どうせ大した数ではないのでループで処理する
|
| 860 |
+
n = 0
|
| 861 |
+
first_loop = True
|
| 862 |
+
while n < num_train_images:
|
| 863 |
+
for info in reg_infos:
|
| 864 |
+
if first_loop:
|
| 865 |
+
self.register_image(info, subset)
|
| 866 |
+
n += info.num_repeats
|
| 867 |
+
else:
|
| 868 |
+
info.num_repeats += 1
|
| 869 |
+
n += 1
|
| 870 |
+
if n >= num_train_images:
|
| 871 |
+
break
|
| 872 |
+
first_loop = False
|
| 873 |
|
| 874 |
+
self.num_reg_images = num_reg_images
|
|
|
|
|
|
|
| 875 |
|
|
|
|
| 876 |
|
| 877 |
+
class FineTuningDataset(BaseDataset):
|
| 878 |
+
def __init__(self, subsets: Sequence[FineTuningSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset) -> None:
|
| 879 |
+
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
| 880 |
+
|
| 881 |
+
self.batch_size = batch_size
|
| 882 |
+
|
| 883 |
+
self.num_train_images = 0
|
| 884 |
+
self.num_reg_images = 0
|
| 885 |
|
| 886 |
+
for subset in subsets:
|
| 887 |
+
if subset.num_repeats < 1:
|
| 888 |
+
print(
|
| 889 |
+
f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
| 890 |
+
continue
|
| 891 |
+
|
| 892 |
+
if subset in self.subsets:
|
| 893 |
+
print(
|
| 894 |
+
f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
| 895 |
+
continue
|
| 896 |
+
|
| 897 |
+
# メタデータを読み込む
|
| 898 |
+
if os.path.exists(subset.metadata_file):
|
| 899 |
+
print(f"loading existing metadata: {subset.metadata_file}")
|
| 900 |
+
with open(subset.metadata_file, "rt", encoding='utf-8') as f:
|
| 901 |
+
metadata = json.load(f)
|
| 902 |
else:
|
| 903 |
+
raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 904 |
|
| 905 |
+
if len(metadata) < 1:
|
| 906 |
+
print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します")
|
| 907 |
+
continue
|
| 908 |
|
| 909 |
+
tags_list = []
|
| 910 |
+
for image_key, img_md in metadata.items():
|
| 911 |
+
# path情報を作る
|
| 912 |
+
if os.path.exists(image_key):
|
| 913 |
+
abs_path = image_key
|
| 914 |
+
else:
|
| 915 |
+
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
| 916 |
+
if os.path.exists(npz_path):
|
| 917 |
+
abs_path = npz_path
|
| 918 |
+
else:
|
| 919 |
+
# わりといい加減だがいい方法が思いつかん
|
| 920 |
+
abs_path = glob_images(subset.image_dir, image_key)
|
| 921 |
+
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
| 922 |
+
abs_path = abs_path[0]
|
| 923 |
|
| 924 |
+
caption = img_md.get('caption')
|
| 925 |
+
tags = img_md.get('tags')
|
| 926 |
+
if caption is None:
|
| 927 |
+
caption = tags
|
| 928 |
+
elif tags is not None and len(tags) > 0:
|
| 929 |
+
caption = caption + ', ' + tags
|
| 930 |
+
tags_list.append(tags)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 931 |
|
| 932 |
+
if caption is None:
|
| 933 |
+
caption = ""
|
|
|
|
| 934 |
|
| 935 |
+
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
|
| 936 |
+
image_info.image_size = img_md.get('train_resolution')
|
| 937 |
+
|
| 938 |
+
if not subset.color_aug and not subset.random_crop:
|
| 939 |
+
# if npz exists, use them
|
| 940 |
+
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
|
| 941 |
+
|
| 942 |
+
self.register_image(image_info, subset)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 943 |
|
| 944 |
+
self.num_train_images += len(metadata) * subset.num_repeats
|
| 945 |
+
|
| 946 |
+
# TODO do not record tag freq when no tag
|
| 947 |
+
self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list)
|
| 948 |
+
subset.img_count = len(metadata)
|
| 949 |
+
self.subsets.append(subset)
|
| 950 |
|
| 951 |
# check existence of all npz files
|
| 952 |
+
use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
|
| 953 |
if use_npz_latents:
|
| 954 |
+
flip_aug_in_subset = False
|
| 955 |
npz_any = False
|
| 956 |
npz_all = True
|
| 957 |
+
|
| 958 |
for image_info in self.image_data.values():
|
| 959 |
+
subset = self.image_to_subset[image_info.image_key]
|
| 960 |
+
|
| 961 |
has_npz = image_info.latents_npz is not None
|
| 962 |
npz_any = npz_any or has_npz
|
| 963 |
|
| 964 |
+
if subset.flip_aug:
|
| 965 |
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
| 966 |
+
flip_aug_in_subset = True
|
| 967 |
npz_all = npz_all and has_npz
|
| 968 |
|
| 969 |
if npz_any and not npz_all:
|
|
|
|
| 975 |
elif not npz_all:
|
| 976 |
use_npz_latents = False
|
| 977 |
print(f"some of npz file does not exist. ignore npz files / いくつ���のnpzファイルが見つからないためnpzファイルを無視します")
|
| 978 |
+
if flip_aug_in_subset:
|
| 979 |
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
| 980 |
# else:
|
| 981 |
# print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
|
|
|
| 1021 |
for image_info in self.image_data.values():
|
| 1022 |
image_info.latents_npz = image_info.latents_npz_flipped = None
|
| 1023 |
|
| 1024 |
+
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
|
| 1025 |
base_name = os.path.splitext(image_key)[0]
|
| 1026 |
npz_file_norm = base_name + '.npz'
|
| 1027 |
|
|
|
|
| 1033 |
return npz_file_norm, npz_file_flip
|
| 1034 |
|
| 1035 |
# image_key is relative path
|
| 1036 |
+
npz_file_norm = os.path.join(subset.image_dir, image_key + '.npz')
|
| 1037 |
+
npz_file_flip = os.path.join(subset.image_dir, image_key + '_flip.npz')
|
| 1038 |
|
| 1039 |
if not os.path.exists(npz_file_norm):
|
| 1040 |
npz_file_norm = None
|
|
|
|
| 1045 |
return npz_file_norm, npz_file_flip
|
| 1046 |
|
| 1047 |
|
| 1048 |
+
# behave as Dataset mock
|
| 1049 |
+
class DatasetGroup(torch.utils.data.ConcatDataset):
|
| 1050 |
+
def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
|
| 1051 |
+
self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]]
|
| 1052 |
+
|
| 1053 |
+
super().__init__(datasets)
|
| 1054 |
+
|
| 1055 |
+
self.image_data = {}
|
| 1056 |
+
self.num_train_images = 0
|
| 1057 |
+
self.num_reg_images = 0
|
| 1058 |
+
|
| 1059 |
+
# simply concat together
|
| 1060 |
+
# TODO: handling image_data key duplication among dataset
|
| 1061 |
+
# In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset.
|
| 1062 |
+
for dataset in datasets:
|
| 1063 |
+
self.image_data.update(dataset.image_data)
|
| 1064 |
+
self.num_train_images += dataset.num_train_images
|
| 1065 |
+
self.num_reg_images += dataset.num_reg_images
|
| 1066 |
+
|
| 1067 |
+
def add_replacement(self, str_from, str_to):
|
| 1068 |
+
for dataset in self.datasets:
|
| 1069 |
+
dataset.add_replacement(str_from, str_to)
|
| 1070 |
+
|
| 1071 |
+
# def make_buckets(self):
|
| 1072 |
+
# for dataset in self.datasets:
|
| 1073 |
+
# dataset.make_buckets()
|
| 1074 |
+
|
| 1075 |
+
def cache_latents(self, vae):
|
| 1076 |
+
for i, dataset in enumerate(self.datasets):
|
| 1077 |
+
print(f"[Dataset {i}]")
|
| 1078 |
+
dataset.cache_latents(vae)
|
| 1079 |
+
|
| 1080 |
+
def is_latent_cacheable(self) -> bool:
|
| 1081 |
+
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
| 1082 |
+
|
| 1083 |
+
def set_current_epoch(self, epoch):
|
| 1084 |
+
for dataset in self.datasets:
|
| 1085 |
+
dataset.set_current_epoch(epoch)
|
| 1086 |
+
|
| 1087 |
+
def disable_token_padding(self):
|
| 1088 |
+
for dataset in self.datasets:
|
| 1089 |
+
dataset.disable_token_padding()
|
| 1090 |
+
|
| 1091 |
+
|
| 1092 |
def debug_dataset(train_dataset, show_input_ids=False):
|
| 1093 |
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
| 1094 |
print("Escape for exit. / Escキーで中断、終了します")
|
| 1095 |
|
| 1096 |
train_dataset.set_current_epoch(1)
|
| 1097 |
k = 0
|
| 1098 |
+
indices = list(range(len(train_dataset)))
|
| 1099 |
+
random.shuffle(indices)
|
| 1100 |
+
for i, idx in enumerate(indices):
|
| 1101 |
+
example = train_dataset[idx]
|
| 1102 |
if example['latents'] is not None:
|
| 1103 |
print(f"sample has latents from npz file: {example['latents'].size()}")
|
| 1104 |
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
|
|
|
|
| 1503 |
help='enable v-parameterization training / v-parameterization学習を有効にする')
|
| 1504 |
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
|
| 1505 |
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
| 1506 |
+
parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
|
| 1507 |
+
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
|
| 1508 |
+
|
| 1509 |
+
|
| 1510 |
+
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
| 1511 |
+
parser.add_argument("--optimizer_type", type=str, default="",
|
| 1512 |
+
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor")
|
| 1513 |
+
|
| 1514 |
+
# backward compatibility
|
| 1515 |
+
parser.add_argument("--use_8bit_adam", action="store_true",
|
| 1516 |
+
help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
| 1517 |
+
parser.add_argument("--use_lion_optimizer", action="store_true",
|
| 1518 |
+
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
| 1519 |
+
|
| 1520 |
+
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
| 1521 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
| 1522 |
+
help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない")
|
| 1523 |
+
|
| 1524 |
+
parser.add_argument("--optimizer_args", type=str, default=None, nargs='*',
|
| 1525 |
+
help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")")
|
| 1526 |
+
|
| 1527 |
+
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
| 1528 |
+
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor")
|
| 1529 |
+
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
| 1530 |
+
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
| 1531 |
+
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
| 1532 |
+
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
| 1533 |
+
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
| 1534 |
+
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
| 1535 |
|
| 1536 |
|
| 1537 |
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
|
|
|
| 1555 |
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
| 1556 |
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
| 1557 |
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1558 |
parser.add_argument("--mem_eff_attn", action="store_true",
|
| 1559 |
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
| 1560 |
parser.add_argument("--xformers", action="store_true",
|
|
|
|
| 1562 |
parser.add_argument("--vae", type=str, default=None,
|
| 1563 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
| 1564 |
|
|
|
|
| 1565 |
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
| 1566 |
parser.add_argument("--max_train_epochs", type=int, default=None,
|
| 1567 |
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
|
|
|
| 1582 |
parser.add_argument("--logging_dir", type=str, default=None,
|
| 1583 |
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
| 1584 |
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1585 |
parser.add_argument("--noise_offset", type=float, default=None,
|
| 1586 |
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
|
| 1587 |
parser.add_argument("--lowram", action="store_true",
|
| 1588 |
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
|
| 1589 |
|
| 1590 |
+
parser.add_argument("--sample_every_n_steps", type=int, default=None,
|
| 1591 |
+
help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する")
|
| 1592 |
+
parser.add_argument("--sample_every_n_epochs", type=int, default=None,
|
| 1593 |
+
help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)")
|
| 1594 |
+
parser.add_argument("--sample_prompts", type=str, default=None,
|
| 1595 |
+
help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル")
|
| 1596 |
+
parser.add_argument('--sample_sampler', type=str, default='ddim',
|
| 1597 |
+
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
|
| 1598 |
+
'dpmsolver++', 'dpmsingle',
|
| 1599 |
+
'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'],
|
| 1600 |
+
help=f'sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類')
|
| 1601 |
+
|
| 1602 |
if support_dreambooth:
|
| 1603 |
# DreamBooth training
|
| 1604 |
parser.add_argument("--prior_loss_weight", type=float, default=1.0,
|
|
|
|
| 1620 |
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
|
| 1621 |
parser.add_argument("--caption_extention", type=str, default=None,
|
| 1622 |
help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
|
| 1623 |
+
parser.add_argument("--keep_tokens", type=int, default=0,
|
| 1624 |
+
help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)")
|
| 1625 |
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
|
| 1626 |
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
|
| 1627 |
parser.add_argument("--face_crop_aug_range", type=str, default=None,
|
|
|
|
| 1646 |
if support_caption_dropout:
|
| 1647 |
# Textual Inversion はcaptionのdropoutをsupportしない
|
| 1648 |
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
| 1649 |
+
parser.add_argument("--caption_dropout_rate", type=float, default=0.0,
|
| 1650 |
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
| 1651 |
+
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=0,
|
| 1652 |
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
| 1653 |
+
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0.0,
|
| 1654 |
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
|
| 1655 |
|
| 1656 |
if support_dreambooth:
|
|
|
|
| 1675 |
# region utils
|
| 1676 |
|
| 1677 |
|
| 1678 |
+
def get_optimizer(args, trainable_params):
|
| 1679 |
+
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
|
| 1680 |
+
|
| 1681 |
+
optimizer_type = args.optimizer_type
|
| 1682 |
+
if args.use_8bit_adam:
|
| 1683 |
+
assert not args.use_lion_optimizer, "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています"
|
| 1684 |
+
assert optimizer_type is None or optimizer_type == "", "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています"
|
| 1685 |
+
optimizer_type = "AdamW8bit"
|
| 1686 |
+
|
| 1687 |
+
elif args.use_lion_optimizer:
|
| 1688 |
+
assert optimizer_type is None or optimizer_type == "", "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています"
|
| 1689 |
+
optimizer_type = "Lion"
|
| 1690 |
+
|
| 1691 |
+
if optimizer_type is None or optimizer_type == "":
|
| 1692 |
+
optimizer_type = "AdamW"
|
| 1693 |
+
optimizer_type = optimizer_type.lower()
|
| 1694 |
+
|
| 1695 |
+
# 引数を分解する:boolとfloat、tupleのみ対応
|
| 1696 |
+
optimizer_kwargs = {}
|
| 1697 |
+
if args.optimizer_args is not None and len(args.optimizer_args) > 0:
|
| 1698 |
+
for arg in args.optimizer_args:
|
| 1699 |
+
key, value = arg.split('=')
|
| 1700 |
+
|
| 1701 |
+
value = value.split(",")
|
| 1702 |
+
for i in range(len(value)):
|
| 1703 |
+
if value[i].lower() == "true" or value[i].lower() == "false":
|
| 1704 |
+
value[i] = (value[i].lower() == "true")
|
| 1705 |
+
else:
|
| 1706 |
+
value[i] = float(value[i])
|
| 1707 |
+
if len(value) == 1:
|
| 1708 |
+
value = value[0]
|
| 1709 |
+
else:
|
| 1710 |
+
value = tuple(value)
|
| 1711 |
+
|
| 1712 |
+
optimizer_kwargs[key] = value
|
| 1713 |
+
# print("optkwargs:", optimizer_kwargs)
|
| 1714 |
+
|
| 1715 |
+
lr = args.learning_rate
|
| 1716 |
+
|
| 1717 |
+
if optimizer_type == "AdamW8bit".lower():
|
| 1718 |
+
try:
|
| 1719 |
+
import bitsandbytes as bnb
|
| 1720 |
+
except ImportError:
|
| 1721 |
+
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
| 1722 |
+
print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
|
| 1723 |
+
optimizer_class = bnb.optim.AdamW8bit
|
| 1724 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
| 1725 |
+
|
| 1726 |
+
elif optimizer_type == "SGDNesterov8bit".lower():
|
| 1727 |
+
try:
|
| 1728 |
+
import bitsandbytes as bnb
|
| 1729 |
+
except ImportError:
|
| 1730 |
+
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
| 1731 |
+
print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
|
| 1732 |
+
if "momentum" not in optimizer_kwargs:
|
| 1733 |
+
print(f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
|
| 1734 |
+
optimizer_kwargs["momentum"] = 0.9
|
| 1735 |
+
|
| 1736 |
+
optimizer_class = bnb.optim.SGD8bit
|
| 1737 |
+
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
| 1738 |
+
|
| 1739 |
+
elif optimizer_type == "Lion".lower():
|
| 1740 |
+
try:
|
| 1741 |
+
import lion_pytorch
|
| 1742 |
+
except ImportError:
|
| 1743 |
+
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
| 1744 |
+
print(f"use Lion optimizer | {optimizer_kwargs}")
|
| 1745 |
+
optimizer_class = lion_pytorch.Lion
|
| 1746 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
| 1747 |
+
|
| 1748 |
+
elif optimizer_type == "SGDNesterov".lower():
|
| 1749 |
+
print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
|
| 1750 |
+
if "momentum" not in optimizer_kwargs:
|
| 1751 |
+
print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
|
| 1752 |
+
optimizer_kwargs["momentum"] = 0.9
|
| 1753 |
+
|
| 1754 |
+
optimizer_class = torch.optim.SGD
|
| 1755 |
+
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
| 1756 |
+
|
| 1757 |
+
elif optimizer_type == "DAdaptation".lower():
|
| 1758 |
+
try:
|
| 1759 |
+
import dadaptation
|
| 1760 |
+
except ImportError:
|
| 1761 |
+
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
| 1762 |
+
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
|
| 1763 |
+
|
| 1764 |
+
actual_lr = lr
|
| 1765 |
+
lr_count = 1
|
| 1766 |
+
if type(trainable_params) == list and type(trainable_params[0]) == dict:
|
| 1767 |
+
lrs = set()
|
| 1768 |
+
actual_lr = trainable_params[0].get("lr", actual_lr)
|
| 1769 |
+
for group in trainable_params:
|
| 1770 |
+
lrs.add(group.get("lr", actual_lr))
|
| 1771 |
+
lr_count = len(lrs)
|
| 1772 |
+
|
| 1773 |
+
if actual_lr <= 0.1:
|
| 1774 |
+
print(
|
| 1775 |
+
f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}')
|
| 1776 |
+
print('recommend option: lr=1.0 / 推奨は1.0です')
|
| 1777 |
+
if lr_count > 1:
|
| 1778 |
+
print(
|
| 1779 |
+
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}")
|
| 1780 |
+
|
| 1781 |
+
optimizer_class = dadaptation.DAdaptAdam
|
| 1782 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
| 1783 |
+
|
| 1784 |
+
elif optimizer_type == "Adafactor".lower():
|
| 1785 |
+
# 引数を確認して適宜補正する
|
| 1786 |
+
if "relative_step" not in optimizer_kwargs:
|
| 1787 |
+
optimizer_kwargs["relative_step"] = True # default
|
| 1788 |
+
if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
|
| 1789 |
+
print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします")
|
| 1790 |
+
optimizer_kwargs["relative_step"] = True
|
| 1791 |
+
print(f"use Adafactor optimizer | {optimizer_kwargs}")
|
| 1792 |
+
|
| 1793 |
+
if optimizer_kwargs["relative_step"]:
|
| 1794 |
+
print(f"relative_step is true / relative_stepがtrueです")
|
| 1795 |
+
if lr != 0.0:
|
| 1796 |
+
print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
|
| 1797 |
+
args.learning_rate = None
|
| 1798 |
+
|
| 1799 |
+
# trainable_paramsがgroupだった時の処理:lrを削除する
|
| 1800 |
+
if type(trainable_params) == list and type(trainable_params[0]) == dict:
|
| 1801 |
+
has_group_lr = False
|
| 1802 |
+
for group in trainable_params:
|
| 1803 |
+
p = group.pop("lr", None)
|
| 1804 |
+
has_group_lr = has_group_lr or (p is not None)
|
| 1805 |
+
|
| 1806 |
+
if has_group_lr:
|
| 1807 |
+
# 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない
|
| 1808 |
+
print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます")
|
| 1809 |
+
args.unet_lr = None
|
| 1810 |
+
args.text_encoder_lr = None
|
| 1811 |
+
|
| 1812 |
+
if args.lr_scheduler != "adafactor":
|
| 1813 |
+
print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
|
| 1814 |
+
args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
|
| 1815 |
+
|
| 1816 |
+
lr = None
|
| 1817 |
+
else:
|
| 1818 |
+
if args.max_grad_norm != 0.0:
|
| 1819 |
+
print(f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません")
|
| 1820 |
+
if args.lr_scheduler != "constant_with_warmup":
|
| 1821 |
+
print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
|
| 1822 |
+
if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
|
| 1823 |
+
print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
|
| 1824 |
+
|
| 1825 |
+
optimizer_class = transformers.optimization.Adafactor
|
| 1826 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
| 1827 |
+
|
| 1828 |
+
elif optimizer_type == "AdamW".lower():
|
| 1829 |
+
print(f"use AdamW optimizer | {optimizer_kwargs}")
|
| 1830 |
+
optimizer_class = torch.optim.AdamW
|
| 1831 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
| 1832 |
+
|
| 1833 |
+
else:
|
| 1834 |
+
# 任意のoptimizerを使う
|
| 1835 |
+
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
| 1836 |
+
print(f"use {optimizer_type} | {optimizer_kwargs}")
|
| 1837 |
+
if "." not in optimizer_type:
|
| 1838 |
+
optimizer_module = torch.optim
|
| 1839 |
+
else:
|
| 1840 |
+
values = optimizer_type.split(".")
|
| 1841 |
+
optimizer_module = importlib.import_module(".".join(values[:-1]))
|
| 1842 |
+
optimizer_type = values[-1]
|
| 1843 |
+
|
| 1844 |
+
optimizer_class = getattr(optimizer_module, optimizer_type)
|
| 1845 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
| 1846 |
+
|
| 1847 |
+
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
| 1848 |
+
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
|
| 1849 |
+
|
| 1850 |
+
return optimizer_name, optimizer_args, optimizer
|
| 1851 |
+
|
| 1852 |
+
|
| 1853 |
+
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
| 1854 |
+
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
| 1855 |
+
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
| 1856 |
+
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
| 1857 |
+
|
| 1858 |
+
|
| 1859 |
+
def get_scheduler_fix(
|
| 1860 |
+
name: Union[str, SchedulerType],
|
| 1861 |
+
optimizer: Optimizer,
|
| 1862 |
+
num_warmup_steps: Optional[int] = None,
|
| 1863 |
+
num_training_steps: Optional[int] = None,
|
| 1864 |
+
num_cycles: int = 1,
|
| 1865 |
+
power: float = 1.0,
|
| 1866 |
+
):
|
| 1867 |
+
"""
|
| 1868 |
+
Unified API to get any scheduler from its name.
|
| 1869 |
+
Args:
|
| 1870 |
+
name (`str` or `SchedulerType`):
|
| 1871 |
+
The name of the scheduler to use.
|
| 1872 |
+
optimizer (`torch.optim.Optimizer`):
|
| 1873 |
+
The optimizer that will be used during training.
|
| 1874 |
+
num_warmup_steps (`int`, *optional*):
|
| 1875 |
+
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
| 1876 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
| 1877 |
+
num_training_steps (`int``, *optional*):
|
| 1878 |
+
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
| 1879 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
| 1880 |
+
num_cycles (`int`, *optional*):
|
| 1881 |
+
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
| 1882 |
+
power (`float`, *optional*, defaults to 1.0):
|
| 1883 |
+
Power factor. See `POLYNOMIAL` scheduler
|
| 1884 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 1885 |
+
The index of the last epoch when resuming training.
|
| 1886 |
+
"""
|
| 1887 |
+
if name.startswith("adafactor"):
|
| 1888 |
+
assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
|
| 1889 |
+
initial_lr = float(name.split(':')[1])
|
| 1890 |
+
# print("adafactor scheduler init lr", initial_lr)
|
| 1891 |
+
return transformers.optimization.AdafactorSchedule(optimizer, initial_lr)
|
| 1892 |
+
|
| 1893 |
+
name = SchedulerType(name)
|
| 1894 |
+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
| 1895 |
+
if name == SchedulerType.CONSTANT:
|
| 1896 |
+
return schedule_func(optimizer)
|
| 1897 |
+
|
| 1898 |
+
# All other schedulers require `num_warmup_steps`
|
| 1899 |
+
if num_warmup_steps is None:
|
| 1900 |
+
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
| 1901 |
+
|
| 1902 |
+
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
| 1903 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
| 1904 |
+
|
| 1905 |
+
# All other schedulers require `num_training_steps`
|
| 1906 |
+
if num_training_steps is None:
|
| 1907 |
+
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
| 1908 |
+
|
| 1909 |
+
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
| 1910 |
+
return schedule_func(
|
| 1911 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
| 1912 |
+
)
|
| 1913 |
+
|
| 1914 |
+
if name == SchedulerType.POLYNOMIAL:
|
| 1915 |
+
return schedule_func(
|
| 1916 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
| 1917 |
+
)
|
| 1918 |
+
|
| 1919 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
| 1920 |
+
|
| 1921 |
+
|
| 1922 |
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
| 1923 |
# backward compatibility
|
| 1924 |
if args.caption_extention is not None:
|
| 1925 |
args.caption_extension = args.caption_extention
|
| 1926 |
args.caption_extention = None
|
| 1927 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1928 |
# assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
|
| 1929 |
if args.resolution is not None:
|
| 1930 |
args.resolution = tuple([int(r) for r in args.resolution.split(',')])
|
|
|
|
| 1947 |
|
| 1948 |
def load_tokenizer(args: argparse.Namespace):
|
| 1949 |
print("prepare tokenizer")
|
| 1950 |
+
original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH
|
| 1951 |
+
|
| 1952 |
+
tokenizer: CLIPTokenizer = None
|
| 1953 |
+
if args.tokenizer_cache_dir:
|
| 1954 |
+
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace('/', '_'))
|
| 1955 |
+
if os.path.exists(local_tokenizer_path):
|
| 1956 |
+
print(f"load tokenizer from cache: {local_tokenizer_path}")
|
| 1957 |
+
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2
|
| 1958 |
+
|
| 1959 |
+
if tokenizer is None:
|
| 1960 |
+
if args.v2:
|
| 1961 |
+
tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer")
|
| 1962 |
+
else:
|
| 1963 |
+
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
| 1964 |
+
|
| 1965 |
+
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
| 1966 |
print(f"update token length: {args.max_token_length}")
|
| 1967 |
+
|
| 1968 |
+
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
| 1969 |
+
print(f"save Tokenizer to cache: {local_tokenizer_path}")
|
| 1970 |
+
tokenizer.save_pretrained(local_tokenizer_path)
|
| 1971 |
+
|
| 1972 |
return tokenizer
|
| 1973 |
|
| 1974 |
|
|
|
|
| 2019 |
|
| 2020 |
|
| 2021 |
def load_target_model(args: argparse.Namespace, weight_dtype):
|
| 2022 |
+
name_or_path = args.pretrained_model_name_or_path
|
| 2023 |
+
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
| 2024 |
+
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
| 2025 |
if load_stable_diffusion_format:
|
| 2026 |
print("load StableDiffusion checkpoint")
|
| 2027 |
+
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path)
|
| 2028 |
else:
|
| 2029 |
print("load Diffusers pretrained models")
|
| 2030 |
+
try:
|
| 2031 |
+
pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
|
| 2032 |
+
except EnvironmentError as ex:
|
| 2033 |
+
print(
|
| 2034 |
+
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}")
|
| 2035 |
text_encoder = pipe.text_encoder
|
| 2036 |
vae = pipe.vae
|
| 2037 |
unet = pipe.unet
|
|
|
|
| 2200 |
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
| 2201 |
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
| 2202 |
|
| 2203 |
+
|
| 2204 |
+
# scheduler:
|
| 2205 |
+
SCHEDULER_LINEAR_START = 0.00085
|
| 2206 |
+
SCHEDULER_LINEAR_END = 0.0120
|
| 2207 |
+
SCHEDULER_TIMESTEPS = 1000
|
| 2208 |
+
SCHEDLER_SCHEDULE = 'scaled_linear'
|
| 2209 |
+
|
| 2210 |
+
|
| 2211 |
+
def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None):
|
| 2212 |
+
"""
|
| 2213 |
+
生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない
|
| 2214 |
+
clip skipは対応した
|
| 2215 |
+
"""
|
| 2216 |
+
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
| 2217 |
+
return
|
| 2218 |
+
if args.sample_every_n_epochs is not None:
|
| 2219 |
+
# sample_every_n_steps は無視する
|
| 2220 |
+
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
| 2221 |
+
return
|
| 2222 |
+
else:
|
| 2223 |
+
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
| 2224 |
+
return
|
| 2225 |
+
|
| 2226 |
+
print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
| 2227 |
+
if not os.path.isfile(args.sample_prompts):
|
| 2228 |
+
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
| 2229 |
+
return
|
| 2230 |
+
|
| 2231 |
+
org_vae_device = vae.device # CPUにいるはず
|
| 2232 |
+
vae.to(device)
|
| 2233 |
+
|
| 2234 |
+
# clip skip 対応のための wrapper を作る
|
| 2235 |
+
if args.clip_skip is None:
|
| 2236 |
+
text_encoder_or_wrapper = text_encoder
|
| 2237 |
+
else:
|
| 2238 |
+
class Wrapper():
|
| 2239 |
+
def __init__(self, tenc) -> None:
|
| 2240 |
+
self.tenc = tenc
|
| 2241 |
+
self.config = {}
|
| 2242 |
+
super().__init__()
|
| 2243 |
+
|
| 2244 |
+
def __call__(self, input_ids, attention_mask):
|
| 2245 |
+
enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True)
|
| 2246 |
+
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
| 2247 |
+
encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states)
|
| 2248 |
+
pooled_output = enc_out['pooler_output']
|
| 2249 |
+
return encoder_hidden_states, pooled_output # 1st output is only used
|
| 2250 |
+
|
| 2251 |
+
text_encoder_or_wrapper = Wrapper(text_encoder)
|
| 2252 |
+
|
| 2253 |
+
# read prompts
|
| 2254 |
+
with open(args.sample_prompts, 'rt', encoding='utf-8') as f:
|
| 2255 |
+
prompts = f.readlines()
|
| 2256 |
+
|
| 2257 |
+
# schedulerを用意する
|
| 2258 |
+
sched_init_args = {}
|
| 2259 |
+
if args.sample_sampler == "ddim":
|
| 2260 |
+
scheduler_cls = DDIMScheduler
|
| 2261 |
+
elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
|
| 2262 |
+
scheduler_cls = DDPMScheduler
|
| 2263 |
+
elif args.sample_sampler == "pndm":
|
| 2264 |
+
scheduler_cls = PNDMScheduler
|
| 2265 |
+
elif args.sample_sampler == 'lms' or args.sample_sampler == 'k_lms':
|
| 2266 |
+
scheduler_cls = LMSDiscreteScheduler
|
| 2267 |
+
elif args.sample_sampler == 'euler' or args.sample_sampler == 'k_euler':
|
| 2268 |
+
scheduler_cls = EulerDiscreteScheduler
|
| 2269 |
+
elif args.sample_sampler == 'euler_a' or args.sample_sampler == 'k_euler_a':
|
| 2270 |
+
scheduler_cls = EulerAncestralDiscreteScheduler
|
| 2271 |
+
elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
|
| 2272 |
+
scheduler_cls = DPMSolverMultistepScheduler
|
| 2273 |
+
sched_init_args['algorithm_type'] = args.sample_sampler
|
| 2274 |
+
elif args.sample_sampler == "dpmsingle":
|
| 2275 |
+
scheduler_cls = DPMSolverSinglestepScheduler
|
| 2276 |
+
elif args.sample_sampler == "heun":
|
| 2277 |
+
scheduler_cls = HeunDiscreteScheduler
|
| 2278 |
+
elif args.sample_sampler == 'dpm_2' or args.sample_sampler == 'k_dpm_2':
|
| 2279 |
+
scheduler_cls = KDPM2DiscreteScheduler
|
| 2280 |
+
elif args.sample_sampler == 'dpm_2_a' or args.sample_sampler == 'k_dpm_2_a':
|
| 2281 |
+
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
| 2282 |
+
else:
|
| 2283 |
+
scheduler_cls = DDIMScheduler
|
| 2284 |
+
|
| 2285 |
+
if args.v_parameterization:
|
| 2286 |
+
sched_init_args['prediction_type'] = 'v_prediction'
|
| 2287 |
+
|
| 2288 |
+
scheduler = scheduler_cls(num_train_timesteps=SCHEDULER_TIMESTEPS,
|
| 2289 |
+
beta_start=SCHEDULER_LINEAR_START, beta_end=SCHEDULER_LINEAR_END,
|
| 2290 |
+
beta_schedule=SCHEDLER_SCHEDULE, **sched_init_args)
|
| 2291 |
+
|
| 2292 |
+
# clip_sample=Trueにする
|
| 2293 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
| 2294 |
+
# print("set clip_sample to True")
|
| 2295 |
+
scheduler.config.clip_sample = True
|
| 2296 |
+
|
| 2297 |
+
pipeline = StableDiffusionPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer,
|
| 2298 |
+
scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False)
|
| 2299 |
+
pipeline.to(device)
|
| 2300 |
+
|
| 2301 |
+
save_dir = args.output_dir + "/sample"
|
| 2302 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 2303 |
+
|
| 2304 |
+
rng_state = torch.get_rng_state()
|
| 2305 |
+
cuda_rng_state = torch.cuda.get_rng_state()
|
| 2306 |
+
|
| 2307 |
+
with torch.no_grad():
|
| 2308 |
+
with accelerator.autocast():
|
| 2309 |
+
for i, prompt in enumerate(prompts):
|
| 2310 |
+
if not accelerator.is_main_process:
|
| 2311 |
+
continue
|
| 2312 |
+
prompt = prompt.strip()
|
| 2313 |
+
if len(prompt) == 0 or prompt[0] == '#':
|
| 2314 |
+
continue
|
| 2315 |
+
|
| 2316 |
+
# subset of gen_img_diffusers
|
| 2317 |
+
prompt_args = prompt.split(' --')
|
| 2318 |
+
prompt = prompt_args[0]
|
| 2319 |
+
negative_prompt = None
|
| 2320 |
+
sample_steps = 30
|
| 2321 |
+
width = height = 512
|
| 2322 |
+
scale = 7.5
|
| 2323 |
+
seed = None
|
| 2324 |
+
for parg in prompt_args:
|
| 2325 |
+
try:
|
| 2326 |
+
m = re.match(r'w (\d+)', parg, re.IGNORECASE)
|
| 2327 |
+
if m:
|
| 2328 |
+
width = int(m.group(1))
|
| 2329 |
+
continue
|
| 2330 |
+
|
| 2331 |
+
m = re.match(r'h (\d+)', parg, re.IGNORECASE)
|
| 2332 |
+
if m:
|
| 2333 |
+
height = int(m.group(1))
|
| 2334 |
+
continue
|
| 2335 |
+
|
| 2336 |
+
m = re.match(r'd (\d+)', parg, re.IGNORECASE)
|
| 2337 |
+
if m:
|
| 2338 |
+
seed = int(m.group(1))
|
| 2339 |
+
continue
|
| 2340 |
+
|
| 2341 |
+
m = re.match(r's (\d+)', parg, re.IGNORECASE)
|
| 2342 |
+
if m: # steps
|
| 2343 |
+
sample_steps = max(1, min(1000, int(m.group(1))))
|
| 2344 |
+
continue
|
| 2345 |
+
|
| 2346 |
+
m = re.match(r'l ([\d\.]+)', parg, re.IGNORECASE)
|
| 2347 |
+
if m: # scale
|
| 2348 |
+
scale = float(m.group(1))
|
| 2349 |
+
continue
|
| 2350 |
+
|
| 2351 |
+
m = re.match(r'n (.+)', parg, re.IGNORECASE)
|
| 2352 |
+
if m: # negative prompt
|
| 2353 |
+
negative_prompt = m.group(1)
|
| 2354 |
+
continue
|
| 2355 |
+
|
| 2356 |
+
except ValueError as ex:
|
| 2357 |
+
print(f"Exception in parsing / 解析エラー: {parg}")
|
| 2358 |
+
print(ex)
|
| 2359 |
+
|
| 2360 |
+
if seed is not None:
|
| 2361 |
+
torch.manual_seed(seed)
|
| 2362 |
+
torch.cuda.manual_seed(seed)
|
| 2363 |
+
|
| 2364 |
+
if prompt_replacement is not None:
|
| 2365 |
+
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
| 2366 |
+
if negative_prompt is not None:
|
| 2367 |
+
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
| 2368 |
+
|
| 2369 |
+
height = max(64, height - height % 8) # round to divisible by 8
|
| 2370 |
+
width = max(64, width - width % 8) # round to divisible by 8
|
| 2371 |
+
print(f"prompt: {prompt}")
|
| 2372 |
+
print(f"negative_prompt: {negative_prompt}")
|
| 2373 |
+
print(f"height: {height}")
|
| 2374 |
+
print(f"width: {width}")
|
| 2375 |
+
print(f"sample_steps: {sample_steps}")
|
| 2376 |
+
print(f"scale: {scale}")
|
| 2377 |
+
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
|
| 2378 |
+
|
| 2379 |
+
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
|
| 2380 |
+
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
| 2381 |
+
seed_suffix = "" if seed is None else f"_{seed}"
|
| 2382 |
+
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
|
| 2383 |
+
|
| 2384 |
+
image.save(os.path.join(save_dir, img_filename))
|
| 2385 |
+
|
| 2386 |
+
# clear pipeline and cache to reduce vram usage
|
| 2387 |
+
del pipeline
|
| 2388 |
+
torch.cuda.empty_cache()
|
| 2389 |
+
|
| 2390 |
+
torch.set_rng_state(rng_state)
|
| 2391 |
+
torch.cuda.set_rng_state(cuda_rng_state)
|
| 2392 |
+
vae.to(org_vae_device)
|
| 2393 |
+
|
| 2394 |
# endregion
|
| 2395 |
|
| 2396 |
# region 前処理用
|
networks/check_lora_weights.py
CHANGED
|
@@ -21,7 +21,7 @@ def main(file):
|
|
| 21 |
|
| 22 |
for key, value in values:
|
| 23 |
value = value.to(torch.float32)
|
| 24 |
-
print(f"{key},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
|
| 25 |
|
| 26 |
|
| 27 |
if __name__ == '__main__':
|
|
|
|
| 21 |
|
| 22 |
for key, value in values:
|
| 23 |
value = value.to(torch.float32)
|
| 24 |
+
print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
|
| 25 |
|
| 26 |
|
| 27 |
if __name__ == '__main__':
|
networks/extract_lora_from_models.py
CHANGED
|
@@ -45,8 +45,13 @@ def svd(args):
|
|
| 45 |
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
|
| 46 |
|
| 47 |
# create LoRA network to extract weights: Use dim (rank) as alpha
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
assert len(lora_network_o.text_encoder_loras) == len(
|
| 51 |
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
| 52 |
|
|
@@ -85,13 +90,28 @@ def svd(args):
|
|
| 85 |
|
| 86 |
# make LoRA with svd
|
| 87 |
print("calculating by svd")
|
| 88 |
-
rank = args.dim
|
| 89 |
lora_weights = {}
|
| 90 |
with torch.no_grad():
|
| 91 |
for lora_name, mat in tqdm(list(diffs.items())):
|
|
|
|
| 92 |
conv2d = (len(mat.size()) == 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
if conv2d:
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
U, S, Vh = torch.linalg.svd(mat)
|
| 97 |
|
|
@@ -108,30 +128,27 @@ def svd(args):
|
|
| 108 |
U = U.clamp(low_val, hi_val)
|
| 109 |
Vh = Vh.clamp(low_val, hi_val)
|
| 110 |
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict
|
| 115 |
-
lora_sd = lora_network_o.state_dict()
|
| 116 |
-
print(f"LoRA has {len(lora_sd)} weights.")
|
| 117 |
-
|
| 118 |
-
for key in list(lora_sd.keys()):
|
| 119 |
-
if "alpha" in key:
|
| 120 |
-
continue
|
| 121 |
|
| 122 |
-
|
| 123 |
-
|
| 124 |
|
| 125 |
-
|
| 126 |
-
# print(key, i, weights.size(), lora_sd[key].size())
|
| 127 |
-
if len(lora_sd[key].size()) == 4:
|
| 128 |
-
weights = weights.unsqueeze(2).unsqueeze(3)
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
# load state dict to LoRA and save it
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
| 135 |
print(f"Loading extracted LoRA weights: {info}")
|
| 136 |
|
| 137 |
dir_name = os.path.dirname(args.save_to)
|
|
@@ -139,9 +156,9 @@ def svd(args):
|
|
| 139 |
os.makedirs(dir_name, exist_ok=True)
|
| 140 |
|
| 141 |
# minimum metadata
|
| 142 |
-
metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
|
| 143 |
|
| 144 |
-
|
| 145 |
print(f"LoRA weights are saved to: {args.save_to}")
|
| 146 |
|
| 147 |
|
|
@@ -158,6 +175,8 @@ if __name__ == '__main__':
|
|
| 158 |
parser.add_argument("--save_to", type=str, default=None,
|
| 159 |
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
| 160 |
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
|
|
|
|
|
|
| 161 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイ��、cuda でGPUを使う")
|
| 162 |
|
| 163 |
args = parser.parse_args()
|
|
|
|
| 45 |
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
|
| 46 |
|
| 47 |
# create LoRA network to extract weights: Use dim (rank) as alpha
|
| 48 |
+
if args.conv_dim is None:
|
| 49 |
+
kwargs = {}
|
| 50 |
+
else:
|
| 51 |
+
kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim}
|
| 52 |
+
|
| 53 |
+
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o, **kwargs)
|
| 54 |
+
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t, **kwargs)
|
| 55 |
assert len(lora_network_o.text_encoder_loras) == len(
|
| 56 |
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
| 57 |
|
|
|
|
| 90 |
|
| 91 |
# make LoRA with svd
|
| 92 |
print("calculating by svd")
|
|
|
|
| 93 |
lora_weights = {}
|
| 94 |
with torch.no_grad():
|
| 95 |
for lora_name, mat in tqdm(list(diffs.items())):
|
| 96 |
+
# if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3
|
| 97 |
conv2d = (len(mat.size()) == 4)
|
| 98 |
+
kernel_size = None if not conv2d else mat.size()[2:4]
|
| 99 |
+
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
| 100 |
+
|
| 101 |
+
rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim
|
| 102 |
+
out_dim, in_dim = mat.size()[0:2]
|
| 103 |
+
|
| 104 |
+
if args.device:
|
| 105 |
+
mat = mat.to(args.device)
|
| 106 |
+
|
| 107 |
+
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
| 108 |
+
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
| 109 |
+
|
| 110 |
if conv2d:
|
| 111 |
+
if conv2d_3x3:
|
| 112 |
+
mat = mat.flatten(start_dim=1)
|
| 113 |
+
else:
|
| 114 |
+
mat = mat.squeeze()
|
| 115 |
|
| 116 |
U, S, Vh = torch.linalg.svd(mat)
|
| 117 |
|
|
|
|
| 128 |
U = U.clamp(low_val, hi_val)
|
| 129 |
Vh = Vh.clamp(low_val, hi_val)
|
| 130 |
|
| 131 |
+
if conv2d:
|
| 132 |
+
U = U.reshape(out_dim, rank, 1, 1)
|
| 133 |
+
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
+
U = U.to("cpu").contiguous()
|
| 136 |
+
Vh = Vh.to("cpu").contiguous()
|
| 137 |
|
| 138 |
+
lora_weights[lora_name] = (U, Vh)
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
+
# make state dict for LoRA
|
| 141 |
+
lora_sd = {}
|
| 142 |
+
for lora_name, (up_weight, down_weight) in lora_weights.items():
|
| 143 |
+
lora_sd[lora_name + '.lora_up.weight'] = up_weight
|
| 144 |
+
lora_sd[lora_name + '.lora_down.weight'] = down_weight
|
| 145 |
+
lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0])
|
| 146 |
|
| 147 |
# load state dict to LoRA and save it
|
| 148 |
+
lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
|
| 149 |
+
lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict
|
| 150 |
+
|
| 151 |
+
info = lora_network_save.load_state_dict(lora_sd)
|
| 152 |
print(f"Loading extracted LoRA weights: {info}")
|
| 153 |
|
| 154 |
dir_name = os.path.dirname(args.save_to)
|
|
|
|
| 156 |
os.makedirs(dir_name, exist_ok=True)
|
| 157 |
|
| 158 |
# minimum metadata
|
| 159 |
+
metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
|
| 160 |
|
| 161 |
+
lora_network_save.save_weights(args.save_to, save_dtype, metadata)
|
| 162 |
print(f"LoRA weights are saved to: {args.save_to}")
|
| 163 |
|
| 164 |
|
|
|
|
| 175 |
parser.add_argument("--save_to", type=str, default=None,
|
| 176 |
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
| 177 |
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
| 178 |
+
parser.add_argument("--conv_dim", type=int, default=None,
|
| 179 |
+
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)")
|
| 180 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイ��、cuda でGPUを使う")
|
| 181 |
|
| 182 |
args = parser.parse_args()
|
networks/lora.py
CHANGED
|
@@ -6,6 +6,7 @@
|
|
| 6 |
import math
|
| 7 |
import os
|
| 8 |
from typing import List
|
|
|
|
| 9 |
import torch
|
| 10 |
|
| 11 |
from library import train_util
|
|
@@ -20,22 +21,34 @@ class LoRAModule(torch.nn.Module):
|
|
| 20 |
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
| 21 |
super().__init__()
|
| 22 |
self.lora_name = lora_name
|
| 23 |
-
self.lora_dim = lora_dim
|
| 24 |
|
| 25 |
if org_module.__class__.__name__ == 'Conv2d':
|
| 26 |
in_dim = org_module.in_channels
|
| 27 |
out_dim = org_module.out_channels
|
| 28 |
-
self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
|
| 29 |
-
self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
|
| 30 |
else:
|
| 31 |
in_dim = org_module.in_features
|
| 32 |
out_dim = org_module.out_features
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
if type(alpha) == torch.Tensor:
|
| 37 |
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
| 38 |
-
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
| 39 |
self.scale = alpha / self.lora_dim
|
| 40 |
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
| 41 |
|
|
@@ -45,69 +58,192 @@ class LoRAModule(torch.nn.Module):
|
|
| 45 |
|
| 46 |
self.multiplier = multiplier
|
| 47 |
self.org_module = org_module # remove in applying
|
|
|
|
|
|
|
| 48 |
|
| 49 |
def apply_to(self):
|
| 50 |
self.org_forward = self.org_module.forward
|
| 51 |
self.org_module.forward = self.forward
|
| 52 |
del self.org_module
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
def forward(self, x):
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
| 59 |
if network_dim is None:
|
| 60 |
network_dim = 4 # default
|
| 61 |
-
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
|
| 62 |
-
return network
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
for key, value in weights_sd.items():
|
| 76 |
-
if network_alpha is None and 'alpha' in key:
|
| 77 |
-
network_alpha = value
|
| 78 |
-
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
|
| 79 |
-
network_dim = value.size()[0]
|
| 80 |
|
| 81 |
-
if network_alpha is None:
|
| 82 |
-
network_alpha = network_dim
|
| 83 |
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
network.weights_sd = weights_sd
|
| 86 |
return network
|
| 87 |
|
| 88 |
|
| 89 |
class LoRANetwork(torch.nn.Module):
|
|
|
|
| 90 |
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
|
|
|
| 91 |
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
| 92 |
LORA_PREFIX_UNET = 'lora_unet'
|
| 93 |
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
| 94 |
|
| 95 |
-
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
|
| 96 |
super().__init__()
|
| 97 |
self.multiplier = multiplier
|
|
|
|
| 98 |
self.lora_dim = lora_dim
|
| 99 |
self.alpha = alpha
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
# create module instances
|
| 102 |
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
| 103 |
loras = []
|
| 104 |
for name, module in root_module.named_modules():
|
| 105 |
if module.__class__.__name__ in target_replace_modules:
|
|
|
|
| 106 |
for child_name, child_module in module.named_modules():
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
| 108 |
lora_name = prefix + '.' + name + '.' + child_name
|
| 109 |
lora_name = lora_name.replace('.', '_')
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
loras.append(lora)
|
| 112 |
return loras
|
| 113 |
|
|
@@ -115,7 +251,12 @@ class LoRANetwork(torch.nn.Module):
|
|
| 115 |
text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
| 116 |
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
| 117 |
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
| 120 |
|
| 121 |
self.weights_sd = None
|
|
@@ -126,6 +267,11 @@ class LoRANetwork(torch.nn.Module):
|
|
| 126 |
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
| 127 |
names.add(lora.lora_name)
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
def load_weights(self, file):
|
| 130 |
if os.path.splitext(file)[1] == '.safetensors':
|
| 131 |
from safetensors.torch import load_file, safe_open
|
|
@@ -235,3 +381,18 @@ class LoRANetwork(torch.nn.Module):
|
|
| 235 |
save_file(state_dict, file, metadata)
|
| 236 |
else:
|
| 237 |
torch.save(state_dict, file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import math
|
| 7 |
import os
|
| 8 |
from typing import List
|
| 9 |
+
import numpy as np
|
| 10 |
import torch
|
| 11 |
|
| 12 |
from library import train_util
|
|
|
|
| 21 |
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
| 22 |
super().__init__()
|
| 23 |
self.lora_name = lora_name
|
|
|
|
| 24 |
|
| 25 |
if org_module.__class__.__name__ == 'Conv2d':
|
| 26 |
in_dim = org_module.in_channels
|
| 27 |
out_dim = org_module.out_channels
|
|
|
|
|
|
|
| 28 |
else:
|
| 29 |
in_dim = org_module.in_features
|
| 30 |
out_dim = org_module.out_features
|
| 31 |
+
|
| 32 |
+
# if limit_rank:
|
| 33 |
+
# self.lora_dim = min(lora_dim, in_dim, out_dim)
|
| 34 |
+
# if self.lora_dim != lora_dim:
|
| 35 |
+
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
| 36 |
+
# else:
|
| 37 |
+
self.lora_dim = lora_dim
|
| 38 |
+
|
| 39 |
+
if org_module.__class__.__name__ == 'Conv2d':
|
| 40 |
+
kernel_size = org_module.kernel_size
|
| 41 |
+
stride = org_module.stride
|
| 42 |
+
padding = org_module.padding
|
| 43 |
+
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
| 44 |
+
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
| 45 |
+
else:
|
| 46 |
+
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
| 47 |
+
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
| 48 |
|
| 49 |
if type(alpha) == torch.Tensor:
|
| 50 |
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
| 51 |
+
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
| 52 |
self.scale = alpha / self.lora_dim
|
| 53 |
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
| 54 |
|
|
|
|
| 58 |
|
| 59 |
self.multiplier = multiplier
|
| 60 |
self.org_module = org_module # remove in applying
|
| 61 |
+
self.region = None
|
| 62 |
+
self.region_mask = None
|
| 63 |
|
| 64 |
def apply_to(self):
|
| 65 |
self.org_forward = self.org_module.forward
|
| 66 |
self.org_module.forward = self.forward
|
| 67 |
del self.org_module
|
| 68 |
|
| 69 |
+
def set_region(self, region):
|
| 70 |
+
self.region = region
|
| 71 |
+
self.region_mask = None
|
| 72 |
+
|
| 73 |
def forward(self, x):
|
| 74 |
+
if self.region is None:
|
| 75 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
| 76 |
+
|
| 77 |
+
# regional LoRA FIXME same as additional-network extension
|
| 78 |
+
if x.size()[1] % 77 == 0:
|
| 79 |
+
# print(f"LoRA for context: {self.lora_name}")
|
| 80 |
+
self.region = None
|
| 81 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
| 82 |
+
|
| 83 |
+
# calculate region mask first time
|
| 84 |
+
if self.region_mask is None:
|
| 85 |
+
if len(x.size()) == 4:
|
| 86 |
+
h, w = x.size()[2:4]
|
| 87 |
+
else:
|
| 88 |
+
seq_len = x.size()[1]
|
| 89 |
+
ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
|
| 90 |
+
h = int(self.region.size()[0] / ratio + .5)
|
| 91 |
+
w = seq_len // h
|
| 92 |
+
|
| 93 |
+
r = self.region.to(x.device)
|
| 94 |
+
if r.dtype == torch.bfloat16:
|
| 95 |
+
r = r.to(torch.float)
|
| 96 |
+
r = r.unsqueeze(0).unsqueeze(1)
|
| 97 |
+
# print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
|
| 98 |
+
r = torch.nn.functional.interpolate(r, (h, w), mode='bilinear')
|
| 99 |
+
r = r.to(x.dtype)
|
| 100 |
+
|
| 101 |
+
if len(x.size()) == 3:
|
| 102 |
+
r = torch.reshape(r, (1, x.size()[1], -1))
|
| 103 |
+
|
| 104 |
+
self.region_mask = r
|
| 105 |
+
|
| 106 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask
|
| 107 |
|
| 108 |
|
| 109 |
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
| 110 |
if network_dim is None:
|
| 111 |
network_dim = 4 # default
|
|
|
|
|
|
|
| 112 |
|
| 113 |
+
# extract dim/alpha for conv2d, and block dim
|
| 114 |
+
conv_dim = kwargs.get('conv_dim', None)
|
| 115 |
+
conv_alpha = kwargs.get('conv_alpha', None)
|
| 116 |
+
if conv_dim is not None:
|
| 117 |
+
conv_dim = int(conv_dim)
|
| 118 |
+
if conv_alpha is None:
|
| 119 |
+
conv_alpha = 1.0
|
| 120 |
+
else:
|
| 121 |
+
conv_alpha = float(conv_alpha)
|
| 122 |
|
| 123 |
+
"""
|
| 124 |
+
block_dims = kwargs.get("block_dims")
|
| 125 |
+
block_alphas = None
|
| 126 |
+
|
| 127 |
+
if block_dims is not None:
|
| 128 |
+
block_dims = [int(d) for d in block_dims.split(',')]
|
| 129 |
+
assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
|
| 130 |
+
block_alphas = kwargs.get("block_alphas")
|
| 131 |
+
if block_alphas is None:
|
| 132 |
+
block_alphas = [1] * len(block_dims)
|
| 133 |
+
else:
|
| 134 |
+
block_alphas = [int(a) for a in block_alphas(',')]
|
| 135 |
+
assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
| 136 |
+
|
| 137 |
+
conv_block_dims = kwargs.get("conv_block_dims")
|
| 138 |
+
conv_block_alphas = None
|
| 139 |
+
|
| 140 |
+
if conv_block_dims is not None:
|
| 141 |
+
conv_block_dims = [int(d) for d in conv_block_dims.split(',')]
|
| 142 |
+
assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
|
| 143 |
+
conv_block_alphas = kwargs.get("conv_block_alphas")
|
| 144 |
+
if conv_block_alphas is None:
|
| 145 |
+
conv_block_alphas = [1] * len(conv_block_dims)
|
| 146 |
+
else:
|
| 147 |
+
conv_block_alphas = [int(a) for a in conv_block_alphas(',')]
|
| 148 |
+
assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
| 149 |
+
"""
|
| 150 |
|
| 151 |
+
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim,
|
| 152 |
+
alpha=network_alpha, conv_lora_dim=conv_dim, conv_alpha=conv_alpha)
|
| 153 |
+
return network
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
|
|
|
|
|
|
| 155 |
|
| 156 |
+
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
|
| 157 |
+
if weights_sd is None:
|
| 158 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
| 159 |
+
from safetensors.torch import load_file, safe_open
|
| 160 |
+
weights_sd = load_file(file)
|
| 161 |
+
else:
|
| 162 |
+
weights_sd = torch.load(file, map_location='cpu')
|
| 163 |
+
|
| 164 |
+
# get dim/alpha mapping
|
| 165 |
+
modules_dim = {}
|
| 166 |
+
modules_alpha = {}
|
| 167 |
+
for key, value in weights_sd.items():
|
| 168 |
+
if '.' not in key:
|
| 169 |
+
continue
|
| 170 |
+
|
| 171 |
+
lora_name = key.split('.')[0]
|
| 172 |
+
if 'alpha' in key:
|
| 173 |
+
modules_alpha[lora_name] = value
|
| 174 |
+
elif 'lora_down' in key:
|
| 175 |
+
dim = value.size()[0]
|
| 176 |
+
modules_dim[lora_name] = dim
|
| 177 |
+
# print(lora_name, value.size(), dim)
|
| 178 |
+
|
| 179 |
+
# support old LoRA without alpha
|
| 180 |
+
for key in modules_dim.keys():
|
| 181 |
+
if key not in modules_alpha:
|
| 182 |
+
modules_alpha = modules_dim[key]
|
| 183 |
+
|
| 184 |
+
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
| 185 |
network.weights_sd = weights_sd
|
| 186 |
return network
|
| 187 |
|
| 188 |
|
| 189 |
class LoRANetwork(torch.nn.Module):
|
| 190 |
+
# is it possible to apply conv_in and conv_out?
|
| 191 |
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
| 192 |
+
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
| 193 |
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
| 194 |
LORA_PREFIX_UNET = 'lora_unet'
|
| 195 |
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
| 196 |
|
| 197 |
+
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1, conv_lora_dim=None, conv_alpha=None, modules_dim=None, modules_alpha=None) -> None:
|
| 198 |
super().__init__()
|
| 199 |
self.multiplier = multiplier
|
| 200 |
+
|
| 201 |
self.lora_dim = lora_dim
|
| 202 |
self.alpha = alpha
|
| 203 |
+
self.conv_lora_dim = conv_lora_dim
|
| 204 |
+
self.conv_alpha = conv_alpha
|
| 205 |
+
|
| 206 |
+
if modules_dim is not None:
|
| 207 |
+
print(f"create LoRA network from weights")
|
| 208 |
+
else:
|
| 209 |
+
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
| 210 |
+
|
| 211 |
+
self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None
|
| 212 |
+
if self.apply_to_conv2d_3x3:
|
| 213 |
+
if self.conv_alpha is None:
|
| 214 |
+
self.conv_alpha = self.alpha
|
| 215 |
+
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
| 216 |
|
| 217 |
# create module instances
|
| 218 |
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
| 219 |
loras = []
|
| 220 |
for name, module in root_module.named_modules():
|
| 221 |
if module.__class__.__name__ in target_replace_modules:
|
| 222 |
+
# TODO get block index here
|
| 223 |
for child_name, child_module in module.named_modules():
|
| 224 |
+
is_linear = child_module.__class__.__name__ == "Linear"
|
| 225 |
+
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
| 226 |
+
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
| 227 |
+
if is_linear or is_conv2d:
|
| 228 |
lora_name = prefix + '.' + name + '.' + child_name
|
| 229 |
lora_name = lora_name.replace('.', '_')
|
| 230 |
+
|
| 231 |
+
if modules_dim is not None:
|
| 232 |
+
if lora_name not in modules_dim:
|
| 233 |
+
continue # no LoRA module in this weights file
|
| 234 |
+
dim = modules_dim[lora_name]
|
| 235 |
+
alpha = modules_alpha[lora_name]
|
| 236 |
+
else:
|
| 237 |
+
if is_linear or is_conv2d_1x1:
|
| 238 |
+
dim = self.lora_dim
|
| 239 |
+
alpha = self.alpha
|
| 240 |
+
elif self.apply_to_conv2d_3x3:
|
| 241 |
+
dim = self.conv_lora_dim
|
| 242 |
+
alpha = self.conv_alpha
|
| 243 |
+
else:
|
| 244 |
+
continue
|
| 245 |
+
|
| 246 |
+
lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
|
| 247 |
loras.append(lora)
|
| 248 |
return loras
|
| 249 |
|
|
|
|
| 251 |
text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
| 252 |
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
| 253 |
|
| 254 |
+
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
| 255 |
+
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
| 256 |
+
if modules_dim is not None or self.conv_lora_dim is not None:
|
| 257 |
+
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
| 258 |
+
|
| 259 |
+
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules)
|
| 260 |
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
| 261 |
|
| 262 |
self.weights_sd = None
|
|
|
|
| 267 |
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
| 268 |
names.add(lora.lora_name)
|
| 269 |
|
| 270 |
+
def set_multiplier(self, multiplier):
|
| 271 |
+
self.multiplier = multiplier
|
| 272 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 273 |
+
lora.multiplier = self.multiplier
|
| 274 |
+
|
| 275 |
def load_weights(self, file):
|
| 276 |
if os.path.splitext(file)[1] == '.safetensors':
|
| 277 |
from safetensors.torch import load_file, safe_open
|
|
|
|
| 381 |
save_file(state_dict, file, metadata)
|
| 382 |
else:
|
| 383 |
torch.save(state_dict, file)
|
| 384 |
+
|
| 385 |
+
@ staticmethod
|
| 386 |
+
def set_regions(networks, image):
|
| 387 |
+
image = image.astype(np.float32) / 255.0
|
| 388 |
+
for i, network in enumerate(networks[:3]):
|
| 389 |
+
# NOTE: consider averaging overwrapping area
|
| 390 |
+
region = image[:, :, i]
|
| 391 |
+
if region.max() == 0:
|
| 392 |
+
continue
|
| 393 |
+
region = torch.tensor(region)
|
| 394 |
+
network.set_region(region)
|
| 395 |
+
|
| 396 |
+
def set_region(self, region):
|
| 397 |
+
for lora in self.unet_loras:
|
| 398 |
+
lora.set_region(region)
|
networks/merge_lora.py
CHANGED
|
@@ -48,7 +48,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
|
| 48 |
for name, module in root_module.named_modules():
|
| 49 |
if module.__class__.__name__ in target_replace_modules:
|
| 50 |
for child_name, child_module in module.named_modules():
|
| 51 |
-
if child_module.__class__.__name__ == "Linear" or
|
| 52 |
lora_name = prefix + '.' + name + '.' + child_name
|
| 53 |
lora_name = lora_name.replace('.', '_')
|
| 54 |
name_to_module[lora_name] = child_module
|
|
@@ -80,13 +80,19 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
|
| 80 |
|
| 81 |
# W <- W + U * D
|
| 82 |
weight = module.weight
|
|
|
|
| 83 |
if len(weight.size()) == 2:
|
| 84 |
# linear
|
| 85 |
weight = weight + ratio * (up_weight @ down_weight) * scale
|
| 86 |
-
|
| 87 |
-
# conv2d
|
| 88 |
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
| 89 |
).unsqueeze(2).unsqueeze(3) * scale
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
module.weight = torch.nn.Parameter(weight)
|
| 92 |
|
|
@@ -123,7 +129,7 @@ def merge_lora_models(models, ratios, merge_dtype):
|
|
| 123 |
alphas[lora_module_name] = alpha
|
| 124 |
if lora_module_name not in base_alphas:
|
| 125 |
base_alphas[lora_module_name] = alpha
|
| 126 |
-
|
| 127 |
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
| 128 |
|
| 129 |
# merge
|
|
@@ -145,7 +151,7 @@ def merge_lora_models(models, ratios, merge_dtype):
|
|
| 145 |
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
| 146 |
else:
|
| 147 |
merged_sd[key] = lora_sd[key] * scale
|
| 148 |
-
|
| 149 |
# set alpha to sd
|
| 150 |
for lora_module_name, alpha in base_alphas.items():
|
| 151 |
key = lora_module_name + ".alpha"
|
|
|
|
| 48 |
for name, module in root_module.named_modules():
|
| 49 |
if module.__class__.__name__ in target_replace_modules:
|
| 50 |
for child_name, child_module in module.named_modules():
|
| 51 |
+
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
|
| 52 |
lora_name = prefix + '.' + name + '.' + child_name
|
| 53 |
lora_name = lora_name.replace('.', '_')
|
| 54 |
name_to_module[lora_name] = child_module
|
|
|
|
| 80 |
|
| 81 |
# W <- W + U * D
|
| 82 |
weight = module.weight
|
| 83 |
+
# print(module_name, down_weight.size(), up_weight.size())
|
| 84 |
if len(weight.size()) == 2:
|
| 85 |
# linear
|
| 86 |
weight = weight + ratio * (up_weight @ down_weight) * scale
|
| 87 |
+
elif down_weight.size()[2:4] == (1, 1):
|
| 88 |
+
# conv2d 1x1
|
| 89 |
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
| 90 |
).unsqueeze(2).unsqueeze(3) * scale
|
| 91 |
+
else:
|
| 92 |
+
# conv2d 3x3
|
| 93 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
| 94 |
+
# print(conved.size(), weight.size(), module.stride, module.padding)
|
| 95 |
+
weight = weight + ratio * conved * scale
|
| 96 |
|
| 97 |
module.weight = torch.nn.Parameter(weight)
|
| 98 |
|
|
|
|
| 129 |
alphas[lora_module_name] = alpha
|
| 130 |
if lora_module_name not in base_alphas:
|
| 131 |
base_alphas[lora_module_name] = alpha
|
| 132 |
+
|
| 133 |
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
| 134 |
|
| 135 |
# merge
|
|
|
|
| 151 |
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
| 152 |
else:
|
| 153 |
merged_sd[key] = lora_sd[key] * scale
|
| 154 |
+
|
| 155 |
# set alpha to sd
|
| 156 |
for lora_module_name, alpha in base_alphas.items():
|
| 157 |
key = lora_module_name + ".alpha"
|
networks/resize_lora.py
CHANGED
|
@@ -1,14 +1,15 @@
|
|
| 1 |
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
| 2 |
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
| 3 |
-
# Thanks to cloneofsimo
|
| 4 |
|
| 5 |
import argparse
|
| 6 |
-
import os
|
| 7 |
import torch
|
| 8 |
from safetensors.torch import load_file, save_file, safe_open
|
| 9 |
from tqdm import tqdm
|
| 10 |
from library import train_util, model_util
|
|
|
|
| 11 |
|
|
|
|
| 12 |
|
| 13 |
def load_state_dict(file_name, dtype):
|
| 14 |
if model_util.is_safetensors(file_name):
|
|
@@ -38,12 +39,149 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
|
|
| 38 |
torch.save(model, file_name)
|
| 39 |
|
| 40 |
|
| 41 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
network_alpha = None
|
| 43 |
network_dim = None
|
| 44 |
verbose_str = "\n"
|
| 45 |
-
|
| 46 |
-
CLAMP_QUANTILE = 0.99
|
| 47 |
|
| 48 |
# Extract loaded lora dim and alpha
|
| 49 |
for key, value in lora_sd.items():
|
|
@@ -57,9 +195,9 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
|
| 57 |
network_alpha = network_dim
|
| 58 |
|
| 59 |
scale = network_alpha/network_dim
|
| 60 |
-
new_alpha = float(scale*new_rank) # calculate new alpha from scale
|
| 61 |
|
| 62 |
-
|
|
|
|
| 63 |
|
| 64 |
lora_down_weight = None
|
| 65 |
lora_up_weight = None
|
|
@@ -68,7 +206,6 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
|
| 68 |
block_down_name = None
|
| 69 |
block_up_name = None
|
| 70 |
|
| 71 |
-
print("resizing lora...")
|
| 72 |
with torch.no_grad():
|
| 73 |
for key, value in tqdm(lora_sd.items()):
|
| 74 |
if 'lora_down' in key:
|
|
@@ -85,57 +222,43 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
|
| 85 |
conv2d = (len(lora_down_weight.size()) == 4)
|
| 86 |
|
| 87 |
if conv2d:
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
lora_up_weight = lora_up_weight.to(args.device)
|
| 94 |
-
lora_down_weight = lora_down_weight.to(args.device)
|
| 95 |
-
|
| 96 |
-
full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
|
| 97 |
-
|
| 98 |
-
U, S, Vh = torch.linalg.svd(full_weight_matrix)
|
| 99 |
|
| 100 |
if verbose:
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
|
|
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
U = U @ torch.diag(S)
|
| 109 |
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
U = U.clamp(low_val, hi_val)
|
| 117 |
-
Vh = Vh.clamp(low_val, hi_val)
|
| 118 |
-
|
| 119 |
-
if conv2d:
|
| 120 |
-
U = U.unsqueeze(2).unsqueeze(3)
|
| 121 |
-
Vh = Vh.unsqueeze(2).unsqueeze(3)
|
| 122 |
-
|
| 123 |
-
if device:
|
| 124 |
-
U = U.to(org_device)
|
| 125 |
-
Vh = Vh.to(org_device)
|
| 126 |
-
|
| 127 |
-
o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
|
| 128 |
-
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
|
| 129 |
-
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
|
| 130 |
|
| 131 |
block_down_name = None
|
| 132 |
block_up_name = None
|
| 133 |
lora_down_weight = None
|
| 134 |
lora_up_weight = None
|
| 135 |
weights_loaded = False
|
|
|
|
| 136 |
|
| 137 |
if verbose:
|
| 138 |
print(verbose_str)
|
|
|
|
|
|
|
| 139 |
print("resizing complete")
|
| 140 |
return o_lora_sd, network_dim, new_alpha
|
| 141 |
|
|
@@ -151,6 +274,9 @@ def resize(args):
|
|
| 151 |
return torch.bfloat16
|
| 152 |
return None
|
| 153 |
|
|
|
|
|
|
|
|
|
|
| 154 |
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
|
| 155 |
save_dtype = str_to_dtype(args.save_precision)
|
| 156 |
if save_dtype is None:
|
|
@@ -159,17 +285,23 @@ def resize(args):
|
|
| 159 |
print("loading Model...")
|
| 160 |
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
| 161 |
|
| 162 |
-
print("
|
| 163 |
-
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose)
|
| 164 |
|
| 165 |
# update metadata
|
| 166 |
if metadata is None:
|
| 167 |
metadata = {}
|
| 168 |
|
| 169 |
comment = metadata.get("ss_training_comment", "")
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
| 175 |
metadata["sshs_model_hash"] = model_hash
|
|
@@ -193,6 +325,11 @@ if __name__ == '__main__':
|
|
| 193 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
| 194 |
parser.add_argument("--verbose", action="store_true",
|
| 195 |
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
args = parser.parse_args()
|
| 198 |
resize(args)
|
|
|
|
| 1 |
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
| 2 |
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
| 3 |
+
# Thanks to cloneofsimo
|
| 4 |
|
| 5 |
import argparse
|
|
|
|
| 6 |
import torch
|
| 7 |
from safetensors.torch import load_file, save_file, safe_open
|
| 8 |
from tqdm import tqdm
|
| 9 |
from library import train_util, model_util
|
| 10 |
+
import numpy as np
|
| 11 |
|
| 12 |
+
MIN_SV = 1e-6
|
| 13 |
|
| 14 |
def load_state_dict(file_name, dtype):
|
| 15 |
if model_util.is_safetensors(file_name):
|
|
|
|
| 39 |
torch.save(model, file_name)
|
| 40 |
|
| 41 |
|
| 42 |
+
def index_sv_cumulative(S, target):
|
| 43 |
+
original_sum = float(torch.sum(S))
|
| 44 |
+
cumulative_sums = torch.cumsum(S, dim=0)/original_sum
|
| 45 |
+
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
| 46 |
+
if index >= len(S):
|
| 47 |
+
index = len(S) - 1
|
| 48 |
+
|
| 49 |
+
return index
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def index_sv_fro(S, target):
|
| 53 |
+
S_squared = S.pow(2)
|
| 54 |
+
s_fro_sq = float(torch.sum(S_squared))
|
| 55 |
+
sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
|
| 56 |
+
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
| 57 |
+
if index >= len(S):
|
| 58 |
+
index = len(S) - 1
|
| 59 |
+
|
| 60 |
+
return index
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# Modified from Kohaku-blueleaf's extract/merge functions
|
| 64 |
+
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
| 65 |
+
out_size, in_size, kernel_size, _ = weight.size()
|
| 66 |
+
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
|
| 67 |
+
|
| 68 |
+
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
| 69 |
+
lora_rank = param_dict["new_rank"]
|
| 70 |
+
|
| 71 |
+
U = U[:, :lora_rank]
|
| 72 |
+
S = S[:lora_rank]
|
| 73 |
+
U = U @ torch.diag(S)
|
| 74 |
+
Vh = Vh[:lora_rank, :]
|
| 75 |
+
|
| 76 |
+
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
|
| 77 |
+
param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
|
| 78 |
+
del U, S, Vh, weight
|
| 79 |
+
return param_dict
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
| 83 |
+
out_size, in_size = weight.size()
|
| 84 |
+
|
| 85 |
+
U, S, Vh = torch.linalg.svd(weight.to(device))
|
| 86 |
+
|
| 87 |
+
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
| 88 |
+
lora_rank = param_dict["new_rank"]
|
| 89 |
+
|
| 90 |
+
U = U[:, :lora_rank]
|
| 91 |
+
S = S[:lora_rank]
|
| 92 |
+
U = U @ torch.diag(S)
|
| 93 |
+
Vh = Vh[:lora_rank, :]
|
| 94 |
+
|
| 95 |
+
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
|
| 96 |
+
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
|
| 97 |
+
del U, S, Vh, weight
|
| 98 |
+
return param_dict
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def merge_conv(lora_down, lora_up, device):
|
| 102 |
+
in_rank, in_size, kernel_size, k_ = lora_down.shape
|
| 103 |
+
out_size, out_rank, _, _ = lora_up.shape
|
| 104 |
+
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
|
| 105 |
+
|
| 106 |
+
lora_down = lora_down.to(device)
|
| 107 |
+
lora_up = lora_up.to(device)
|
| 108 |
+
|
| 109 |
+
merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
|
| 110 |
+
weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
|
| 111 |
+
del lora_up, lora_down
|
| 112 |
+
return weight
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def merge_linear(lora_down, lora_up, device):
|
| 116 |
+
in_rank, in_size = lora_down.shape
|
| 117 |
+
out_size, out_rank = lora_up.shape
|
| 118 |
+
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
|
| 119 |
+
|
| 120 |
+
lora_down = lora_down.to(device)
|
| 121 |
+
lora_up = lora_up.to(device)
|
| 122 |
+
|
| 123 |
+
weight = lora_up @ lora_down
|
| 124 |
+
del lora_up, lora_down
|
| 125 |
+
return weight
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
| 129 |
+
param_dict = {}
|
| 130 |
+
|
| 131 |
+
if dynamic_method=="sv_ratio":
|
| 132 |
+
# Calculate new dim and alpha based off ratio
|
| 133 |
+
max_sv = S[0]
|
| 134 |
+
min_sv = max_sv/dynamic_param
|
| 135 |
+
new_rank = max(torch.sum(S > min_sv).item(),1)
|
| 136 |
+
new_alpha = float(scale*new_rank)
|
| 137 |
+
|
| 138 |
+
elif dynamic_method=="sv_cumulative":
|
| 139 |
+
# Calculate new dim and alpha based off cumulative sum
|
| 140 |
+
new_rank = index_sv_cumulative(S, dynamic_param)
|
| 141 |
+
new_rank = max(new_rank, 1)
|
| 142 |
+
new_alpha = float(scale*new_rank)
|
| 143 |
+
|
| 144 |
+
elif dynamic_method=="sv_fro":
|
| 145 |
+
# Calculate new dim and alpha based off sqrt sum of squares
|
| 146 |
+
new_rank = index_sv_fro(S, dynamic_param)
|
| 147 |
+
new_rank = min(max(new_rank, 1), len(S)-1)
|
| 148 |
+
new_alpha = float(scale*new_rank)
|
| 149 |
+
else:
|
| 150 |
+
new_rank = rank
|
| 151 |
+
new_alpha = float(scale*new_rank)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
|
| 155 |
+
new_rank = 1
|
| 156 |
+
new_alpha = float(scale*new_rank)
|
| 157 |
+
elif new_rank > rank: # cap max rank at rank
|
| 158 |
+
new_rank = rank
|
| 159 |
+
new_alpha = float(scale*new_rank)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# Calculate resize info
|
| 163 |
+
s_sum = torch.sum(torch.abs(S))
|
| 164 |
+
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
| 165 |
+
|
| 166 |
+
S_squared = S.pow(2)
|
| 167 |
+
s_fro = torch.sqrt(torch.sum(S_squared))
|
| 168 |
+
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
|
| 169 |
+
fro_percent = float(s_red_fro/s_fro)
|
| 170 |
+
|
| 171 |
+
param_dict["new_rank"] = new_rank
|
| 172 |
+
param_dict["new_alpha"] = new_alpha
|
| 173 |
+
param_dict["sum_retained"] = (s_rank)/s_sum
|
| 174 |
+
param_dict["fro_retained"] = fro_percent
|
| 175 |
+
param_dict["max_ratio"] = S[0]/S[new_rank]
|
| 176 |
+
|
| 177 |
+
return param_dict
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
| 181 |
network_alpha = None
|
| 182 |
network_dim = None
|
| 183 |
verbose_str = "\n"
|
| 184 |
+
fro_list = []
|
|
|
|
| 185 |
|
| 186 |
# Extract loaded lora dim and alpha
|
| 187 |
for key, value in lora_sd.items():
|
|
|
|
| 195 |
network_alpha = network_dim
|
| 196 |
|
| 197 |
scale = network_alpha/network_dim
|
|
|
|
| 198 |
|
| 199 |
+
if dynamic_method:
|
| 200 |
+
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
|
| 201 |
|
| 202 |
lora_down_weight = None
|
| 203 |
lora_up_weight = None
|
|
|
|
| 206 |
block_down_name = None
|
| 207 |
block_up_name = None
|
| 208 |
|
|
|
|
| 209 |
with torch.no_grad():
|
| 210 |
for key, value in tqdm(lora_sd.items()):
|
| 211 |
if 'lora_down' in key:
|
|
|
|
| 222 |
conv2d = (len(lora_down_weight.size()) == 4)
|
| 223 |
|
| 224 |
if conv2d:
|
| 225 |
+
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
| 226 |
+
param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
| 227 |
+
else:
|
| 228 |
+
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
|
| 229 |
+
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
if verbose:
|
| 232 |
+
max_ratio = param_dict['max_ratio']
|
| 233 |
+
sum_retained = param_dict['sum_retained']
|
| 234 |
+
fro_retained = param_dict['fro_retained']
|
| 235 |
+
if not np.isnan(fro_retained):
|
| 236 |
+
fro_list.append(float(fro_retained))
|
| 237 |
|
| 238 |
+
verbose_str+=f"{block_down_name:75} | "
|
| 239 |
+
verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
|
|
|
|
| 240 |
|
| 241 |
+
if verbose and dynamic_method:
|
| 242 |
+
verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
|
| 243 |
+
else:
|
| 244 |
+
verbose_str+=f"\n"
|
| 245 |
|
| 246 |
+
new_alpha = param_dict['new_alpha']
|
| 247 |
+
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
|
| 248 |
+
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
|
| 249 |
+
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
block_down_name = None
|
| 252 |
block_up_name = None
|
| 253 |
lora_down_weight = None
|
| 254 |
lora_up_weight = None
|
| 255 |
weights_loaded = False
|
| 256 |
+
del param_dict
|
| 257 |
|
| 258 |
if verbose:
|
| 259 |
print(verbose_str)
|
| 260 |
+
|
| 261 |
+
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
| 262 |
print("resizing complete")
|
| 263 |
return o_lora_sd, network_dim, new_alpha
|
| 264 |
|
|
|
|
| 274 |
return torch.bfloat16
|
| 275 |
return None
|
| 276 |
|
| 277 |
+
if args.dynamic_method and not args.dynamic_param:
|
| 278 |
+
raise Exception("If using dynamic_method, then dynamic_param is required")
|
| 279 |
+
|
| 280 |
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
|
| 281 |
save_dtype = str_to_dtype(args.save_precision)
|
| 282 |
if save_dtype is None:
|
|
|
|
| 285 |
print("loading Model...")
|
| 286 |
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
| 287 |
|
| 288 |
+
print("Resizing Lora...")
|
| 289 |
+
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
|
| 290 |
|
| 291 |
# update metadata
|
| 292 |
if metadata is None:
|
| 293 |
metadata = {}
|
| 294 |
|
| 295 |
comment = metadata.get("ss_training_comment", "")
|
| 296 |
+
|
| 297 |
+
if not args.dynamic_method:
|
| 298 |
+
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
|
| 299 |
+
metadata["ss_network_dim"] = str(args.new_rank)
|
| 300 |
+
metadata["ss_network_alpha"] = str(new_alpha)
|
| 301 |
+
else:
|
| 302 |
+
metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
|
| 303 |
+
metadata["ss_network_dim"] = 'Dynamic'
|
| 304 |
+
metadata["ss_network_alpha"] = 'Dynamic'
|
| 305 |
|
| 306 |
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
| 307 |
metadata["sshs_model_hash"] = model_hash
|
|
|
|
| 325 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
| 326 |
parser.add_argument("--verbose", action="store_true",
|
| 327 |
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
|
| 328 |
+
parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
|
| 329 |
+
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
|
| 330 |
+
parser.add_argument("--dynamic_param", type=float, default=None,
|
| 331 |
+
help="Specify target for dynamic reduction")
|
| 332 |
+
|
| 333 |
|
| 334 |
args = parser.parse_args()
|
| 335 |
resize(args)
|
networks/svd_merge_lora.py
CHANGED
|
@@ -23,19 +23,20 @@ def load_state_dict(file_name, dtype):
|
|
| 23 |
return sd
|
| 24 |
|
| 25 |
|
| 26 |
-
def save_to_file(file_name,
|
| 27 |
if dtype is not None:
|
| 28 |
for key in list(state_dict.keys()):
|
| 29 |
if type(state_dict[key]) == torch.Tensor:
|
| 30 |
state_dict[key] = state_dict[key].to(dtype)
|
| 31 |
|
| 32 |
if os.path.splitext(file_name)[1] == '.safetensors':
|
| 33 |
-
save_file(
|
| 34 |
else:
|
| 35 |
-
torch.save(
|
| 36 |
|
| 37 |
|
| 38 |
-
def merge_lora_models(models, ratios, new_rank, device,
|
|
|
|
| 39 |
merged_sd = {}
|
| 40 |
for model, ratio in zip(models, ratios):
|
| 41 |
print(f"loading: {model}")
|
|
@@ -58,11 +59,12 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
|
| 58 |
in_dim = down_weight.size()[1]
|
| 59 |
out_dim = up_weight.size()[0]
|
| 60 |
conv2d = len(down_weight.size()) == 4
|
| 61 |
-
|
|
|
|
| 62 |
|
| 63 |
# make original weight if not exist
|
| 64 |
if lora_module_name not in merged_sd:
|
| 65 |
-
weight = torch.zeros((out_dim, in_dim,
|
| 66 |
if device:
|
| 67 |
weight = weight.to(device)
|
| 68 |
else:
|
|
@@ -75,11 +77,18 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
|
| 75 |
|
| 76 |
# W <- W + U * D
|
| 77 |
scale = (alpha / network_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
if not conv2d: # linear
|
| 79 |
weight = weight + ratio * (up_weight @ down_weight) * scale
|
| 80 |
-
|
| 81 |
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
| 82 |
).unsqueeze(2).unsqueeze(3) * scale
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
merged_sd[lora_module_name] = weight
|
| 85 |
|
|
@@ -89,16 +98,26 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
|
| 89 |
with torch.no_grad():
|
| 90 |
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
| 91 |
conv2d = (len(mat.size()) == 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
if conv2d:
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
U, S, Vh = torch.linalg.svd(mat)
|
| 96 |
|
| 97 |
-
U = U[:, :
|
| 98 |
-
S = S[:
|
| 99 |
U = U @ torch.diag(S)
|
| 100 |
|
| 101 |
-
Vh = Vh[:
|
| 102 |
|
| 103 |
dist = torch.cat([U.flatten(), Vh.flatten()])
|
| 104 |
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
|
@@ -107,16 +126,16 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
|
| 107 |
U = U.clamp(low_val, hi_val)
|
| 108 |
Vh = Vh.clamp(low_val, hi_val)
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
up_weight = U
|
| 111 |
down_weight = Vh
|
| 112 |
|
| 113 |
-
if conv2d:
|
| 114 |
-
up_weight = up_weight.unsqueeze(2).unsqueeze(3)
|
| 115 |
-
down_weight = down_weight.unsqueeze(2).unsqueeze(3)
|
| 116 |
-
|
| 117 |
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
|
| 118 |
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
|
| 119 |
-
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(
|
| 120 |
|
| 121 |
return merged_lora_sd
|
| 122 |
|
|
@@ -138,10 +157,11 @@ def merge(args):
|
|
| 138 |
if save_dtype is None:
|
| 139 |
save_dtype = merge_dtype
|
| 140 |
|
| 141 |
-
|
|
|
|
| 142 |
|
| 143 |
print(f"saving model to: {args.save_to}")
|
| 144 |
-
save_to_file(args.save_to, state_dict,
|
| 145 |
|
| 146 |
|
| 147 |
if __name__ == '__main__':
|
|
@@ -158,6 +178,8 @@ if __name__ == '__main__':
|
|
| 158 |
help="ratios for each model / それぞれのLoRAモデルの比率")
|
| 159 |
parser.add_argument("--new_rank", type=int, default=4,
|
| 160 |
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
|
|
|
|
|
|
| 161 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
| 162 |
|
| 163 |
args = parser.parse_args()
|
|
|
|
| 23 |
return sd
|
| 24 |
|
| 25 |
|
| 26 |
+
def save_to_file(file_name, state_dict, dtype):
|
| 27 |
if dtype is not None:
|
| 28 |
for key in list(state_dict.keys()):
|
| 29 |
if type(state_dict[key]) == torch.Tensor:
|
| 30 |
state_dict[key] = state_dict[key].to(dtype)
|
| 31 |
|
| 32 |
if os.path.splitext(file_name)[1] == '.safetensors':
|
| 33 |
+
save_file(state_dict, file_name)
|
| 34 |
else:
|
| 35 |
+
torch.save(state_dict, file_name)
|
| 36 |
|
| 37 |
|
| 38 |
+
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
|
| 39 |
+
print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
|
| 40 |
merged_sd = {}
|
| 41 |
for model, ratio in zip(models, ratios):
|
| 42 |
print(f"loading: {model}")
|
|
|
|
| 59 |
in_dim = down_weight.size()[1]
|
| 60 |
out_dim = up_weight.size()[0]
|
| 61 |
conv2d = len(down_weight.size()) == 4
|
| 62 |
+
kernel_size = None if not conv2d else down_weight.size()[2:4]
|
| 63 |
+
# print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
|
| 64 |
|
| 65 |
# make original weight if not exist
|
| 66 |
if lora_module_name not in merged_sd:
|
| 67 |
+
weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
|
| 68 |
if device:
|
| 69 |
weight = weight.to(device)
|
| 70 |
else:
|
|
|
|
| 77 |
|
| 78 |
# W <- W + U * D
|
| 79 |
scale = (alpha / network_dim)
|
| 80 |
+
|
| 81 |
+
if device: # and isinstance(scale, torch.Tensor):
|
| 82 |
+
scale = scale.to(device)
|
| 83 |
+
|
| 84 |
if not conv2d: # linear
|
| 85 |
weight = weight + ratio * (up_weight @ down_weight) * scale
|
| 86 |
+
elif kernel_size == (1, 1):
|
| 87 |
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
| 88 |
).unsqueeze(2).unsqueeze(3) * scale
|
| 89 |
+
else:
|
| 90 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
| 91 |
+
weight = weight + ratio * conved * scale
|
| 92 |
|
| 93 |
merged_sd[lora_module_name] = weight
|
| 94 |
|
|
|
|
| 98 |
with torch.no_grad():
|
| 99 |
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
| 100 |
conv2d = (len(mat.size()) == 4)
|
| 101 |
+
kernel_size = None if not conv2d else mat.size()[2:4]
|
| 102 |
+
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
| 103 |
+
out_dim, in_dim = mat.size()[0:2]
|
| 104 |
+
|
| 105 |
if conv2d:
|
| 106 |
+
if conv2d_3x3:
|
| 107 |
+
mat = mat.flatten(start_dim=1)
|
| 108 |
+
else:
|
| 109 |
+
mat = mat.squeeze()
|
| 110 |
+
|
| 111 |
+
module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
|
| 112 |
+
module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
| 113 |
|
| 114 |
U, S, Vh = torch.linalg.svd(mat)
|
| 115 |
|
| 116 |
+
U = U[:, :module_new_rank]
|
| 117 |
+
S = S[:module_new_rank]
|
| 118 |
U = U @ torch.diag(S)
|
| 119 |
|
| 120 |
+
Vh = Vh[:module_new_rank, :]
|
| 121 |
|
| 122 |
dist = torch.cat([U.flatten(), Vh.flatten()])
|
| 123 |
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
|
|
|
| 126 |
U = U.clamp(low_val, hi_val)
|
| 127 |
Vh = Vh.clamp(low_val, hi_val)
|
| 128 |
|
| 129 |
+
if conv2d:
|
| 130 |
+
U = U.reshape(out_dim, module_new_rank, 1, 1)
|
| 131 |
+
Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1])
|
| 132 |
+
|
| 133 |
up_weight = U
|
| 134 |
down_weight = Vh
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
|
| 137 |
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
|
| 138 |
+
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(module_new_rank)
|
| 139 |
|
| 140 |
return merged_lora_sd
|
| 141 |
|
|
|
|
| 157 |
if save_dtype is None:
|
| 158 |
save_dtype = merge_dtype
|
| 159 |
|
| 160 |
+
new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
|
| 161 |
+
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype)
|
| 162 |
|
| 163 |
print(f"saving model to: {args.save_to}")
|
| 164 |
+
save_to_file(args.save_to, state_dict, save_dtype)
|
| 165 |
|
| 166 |
|
| 167 |
if __name__ == '__main__':
|
|
|
|
| 178 |
help="ratios for each model / それぞれのLoRAモデルの比率")
|
| 179 |
parser.add_argument("--new_rank", type=int, default=4,
|
| 180 |
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
| 181 |
+
parser.add_argument("--new_conv_rank", type=int, default=None,
|
| 182 |
+
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ")
|
| 183 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
| 184 |
|
| 185 |
args = parser.parse_args()
|
requirements.txt
CHANGED
|
@@ -12,6 +12,8 @@ safetensors==0.2.6
|
|
| 12 |
gradio==3.16.2
|
| 13 |
altair==4.2.2
|
| 14 |
easygui==0.98.3
|
|
|
|
|
|
|
| 15 |
# for BLIP captioning
|
| 16 |
requests==2.28.2
|
| 17 |
timm==0.6.12
|
|
@@ -21,5 +23,4 @@ fairscale==0.4.13
|
|
| 21 |
tensorflow==2.10.1
|
| 22 |
huggingface-hub==0.12.0
|
| 23 |
# for kohya_ss library
|
| 24 |
-
#locon.locon_kohya
|
| 25 |
.
|
|
|
|
| 12 |
gradio==3.16.2
|
| 13 |
altair==4.2.2
|
| 14 |
easygui==0.98.3
|
| 15 |
+
toml==0.10.2
|
| 16 |
+
voluptuous==0.13.1
|
| 17 |
# for BLIP captioning
|
| 18 |
requests==2.28.2
|
| 19 |
timm==0.6.12
|
|
|
|
| 23 |
tensorflow==2.10.1
|
| 24 |
huggingface-hub==0.12.0
|
| 25 |
# for kohya_ss library
|
|
|
|
| 26 |
.
|
train_db.py
CHANGED
|
@@ -15,7 +15,11 @@ import diffusers
|
|
| 15 |
from diffusers import DDPMScheduler
|
| 16 |
|
| 17 |
import library.train_util as train_util
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
def collate_fn(examples):
|
|
@@ -33,24 +37,33 @@ def train(args):
|
|
| 33 |
|
| 34 |
tokenizer = train_util.load_tokenizer(args)
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
|
| 48 |
-
|
|
|
|
| 49 |
|
| 50 |
if args.debug_dataset:
|
| 51 |
-
train_util.debug_dataset(
|
| 52 |
return
|
| 53 |
|
|
|
|
|
|
|
|
|
|
| 54 |
# acceleratorを準備する
|
| 55 |
print("prepare accelerator")
|
| 56 |
|
|
@@ -91,7 +104,7 @@ def train(args):
|
|
| 91 |
vae.requires_grad_(False)
|
| 92 |
vae.eval()
|
| 93 |
with torch.no_grad():
|
| 94 |
-
|
| 95 |
vae.to("cpu")
|
| 96 |
if torch.cuda.is_available():
|
| 97 |
torch.cuda.empty_cache()
|
|
@@ -115,38 +128,18 @@ def train(args):
|
|
| 115 |
|
| 116 |
# 学習に必要なクラスを準備する
|
| 117 |
print("prepare optimizer, data loader etc.")
|
| 118 |
-
|
| 119 |
-
# 8-bit Adamを使う
|
| 120 |
-
if args.use_8bit_adam:
|
| 121 |
-
try:
|
| 122 |
-
import bitsandbytes as bnb
|
| 123 |
-
except ImportError:
|
| 124 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
| 125 |
-
print("use 8-bit Adam optimizer")
|
| 126 |
-
optimizer_class = bnb.optim.AdamW8bit
|
| 127 |
-
elif args.use_lion_optimizer:
|
| 128 |
-
try:
|
| 129 |
-
import lion_pytorch
|
| 130 |
-
except ImportError:
|
| 131 |
-
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
| 132 |
-
print("use Lion optimizer")
|
| 133 |
-
optimizer_class = lion_pytorch.Lion
|
| 134 |
-
else:
|
| 135 |
-
optimizer_class = torch.optim.AdamW
|
| 136 |
-
|
| 137 |
if train_text_encoder:
|
| 138 |
trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
| 139 |
else:
|
| 140 |
trainable_params = unet.parameters()
|
| 141 |
|
| 142 |
-
|
| 143 |
-
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
| 144 |
|
| 145 |
# dataloaderを準備する
|
| 146 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
| 147 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
| 148 |
train_dataloader = torch.utils.data.DataLoader(
|
| 149 |
-
|
| 150 |
|
| 151 |
# 学習ステップ数を計算する
|
| 152 |
if args.max_train_epochs is not None:
|
|
@@ -156,9 +149,10 @@ def train(args):
|
|
| 156 |
if args.stop_text_encoder_training is None:
|
| 157 |
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
| 158 |
|
| 159 |
-
# lr schedulerを用意する
|
| 160 |
-
lr_scheduler =
|
| 161 |
-
|
|
|
|
| 162 |
|
| 163 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
| 164 |
if args.full_fp16:
|
|
@@ -195,8 +189,8 @@ def train(args):
|
|
| 195 |
# 学習する
|
| 196 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 197 |
print("running training / 学習開始")
|
| 198 |
-
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {
|
| 199 |
-
print(f" num reg images / 正則化画像の数: {
|
| 200 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
| 201 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
| 202 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
|
@@ -217,7 +211,7 @@ def train(args):
|
|
| 217 |
loss_total = 0.0
|
| 218 |
for epoch in range(num_train_epochs):
|
| 219 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
| 220 |
-
|
| 221 |
|
| 222 |
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
| 223 |
unet.train()
|
|
@@ -281,12 +275,12 @@ def train(args):
|
|
| 281 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
| 282 |
|
| 283 |
accelerator.backward(loss)
|
| 284 |
-
if accelerator.sync_gradients:
|
| 285 |
if train_text_encoder:
|
| 286 |
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
| 287 |
else:
|
| 288 |
params_to_clip = unet.parameters()
|
| 289 |
-
accelerator.clip_grad_norm_(params_to_clip,
|
| 290 |
|
| 291 |
optimizer.step()
|
| 292 |
lr_scheduler.step()
|
|
@@ -297,9 +291,13 @@ def train(args):
|
|
| 297 |
progress_bar.update(1)
|
| 298 |
global_step += 1
|
| 299 |
|
|
|
|
|
|
|
| 300 |
current_loss = loss.detach().item()
|
| 301 |
if args.logging_dir is not None:
|
| 302 |
-
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
|
|
|
|
| 303 |
accelerator.log(logs, step=global_step)
|
| 304 |
|
| 305 |
if epoch == 0:
|
|
@@ -326,6 +324,8 @@ def train(args):
|
|
| 326 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
| 327 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
| 328 |
|
|
|
|
|
|
|
| 329 |
is_main_process = accelerator.is_main_process
|
| 330 |
if is_main_process:
|
| 331 |
unet = unwrap_model(unet)
|
|
@@ -352,6 +352,8 @@ if __name__ == '__main__':
|
|
| 352 |
train_util.add_dataset_arguments(parser, True, False, True)
|
| 353 |
train_util.add_training_arguments(parser, True)
|
| 354 |
train_util.add_sd_saving_arguments(parser)
|
|
|
|
|
|
|
| 355 |
|
| 356 |
parser.add_argument("--no_token_padding", action="store_true",
|
| 357 |
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
|
|
|
| 15 |
from diffusers import DDPMScheduler
|
| 16 |
|
| 17 |
import library.train_util as train_util
|
| 18 |
+
import library.config_util as config_util
|
| 19 |
+
from library.config_util import (
|
| 20 |
+
ConfigSanitizer,
|
| 21 |
+
BlueprintGenerator,
|
| 22 |
+
)
|
| 23 |
|
| 24 |
|
| 25 |
def collate_fn(examples):
|
|
|
|
| 37 |
|
| 38 |
tokenizer = train_util.load_tokenizer(args)
|
| 39 |
|
| 40 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
|
| 41 |
+
if args.dataset_config is not None:
|
| 42 |
+
print(f"Load dataset config from {args.dataset_config}")
|
| 43 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
| 44 |
+
ignored = ["train_data_dir", "reg_data_dir"]
|
| 45 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
| 46 |
+
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
| 47 |
+
else:
|
| 48 |
+
user_config = {
|
| 49 |
+
"datasets": [{
|
| 50 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
|
| 51 |
+
}]
|
| 52 |
+
}
|
| 53 |
|
| 54 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
| 55 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
| 56 |
|
| 57 |
+
if args.no_token_padding:
|
| 58 |
+
train_dataset_group.disable_token_padding()
|
| 59 |
|
| 60 |
if args.debug_dataset:
|
| 61 |
+
train_util.debug_dataset(train_dataset_group)
|
| 62 |
return
|
| 63 |
|
| 64 |
+
if cache_latents:
|
| 65 |
+
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
| 66 |
+
|
| 67 |
# acceleratorを準備する
|
| 68 |
print("prepare accelerator")
|
| 69 |
|
|
|
|
| 104 |
vae.requires_grad_(False)
|
| 105 |
vae.eval()
|
| 106 |
with torch.no_grad():
|
| 107 |
+
train_dataset_group.cache_latents(vae)
|
| 108 |
vae.to("cpu")
|
| 109 |
if torch.cuda.is_available():
|
| 110 |
torch.cuda.empty_cache()
|
|
|
|
| 128 |
|
| 129 |
# 学習に必要なクラスを準備する
|
| 130 |
print("prepare optimizer, data loader etc.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
if train_text_encoder:
|
| 132 |
trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
| 133 |
else:
|
| 134 |
trainable_params = unet.parameters()
|
| 135 |
|
| 136 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
|
|
|
| 137 |
|
| 138 |
# dataloaderを準備する
|
| 139 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
| 140 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
| 141 |
train_dataloader = torch.utils.data.DataLoader(
|
| 142 |
+
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
| 143 |
|
| 144 |
# 学習ステップ数を計算する
|
| 145 |
if args.max_train_epochs is not None:
|
|
|
|
| 149 |
if args.stop_text_encoder_training is None:
|
| 150 |
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
| 151 |
|
| 152 |
+
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
|
| 153 |
+
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
| 154 |
+
num_training_steps=args.max_train_steps,
|
| 155 |
+
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
| 156 |
|
| 157 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
| 158 |
if args.full_fp16:
|
|
|
|
| 189 |
# 学習する
|
| 190 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 191 |
print("running training / 学習開始")
|
| 192 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
| 193 |
+
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
| 194 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
| 195 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
| 196 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
|
|
|
| 211 |
loss_total = 0.0
|
| 212 |
for epoch in range(num_train_epochs):
|
| 213 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
| 214 |
+
train_dataset_group.set_current_epoch(epoch + 1)
|
| 215 |
|
| 216 |
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
| 217 |
unet.train()
|
|
|
|
| 275 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
| 276 |
|
| 277 |
accelerator.backward(loss)
|
| 278 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
| 279 |
if train_text_encoder:
|
| 280 |
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
| 281 |
else:
|
| 282 |
params_to_clip = unet.parameters()
|
| 283 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 284 |
|
| 285 |
optimizer.step()
|
| 286 |
lr_scheduler.step()
|
|
|
|
| 291 |
progress_bar.update(1)
|
| 292 |
global_step += 1
|
| 293 |
|
| 294 |
+
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
| 295 |
+
|
| 296 |
current_loss = loss.detach().item()
|
| 297 |
if args.logging_dir is not None:
|
| 298 |
+
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
| 299 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
| 300 |
+
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
| 301 |
accelerator.log(logs, step=global_step)
|
| 302 |
|
| 303 |
if epoch == 0:
|
|
|
|
| 324 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
| 325 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
| 326 |
|
| 327 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
| 328 |
+
|
| 329 |
is_main_process = accelerator.is_main_process
|
| 330 |
if is_main_process:
|
| 331 |
unet = unwrap_model(unet)
|
|
|
|
| 352 |
train_util.add_dataset_arguments(parser, True, False, True)
|
| 353 |
train_util.add_training_arguments(parser, True)
|
| 354 |
train_util.add_sd_saving_arguments(parser)
|
| 355 |
+
train_util.add_optimizer_arguments(parser)
|
| 356 |
+
config_util.add_config_arguments(parser)
|
| 357 |
|
| 358 |
parser.add_argument("--no_token_padding", action="store_true",
|
| 359 |
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
train_network.py
CHANGED
|
@@ -1,8 +1,4 @@
|
|
| 1 |
-
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
| 2 |
-
from torch.optim import Optimizer
|
| 3 |
-
from torch.cuda.amp import autocast
|
| 4 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 5 |
-
from typing import Optional, Union
|
| 6 |
import importlib
|
| 7 |
import argparse
|
| 8 |
import gc
|
|
@@ -15,92 +11,39 @@ import json
|
|
| 15 |
from tqdm import tqdm
|
| 16 |
import torch
|
| 17 |
from accelerate.utils import set_seed
|
| 18 |
-
import diffusers
|
| 19 |
from diffusers import DDPMScheduler
|
| 20 |
|
| 21 |
import library.train_util as train_util
|
| 22 |
-
from library.train_util import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
def collate_fn(examples):
|
| 26 |
return examples[0]
|
| 27 |
|
| 28 |
|
|
|
|
| 29 |
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
| 30 |
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
| 31 |
|
| 32 |
if args.network_train_unet_only:
|
| 33 |
-
logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
|
| 34 |
elif args.network_train_text_encoder_only:
|
| 35 |
-
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
| 36 |
else:
|
| 37 |
-
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
| 38 |
-
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
|
| 39 |
-
|
| 40 |
-
return logs
|
| 41 |
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
|
| 44 |
-
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
| 45 |
-
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
| 46 |
-
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def get_scheduler_fix(
|
| 50 |
-
name: Union[str, SchedulerType],
|
| 51 |
-
optimizer: Optimizer,
|
| 52 |
-
num_warmup_steps: Optional[int] = None,
|
| 53 |
-
num_training_steps: Optional[int] = None,
|
| 54 |
-
num_cycles: int = 1,
|
| 55 |
-
power: float = 1.0,
|
| 56 |
-
):
|
| 57 |
-
"""
|
| 58 |
-
Unified API to get any scheduler from its name.
|
| 59 |
-
Args:
|
| 60 |
-
name (`str` or `SchedulerType`):
|
| 61 |
-
The name of the scheduler to use.
|
| 62 |
-
optimizer (`torch.optim.Optimizer`):
|
| 63 |
-
The optimizer that will be used during training.
|
| 64 |
-
num_warmup_steps (`int`, *optional*):
|
| 65 |
-
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
| 66 |
-
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
| 67 |
-
num_training_steps (`int``, *optional*):
|
| 68 |
-
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
| 69 |
-
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
| 70 |
-
num_cycles (`int`, *optional*):
|
| 71 |
-
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
| 72 |
-
power (`float`, *optional*, defaults to 1.0):
|
| 73 |
-
Power factor. See `POLYNOMIAL` scheduler
|
| 74 |
-
last_epoch (`int`, *optional*, defaults to -1):
|
| 75 |
-
The index of the last epoch when resuming training.
|
| 76 |
-
"""
|
| 77 |
-
name = SchedulerType(name)
|
| 78 |
-
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
| 79 |
-
if name == SchedulerType.CONSTANT:
|
| 80 |
-
return schedule_func(optimizer)
|
| 81 |
-
|
| 82 |
-
# All other schedulers require `num_warmup_steps`
|
| 83 |
-
if num_warmup_steps is None:
|
| 84 |
-
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
| 85 |
-
|
| 86 |
-
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
| 87 |
-
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
| 88 |
-
|
| 89 |
-
# All other schedulers require `num_training_steps`
|
| 90 |
-
if num_training_steps is None:
|
| 91 |
-
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
| 92 |
-
|
| 93 |
-
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
| 94 |
-
return schedule_func(
|
| 95 |
-
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
| 96 |
-
)
|
| 97 |
-
|
| 98 |
-
if name == SchedulerType.POLYNOMIAL:
|
| 99 |
-
return schedule_func(
|
| 100 |
-
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
| 101 |
-
)
|
| 102 |
-
|
| 103 |
-
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
| 104 |
|
| 105 |
|
| 106 |
def train(args):
|
|
@@ -111,6 +54,7 @@ def train(args):
|
|
| 111 |
|
| 112 |
cache_latents = args.cache_latents
|
| 113 |
use_dreambooth_method = args.in_json is None
|
|
|
|
| 114 |
|
| 115 |
if args.seed is not None:
|
| 116 |
set_seed(args.seed)
|
|
@@ -118,38 +62,51 @@ def train(args):
|
|
| 118 |
tokenizer = train_util.load_tokenizer(args)
|
| 119 |
|
| 120 |
# データセットを準備する
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
else:
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
if args.debug_dataset:
|
| 144 |
-
train_util.debug_dataset(
|
| 145 |
return
|
| 146 |
-
if len(
|
| 147 |
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
| 148 |
return
|
| 149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
# acceleratorを準備する
|
| 151 |
print("prepare accelerator")
|
| 152 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
|
|
|
| 153 |
|
| 154 |
# mixed precisionに対応した型を用意しておき適宜castする
|
| 155 |
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
|
@@ -161,7 +118,7 @@ def train(args):
|
|
| 161 |
if args.lowram:
|
| 162 |
text_encoder.to("cuda")
|
| 163 |
unet.to("cuda")
|
| 164 |
-
|
| 165 |
# モデルに xformers とか memory efficient attention を組み込む
|
| 166 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
| 167 |
|
|
@@ -171,13 +128,15 @@ def train(args):
|
|
| 171 |
vae.requires_grad_(False)
|
| 172 |
vae.eval()
|
| 173 |
with torch.no_grad():
|
| 174 |
-
|
| 175 |
vae.to("cpu")
|
| 176 |
if torch.cuda.is_available():
|
| 177 |
torch.cuda.empty_cache()
|
| 178 |
gc.collect()
|
| 179 |
|
| 180 |
# prepare network
|
|
|
|
|
|
|
| 181 |
print("import network module:", args.network_module)
|
| 182 |
network_module = importlib.import_module(args.network_module)
|
| 183 |
|
|
@@ -208,48 +167,25 @@ def train(args):
|
|
| 208 |
# 学習に必要なクラスを準備する
|
| 209 |
print("prepare optimizer, data loader etc.")
|
| 210 |
|
| 211 |
-
# 8-bit Adamを使う
|
| 212 |
-
if args.use_8bit_adam:
|
| 213 |
-
try:
|
| 214 |
-
import bitsandbytes as bnb
|
| 215 |
-
except ImportError:
|
| 216 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
| 217 |
-
print("use 8-bit Adam optimizer")
|
| 218 |
-
optimizer_class = bnb.optim.AdamW8bit
|
| 219 |
-
elif args.use_lion_optimizer:
|
| 220 |
-
try:
|
| 221 |
-
import lion_pytorch
|
| 222 |
-
except ImportError:
|
| 223 |
-
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
| 224 |
-
print("use Lion optimizer")
|
| 225 |
-
optimizer_class = lion_pytorch.Lion
|
| 226 |
-
else:
|
| 227 |
-
optimizer_class = torch.optim.AdamW
|
| 228 |
-
|
| 229 |
-
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
| 230 |
-
|
| 231 |
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
| 232 |
-
|
| 233 |
-
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
| 234 |
-
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
| 235 |
|
| 236 |
# dataloaderを準備する
|
| 237 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
| 238 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
| 239 |
train_dataloader = torch.utils.data.DataLoader(
|
| 240 |
-
|
| 241 |
|
| 242 |
# 学習ステップ数を計算する
|
| 243 |
if args.max_train_epochs is not None:
|
| 244 |
-
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
| 245 |
-
|
|
|
|
| 246 |
|
| 247 |
# lr schedulerを用意する
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
| 252 |
-
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
| 253 |
|
| 254 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
| 255 |
if args.full_fp16:
|
|
@@ -317,17 +253,21 @@ def train(args):
|
|
| 317 |
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
| 318 |
|
| 319 |
# 学習する
|
|
|
|
| 320 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
|
|
|
|
|
|
|
|
|
| 331 |
metadata = {
|
| 332 |
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
| 333 |
"ss_training_started_at": training_started_at, # unix timestamp
|
|
@@ -335,12 +275,10 @@ def train(args):
|
|
| 335 |
"ss_learning_rate": args.learning_rate,
|
| 336 |
"ss_text_encoder_lr": args.text_encoder_lr,
|
| 337 |
"ss_unet_lr": args.unet_lr,
|
| 338 |
-
"ss_num_train_images":
|
| 339 |
-
"ss_num_reg_images":
|
| 340 |
"ss_num_batches_per_epoch": len(train_dataloader),
|
| 341 |
"ss_num_epochs": num_train_epochs,
|
| 342 |
-
"ss_batch_size_per_device": args.train_batch_size,
|
| 343 |
-
"ss_total_batch_size": total_batch_size,
|
| 344 |
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
| 345 |
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
| 346 |
"ss_max_train_steps": args.max_train_steps,
|
|
@@ -352,33 +290,156 @@ def train(args):
|
|
| 352 |
"ss_mixed_precision": args.mixed_precision,
|
| 353 |
"ss_full_fp16": bool(args.full_fp16),
|
| 354 |
"ss_v2": bool(args.v2),
|
| 355 |
-
"ss_resolution": args.resolution,
|
| 356 |
"ss_clip_skip": args.clip_skip,
|
| 357 |
"ss_max_token_length": args.max_token_length,
|
| 358 |
-
"ss_color_aug": bool(args.color_aug),
|
| 359 |
-
"ss_flip_aug": bool(args.flip_aug),
|
| 360 |
-
"ss_random_crop": bool(args.random_crop),
|
| 361 |
-
"ss_shuffle_caption": bool(args.shuffle_caption),
|
| 362 |
"ss_cache_latents": bool(args.cache_latents),
|
| 363 |
-
"ss_enable_bucket": bool(train_dataset.enable_bucket),
|
| 364 |
-
"ss_min_bucket_reso": train_dataset.min_bucket_reso,
|
| 365 |
-
"ss_max_bucket_reso": train_dataset.max_bucket_reso,
|
| 366 |
"ss_seed": args.seed,
|
| 367 |
-
"
|
| 368 |
"ss_noise_offset": args.noise_offset,
|
| 369 |
-
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
| 370 |
-
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
| 371 |
-
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
|
| 372 |
-
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
|
| 373 |
"ss_training_comment": args.training_comment, # will not be updated after training
|
| 374 |
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
| 375 |
-
"ss_optimizer": optimizer_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
}
|
| 377 |
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
if args.pretrained_model_name_or_path is not None:
|
| 383 |
sd_model_name = args.pretrained_model_name_or_path
|
| 384 |
if os.path.exists(sd_model_name):
|
|
@@ -397,6 +458,13 @@ def train(args):
|
|
| 397 |
|
| 398 |
metadata = {k: str(v) for k, v in metadata.items()}
|
| 399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
| 401 |
global_step = 0
|
| 402 |
|
|
@@ -409,8 +477,9 @@ def train(args):
|
|
| 409 |
loss_list = []
|
| 410 |
loss_total = 0.0
|
| 411 |
for epoch in range(num_train_epochs):
|
| 412 |
-
|
| 413 |
-
|
|
|
|
| 414 |
|
| 415 |
metadata["ss_epoch"] = str(epoch+1)
|
| 416 |
|
|
@@ -447,7 +516,7 @@ def train(args):
|
|
| 447 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 448 |
|
| 449 |
# Predict the noise residual
|
| 450 |
-
with autocast():
|
| 451 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
| 452 |
|
| 453 |
if args.v_parameterization:
|
|
@@ -465,9 +534,9 @@ def train(args):
|
|
| 465 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
| 466 |
|
| 467 |
accelerator.backward(loss)
|
| 468 |
-
if accelerator.sync_gradients:
|
| 469 |
params_to_clip = network.get_trainable_params()
|
| 470 |
-
accelerator.clip_grad_norm_(params_to_clip,
|
| 471 |
|
| 472 |
optimizer.step()
|
| 473 |
lr_scheduler.step()
|
|
@@ -478,6 +547,8 @@ def train(args):
|
|
| 478 |
progress_bar.update(1)
|
| 479 |
global_step += 1
|
| 480 |
|
|
|
|
|
|
|
| 481 |
current_loss = loss.detach().item()
|
| 482 |
if epoch == 0:
|
| 483 |
loss_list.append(current_loss)
|
|
@@ -508,8 +579,9 @@ def train(args):
|
|
| 508 |
def save_func():
|
| 509 |
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
| 510 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
|
|
|
| 511 |
print(f"saving checkpoint: {ckpt_file}")
|
| 512 |
-
unwrap_model(network).save_weights(ckpt_file, save_dtype,
|
| 513 |
|
| 514 |
def remove_old_func(old_epoch_no):
|
| 515 |
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
|
@@ -518,15 +590,18 @@ def train(args):
|
|
| 518 |
print(f"removing old checkpoint: {old_ckpt_file}")
|
| 519 |
os.remove(old_ckpt_file)
|
| 520 |
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
|
|
|
|
|
|
|
|
|
| 524 |
|
| 525 |
# end of epoch
|
| 526 |
|
| 527 |
metadata["ss_epoch"] = str(num_train_epochs)
|
|
|
|
| 528 |
|
| 529 |
-
is_main_process = accelerator.is_main_process
|
| 530 |
if is_main_process:
|
| 531 |
network = unwrap_model(network)
|
| 532 |
|
|
@@ -545,7 +620,7 @@ def train(args):
|
|
| 545 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
| 546 |
|
| 547 |
print(f"save trained model to {ckpt_file}")
|
| 548 |
-
network.save_weights(ckpt_file, save_dtype,
|
| 549 |
print("model saved.")
|
| 550 |
|
| 551 |
|
|
@@ -555,6 +630,8 @@ if __name__ == '__main__':
|
|
| 555 |
train_util.add_sd_models_arguments(parser)
|
| 556 |
train_util.add_dataset_arguments(parser, True, True, True)
|
| 557 |
train_util.add_training_arguments(parser, True)
|
|
|
|
|
|
|
| 558 |
|
| 559 |
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
| 560 |
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
|
@@ -562,10 +639,6 @@ if __name__ == '__main__':
|
|
| 562 |
|
| 563 |
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
| 564 |
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
| 565 |
-
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
| 566 |
-
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
| 567 |
-
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
| 568 |
-
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
| 569 |
|
| 570 |
parser.add_argument("--network_weights", type=str, default=None,
|
| 571 |
help="pretrained weights for network / 学習するネットワークの初期重み")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
| 2 |
import importlib
|
| 3 |
import argparse
|
| 4 |
import gc
|
|
|
|
| 11 |
from tqdm import tqdm
|
| 12 |
import torch
|
| 13 |
from accelerate.utils import set_seed
|
|
|
|
| 14 |
from diffusers import DDPMScheduler
|
| 15 |
|
| 16 |
import library.train_util as train_util
|
| 17 |
+
from library.train_util import (
|
| 18 |
+
DreamBoothDataset,
|
| 19 |
+
)
|
| 20 |
+
import library.config_util as config_util
|
| 21 |
+
from library.config_util import (
|
| 22 |
+
ConfigSanitizer,
|
| 23 |
+
BlueprintGenerator,
|
| 24 |
+
)
|
| 25 |
|
| 26 |
|
| 27 |
def collate_fn(examples):
|
| 28 |
return examples[0]
|
| 29 |
|
| 30 |
|
| 31 |
+
# TODO 他のスクリプトと共通化する
|
| 32 |
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
| 33 |
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
| 34 |
|
| 35 |
if args.network_train_unet_only:
|
| 36 |
+
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
|
| 37 |
elif args.network_train_text_encoder_only:
|
| 38 |
+
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
| 39 |
else:
|
| 40 |
+
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
| 41 |
+
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
|
|
|
|
|
|
|
| 42 |
|
| 43 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
|
| 44 |
+
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
|
| 45 |
|
| 46 |
+
return logs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
def train(args):
|
|
|
|
| 54 |
|
| 55 |
cache_latents = args.cache_latents
|
| 56 |
use_dreambooth_method = args.in_json is None
|
| 57 |
+
use_user_config = args.dataset_config is not None
|
| 58 |
|
| 59 |
if args.seed is not None:
|
| 60 |
set_seed(args.seed)
|
|
|
|
| 62 |
tokenizer = train_util.load_tokenizer(args)
|
| 63 |
|
| 64 |
# データセットを準備する
|
| 65 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
|
| 66 |
+
if use_user_config:
|
| 67 |
+
print(f"Load dataset config from {args.dataset_config}")
|
| 68 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
| 69 |
+
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
| 70 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
| 71 |
+
print(
|
| 72 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
| 73 |
else:
|
| 74 |
+
if use_dreambooth_method:
|
| 75 |
+
print("Use DreamBooth method.")
|
| 76 |
+
user_config = {
|
| 77 |
+
"datasets": [{
|
| 78 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
|
| 79 |
+
}]
|
| 80 |
+
}
|
| 81 |
+
else:
|
| 82 |
+
print("Train with captions.")
|
| 83 |
+
user_config = {
|
| 84 |
+
"datasets": [{
|
| 85 |
+
"subsets": [{
|
| 86 |
+
"image_dir": args.train_data_dir,
|
| 87 |
+
"metadata_file": args.in_json,
|
| 88 |
+
}]
|
| 89 |
+
}]
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
| 93 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
| 94 |
|
| 95 |
if args.debug_dataset:
|
| 96 |
+
train_util.debug_dataset(train_dataset_group)
|
| 97 |
return
|
| 98 |
+
if len(train_dataset_group) == 0:
|
| 99 |
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
| 100 |
return
|
| 101 |
|
| 102 |
+
if cache_latents:
|
| 103 |
+
assert train_dataset_group.is_latent_cacheable(
|
| 104 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
| 105 |
+
|
| 106 |
# acceleratorを準備する
|
| 107 |
print("prepare accelerator")
|
| 108 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
| 109 |
+
is_main_process = accelerator.is_main_process
|
| 110 |
|
| 111 |
# mixed precisionに対応した型を用意しておき適宜castする
|
| 112 |
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
|
|
|
| 118 |
if args.lowram:
|
| 119 |
text_encoder.to("cuda")
|
| 120 |
unet.to("cuda")
|
| 121 |
+
|
| 122 |
# モデルに xformers とか memory efficient attention を組み込む
|
| 123 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
| 124 |
|
|
|
|
| 128 |
vae.requires_grad_(False)
|
| 129 |
vae.eval()
|
| 130 |
with torch.no_grad():
|
| 131 |
+
train_dataset_group.cache_latents(vae)
|
| 132 |
vae.to("cpu")
|
| 133 |
if torch.cuda.is_available():
|
| 134 |
torch.cuda.empty_cache()
|
| 135 |
gc.collect()
|
| 136 |
|
| 137 |
# prepare network
|
| 138 |
+
import sys
|
| 139 |
+
sys.path.append(os.path.dirname(__file__))
|
| 140 |
print("import network module:", args.network_module)
|
| 141 |
network_module = importlib.import_module(args.network_module)
|
| 142 |
|
|
|
|
| 167 |
# 学習に必要なクラスを準備する
|
| 168 |
print("prepare optimizer, data loader etc.")
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
| 171 |
+
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
|
|
|
|
|
|
| 172 |
|
| 173 |
# dataloaderを準備する
|
| 174 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
| 175 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
| 176 |
train_dataloader = torch.utils.data.DataLoader(
|
| 177 |
+
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
| 178 |
|
| 179 |
# 学習ステップ数を計算する
|
| 180 |
if args.max_train_epochs is not None:
|
| 181 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes)
|
| 182 |
+
if is_main_process:
|
| 183 |
+
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
| 184 |
|
| 185 |
# lr schedulerを用意する
|
| 186 |
+
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
| 187 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps,
|
| 188 |
+
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
|
|
|
|
|
|
| 189 |
|
| 190 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
| 191 |
if args.full_fp16:
|
|
|
|
| 253 |
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
| 254 |
|
| 255 |
# 学習する
|
| 256 |
+
# TODO: find a way to handle total batch size when there are multiple datasets
|
| 257 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 258 |
+
|
| 259 |
+
if is_main_process:
|
| 260 |
+
print("running training / 学習開始")
|
| 261 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
| 262 |
+
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
| 263 |
+
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
| 264 |
+
print(f" num epochs / epoch数: {num_train_epochs}")
|
| 265 |
+
print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
| 266 |
+
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
| 267 |
+
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
| 268 |
+
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
| 269 |
+
|
| 270 |
+
# TODO refactor metadata creation and move to util
|
| 271 |
metadata = {
|
| 272 |
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
| 273 |
"ss_training_started_at": training_started_at, # unix timestamp
|
|
|
|
| 275 |
"ss_learning_rate": args.learning_rate,
|
| 276 |
"ss_text_encoder_lr": args.text_encoder_lr,
|
| 277 |
"ss_unet_lr": args.unet_lr,
|
| 278 |
+
"ss_num_train_images": train_dataset_group.num_train_images,
|
| 279 |
+
"ss_num_reg_images": train_dataset_group.num_reg_images,
|
| 280 |
"ss_num_batches_per_epoch": len(train_dataloader),
|
| 281 |
"ss_num_epochs": num_train_epochs,
|
|
|
|
|
|
|
| 282 |
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
| 283 |
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
| 284 |
"ss_max_train_steps": args.max_train_steps,
|
|
|
|
| 290 |
"ss_mixed_precision": args.mixed_precision,
|
| 291 |
"ss_full_fp16": bool(args.full_fp16),
|
| 292 |
"ss_v2": bool(args.v2),
|
|
|
|
| 293 |
"ss_clip_skip": args.clip_skip,
|
| 294 |
"ss_max_token_length": args.max_token_length,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
"ss_cache_latents": bool(args.cache_latents),
|
|
|
|
|
|
|
|
|
|
| 296 |
"ss_seed": args.seed,
|
| 297 |
+
"ss_lowram": args.lowram,
|
| 298 |
"ss_noise_offset": args.noise_offset,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
"ss_training_comment": args.training_comment, # will not be updated after training
|
| 300 |
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
| 301 |
+
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
|
| 302 |
+
"ss_max_grad_norm": args.max_grad_norm,
|
| 303 |
+
"ss_caption_dropout_rate": args.caption_dropout_rate,
|
| 304 |
+
"ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
|
| 305 |
+
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
|
| 306 |
+
"ss_face_crop_aug_range": args.face_crop_aug_range,
|
| 307 |
+
"ss_prior_loss_weight": args.prior_loss_weight,
|
| 308 |
}
|
| 309 |
|
| 310 |
+
if use_user_config:
|
| 311 |
+
# save metadata of multiple datasets
|
| 312 |
+
# NOTE: pack "ss_datasets" value as json one time
|
| 313 |
+
# or should also pack nested collections as json?
|
| 314 |
+
datasets_metadata = []
|
| 315 |
+
tag_frequency = {} # merge tag frequency for metadata editor
|
| 316 |
+
dataset_dirs_info = {} # merge subset dirs for metadata editor
|
| 317 |
+
|
| 318 |
+
for dataset in train_dataset_group.datasets:
|
| 319 |
+
is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
|
| 320 |
+
dataset_metadata = {
|
| 321 |
+
"is_dreambooth": is_dreambooth_dataset,
|
| 322 |
+
"batch_size_per_device": dataset.batch_size,
|
| 323 |
+
"num_train_images": dataset.num_train_images, # includes repeating
|
| 324 |
+
"num_reg_images": dataset.num_reg_images,
|
| 325 |
+
"resolution": (dataset.width, dataset.height),
|
| 326 |
+
"enable_bucket": bool(dataset.enable_bucket),
|
| 327 |
+
"min_bucket_reso": dataset.min_bucket_reso,
|
| 328 |
+
"max_bucket_reso": dataset.max_bucket_reso,
|
| 329 |
+
"tag_frequency": dataset.tag_frequency,
|
| 330 |
+
"bucket_info": dataset.bucket_info,
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
subsets_metadata = []
|
| 334 |
+
for subset in dataset.subsets:
|
| 335 |
+
subset_metadata = {
|
| 336 |
+
"img_count": subset.img_count,
|
| 337 |
+
"num_repeats": subset.num_repeats,
|
| 338 |
+
"color_aug": bool(subset.color_aug),
|
| 339 |
+
"flip_aug": bool(subset.flip_aug),
|
| 340 |
+
"random_crop": bool(subset.random_crop),
|
| 341 |
+
"shuffle_caption": bool(subset.shuffle_caption),
|
| 342 |
+
"keep_tokens": subset.keep_tokens,
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
image_dir_or_metadata_file = None
|
| 346 |
+
if subset.image_dir:
|
| 347 |
+
image_dir = os.path.basename(subset.image_dir)
|
| 348 |
+
subset_metadata["image_dir"] = image_dir
|
| 349 |
+
image_dir_or_metadata_file = image_dir
|
| 350 |
+
|
| 351 |
+
if is_dreambooth_dataset:
|
| 352 |
+
subset_metadata["class_tokens"] = subset.class_tokens
|
| 353 |
+
subset_metadata["is_reg"] = subset.is_reg
|
| 354 |
+
if subset.is_reg:
|
| 355 |
+
image_dir_or_metadata_file = None # not merging reg dataset
|
| 356 |
+
else:
|
| 357 |
+
metadata_file = os.path.basename(subset.metadata_file)
|
| 358 |
+
subset_metadata["metadata_file"] = metadata_file
|
| 359 |
+
image_dir_or_metadata_file = metadata_file # may overwrite
|
| 360 |
+
|
| 361 |
+
subsets_metadata.append(subset_metadata)
|
| 362 |
+
|
| 363 |
+
# merge dataset dir: not reg subset only
|
| 364 |
+
# TODO update additional-network extension to show detailed dataset config from metadata
|
| 365 |
+
if image_dir_or_metadata_file is not None:
|
| 366 |
+
# datasets may have a certain dir multiple times
|
| 367 |
+
v = image_dir_or_metadata_file
|
| 368 |
+
i = 2
|
| 369 |
+
while v in dataset_dirs_info:
|
| 370 |
+
v = image_dir_or_metadata_file + f" ({i})"
|
| 371 |
+
i += 1
|
| 372 |
+
image_dir_or_metadata_file = v
|
| 373 |
+
|
| 374 |
+
dataset_dirs_info[image_dir_or_metadata_file] = {
|
| 375 |
+
"n_repeats": subset.num_repeats,
|
| 376 |
+
"img_count": subset.img_count
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
dataset_metadata["subsets"] = subsets_metadata
|
| 380 |
+
datasets_metadata.append(dataset_metadata)
|
| 381 |
+
|
| 382 |
+
# merge tag frequency:
|
| 383 |
+
for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
|
| 384 |
+
# あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
|
| 385 |
+
# もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
|
| 386 |
+
# なので、ここで複数datasetの回数を合算してもあまり意味はない
|
| 387 |
+
if ds_dir_name in tag_frequency:
|
| 388 |
+
continue
|
| 389 |
+
tag_frequency[ds_dir_name] = ds_freq_for_dir
|
| 390 |
+
|
| 391 |
+
metadata["ss_datasets"] = json.dumps(datasets_metadata)
|
| 392 |
+
metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
|
| 393 |
+
metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
|
| 394 |
+
else:
|
| 395 |
+
# conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
|
| 396 |
+
assert len(
|
| 397 |
+
train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
|
| 398 |
+
|
| 399 |
+
dataset = train_dataset_group.datasets[0]
|
| 400 |
+
|
| 401 |
+
dataset_dirs_info = {}
|
| 402 |
+
reg_dataset_dirs_info = {}
|
| 403 |
+
if use_dreambooth_method:
|
| 404 |
+
for subset in dataset.subsets:
|
| 405 |
+
info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
|
| 406 |
+
info[os.path.basename(subset.image_dir)] = {
|
| 407 |
+
"n_repeats": subset.num_repeats,
|
| 408 |
+
"img_count": subset.img_count
|
| 409 |
+
}
|
| 410 |
+
else:
|
| 411 |
+
for subset in dataset.subsets:
|
| 412 |
+
dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
|
| 413 |
+
"n_repeats": subset.num_repeats,
|
| 414 |
+
"img_count": subset.img_count
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
metadata.update({
|
| 418 |
+
"ss_batch_size_per_device": args.train_batch_size,
|
| 419 |
+
"ss_total_batch_size": total_batch_size,
|
| 420 |
+
"ss_resolution": args.resolution,
|
| 421 |
+
"ss_color_aug": bool(args.color_aug),
|
| 422 |
+
"ss_flip_aug": bool(args.flip_aug),
|
| 423 |
+
"ss_random_crop": bool(args.random_crop),
|
| 424 |
+
"ss_shuffle_caption": bool(args.shuffle_caption),
|
| 425 |
+
"ss_enable_bucket": bool(dataset.enable_bucket),
|
| 426 |
+
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
|
| 427 |
+
"ss_min_bucket_reso": dataset.min_bucket_reso,
|
| 428 |
+
"ss_max_bucket_reso": dataset.max_bucket_reso,
|
| 429 |
+
"ss_keep_tokens": args.keep_tokens,
|
| 430 |
+
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
|
| 431 |
+
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
|
| 432 |
+
"ss_tag_frequency": json.dumps(dataset.tag_frequency),
|
| 433 |
+
"ss_bucket_info": json.dumps(dataset.bucket_info),
|
| 434 |
+
})
|
| 435 |
+
|
| 436 |
+
# add extra args
|
| 437 |
+
if args.network_args:
|
| 438 |
+
metadata["ss_network_args"] = json.dumps(net_kwargs)
|
| 439 |
+
# for key, value in net_kwargs.items():
|
| 440 |
+
# metadata["ss_arg_" + key] = value
|
| 441 |
+
|
| 442 |
+
# model name and hash
|
| 443 |
if args.pretrained_model_name_or_path is not None:
|
| 444 |
sd_model_name = args.pretrained_model_name_or_path
|
| 445 |
if os.path.exists(sd_model_name):
|
|
|
|
| 458 |
|
| 459 |
metadata = {k: str(v) for k, v in metadata.items()}
|
| 460 |
|
| 461 |
+
# make minimum metadata for filtering
|
| 462 |
+
minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"]
|
| 463 |
+
minimum_metadata = {}
|
| 464 |
+
for key in minimum_keys:
|
| 465 |
+
if key in metadata:
|
| 466 |
+
minimum_metadata[key] = metadata[key]
|
| 467 |
+
|
| 468 |
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
| 469 |
global_step = 0
|
| 470 |
|
|
|
|
| 477 |
loss_list = []
|
| 478 |
loss_total = 0.0
|
| 479 |
for epoch in range(num_train_epochs):
|
| 480 |
+
if is_main_process:
|
| 481 |
+
print(f"epoch {epoch+1}/{num_train_epochs}")
|
| 482 |
+
train_dataset_group.set_current_epoch(epoch + 1)
|
| 483 |
|
| 484 |
metadata["ss_epoch"] = str(epoch+1)
|
| 485 |
|
|
|
|
| 516 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 517 |
|
| 518 |
# Predict the noise residual
|
| 519 |
+
with accelerator.autocast():
|
| 520 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
| 521 |
|
| 522 |
if args.v_parameterization:
|
|
|
|
| 534 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
| 535 |
|
| 536 |
accelerator.backward(loss)
|
| 537 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
| 538 |
params_to_clip = network.get_trainable_params()
|
| 539 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 540 |
|
| 541 |
optimizer.step()
|
| 542 |
lr_scheduler.step()
|
|
|
|
| 547 |
progress_bar.update(1)
|
| 548 |
global_step += 1
|
| 549 |
|
| 550 |
+
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
| 551 |
+
|
| 552 |
current_loss = loss.detach().item()
|
| 553 |
if epoch == 0:
|
| 554 |
loss_list.append(current_loss)
|
|
|
|
| 579 |
def save_func():
|
| 580 |
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
| 581 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
| 582 |
+
metadata["ss_training_finished_at"] = str(time.time())
|
| 583 |
print(f"saving checkpoint: {ckpt_file}")
|
| 584 |
+
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
| 585 |
|
| 586 |
def remove_old_func(old_epoch_no):
|
| 587 |
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
|
|
|
| 590 |
print(f"removing old checkpoint: {old_ckpt_file}")
|
| 591 |
os.remove(old_ckpt_file)
|
| 592 |
|
| 593 |
+
if is_main_process:
|
| 594 |
+
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
| 595 |
+
if saving and args.save_state:
|
| 596 |
+
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
| 597 |
+
|
| 598 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
| 599 |
|
| 600 |
# end of epoch
|
| 601 |
|
| 602 |
metadata["ss_epoch"] = str(num_train_epochs)
|
| 603 |
+
metadata["ss_training_finished_at"] = str(time.time())
|
| 604 |
|
|
|
|
| 605 |
if is_main_process:
|
| 606 |
network = unwrap_model(network)
|
| 607 |
|
|
|
|
| 620 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
| 621 |
|
| 622 |
print(f"save trained model to {ckpt_file}")
|
| 623 |
+
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
| 624 |
print("model saved.")
|
| 625 |
|
| 626 |
|
|
|
|
| 630 |
train_util.add_sd_models_arguments(parser)
|
| 631 |
train_util.add_dataset_arguments(parser, True, True, True)
|
| 632 |
train_util.add_training_arguments(parser, True)
|
| 633 |
+
train_util.add_optimizer_arguments(parser)
|
| 634 |
+
config_util.add_config_arguments(parser)
|
| 635 |
|
| 636 |
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
| 637 |
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
|
|
|
| 639 |
|
| 640 |
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
| 641 |
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 642 |
|
| 643 |
parser.add_argument("--network_weights", type=str, default=None,
|
| 644 |
help="pretrained weights for network / 学習するネットワークの初期重み")
|
train_textual_inversion.py
CHANGED
|
@@ -11,7 +11,11 @@ import diffusers
|
|
| 11 |
from diffusers import DDPMScheduler
|
| 12 |
|
| 13 |
import library.train_util as train_util
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
imagenet_templates_small = [
|
| 17 |
"a photo of a {}",
|
|
@@ -79,7 +83,6 @@ def train(args):
|
|
| 79 |
train_util.prepare_dataset_args(args, True)
|
| 80 |
|
| 81 |
cache_latents = args.cache_latents
|
| 82 |
-
use_dreambooth_method = args.in_json is None
|
| 83 |
|
| 84 |
if args.seed is not None:
|
| 85 |
set_seed(args.seed)
|
|
@@ -139,21 +142,35 @@ def train(args):
|
|
| 139 |
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
| 140 |
|
| 141 |
# データセットを準備する
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
else:
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
| 159 |
if use_template:
|
|
@@ -163,20 +180,30 @@ def train(args):
|
|
| 163 |
captions = []
|
| 164 |
for tmpl in templates:
|
| 165 |
captions.append(tmpl.format(replace_to))
|
| 166 |
-
|
| 167 |
-
elif args.num_vectors_per_token > 1:
|
| 168 |
-
replace_to = " ".join(token_strings)
|
| 169 |
-
train_dataset.add_replacement(args.token_string, replace_to)
|
| 170 |
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
if args.debug_dataset:
|
| 174 |
-
train_util.debug_dataset(
|
| 175 |
return
|
| 176 |
-
if len(
|
| 177 |
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
| 178 |
return
|
| 179 |
|
|
|
|
|
|
|
|
|
|
| 180 |
# モデルに xformers とか memory efficient attention を組み込む
|
| 181 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
| 182 |
|
|
@@ -186,7 +213,7 @@ def train(args):
|
|
| 186 |
vae.requires_grad_(False)
|
| 187 |
vae.eval()
|
| 188 |
with torch.no_grad():
|
| 189 |
-
|
| 190 |
vae.to("cpu")
|
| 191 |
if torch.cuda.is_available():
|
| 192 |
torch.cuda.empty_cache()
|
|
@@ -198,35 +225,14 @@ def train(args):
|
|
| 198 |
|
| 199 |
# 学習に必要なクラスを準備する
|
| 200 |
print("prepare optimizer, data loader etc.")
|
| 201 |
-
|
| 202 |
-
# 8-bit Adamを使う
|
| 203 |
-
if args.use_8bit_adam:
|
| 204 |
-
try:
|
| 205 |
-
import bitsandbytes as bnb
|
| 206 |
-
except ImportError:
|
| 207 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
| 208 |
-
print("use 8-bit Adam optimizer")
|
| 209 |
-
optimizer_class = bnb.optim.AdamW8bit
|
| 210 |
-
elif args.use_lion_optimizer:
|
| 211 |
-
try:
|
| 212 |
-
import lion_pytorch
|
| 213 |
-
except ImportError:
|
| 214 |
-
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
| 215 |
-
print("use Lion optimizer")
|
| 216 |
-
optimizer_class = lion_pytorch.Lion
|
| 217 |
-
else:
|
| 218 |
-
optimizer_class = torch.optim.AdamW
|
| 219 |
-
|
| 220 |
trainable_params = text_encoder.get_input_embeddings().parameters()
|
| 221 |
-
|
| 222 |
-
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
| 223 |
-
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
| 224 |
|
| 225 |
# dataloaderを準備する
|
| 226 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
| 227 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
| 228 |
train_dataloader = torch.utils.data.DataLoader(
|
| 229 |
-
|
| 230 |
|
| 231 |
# 学習ステップ数を計算する
|
| 232 |
if args.max_train_epochs is not None:
|
|
@@ -234,8 +240,9 @@ def train(args):
|
|
| 234 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
| 235 |
|
| 236 |
# lr schedulerを用意する
|
| 237 |
-
lr_scheduler =
|
| 238 |
-
|
|
|
|
| 239 |
|
| 240 |
# acceleratorがなんかよろしくやってくれるらしい
|
| 241 |
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
|
@@ -283,8 +290,8 @@ def train(args):
|
|
| 283 |
# 学習する
|
| 284 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 285 |
print("running training / 学習開始")
|
| 286 |
-
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {
|
| 287 |
-
print(f" num reg images / 正則化画像の数: {
|
| 288 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
| 289 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
| 290 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
|
@@ -303,12 +310,11 @@ def train(args):
|
|
| 303 |
|
| 304 |
for epoch in range(num_train_epochs):
|
| 305 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
| 306 |
-
|
| 307 |
|
| 308 |
text_encoder.train()
|
| 309 |
|
| 310 |
loss_total = 0
|
| 311 |
-
bef_epo_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
| 312 |
for step, batch in enumerate(train_dataloader):
|
| 313 |
with accelerator.accumulate(text_encoder):
|
| 314 |
with torch.no_grad():
|
|
@@ -357,9 +363,9 @@ def train(args):
|
|
| 357 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
| 358 |
|
| 359 |
accelerator.backward(loss)
|
| 360 |
-
if accelerator.sync_gradients:
|
| 361 |
params_to_clip = text_encoder.get_input_embeddings().parameters()
|
| 362 |
-
accelerator.clip_grad_norm_(params_to_clip,
|
| 363 |
|
| 364 |
optimizer.step()
|
| 365 |
lr_scheduler.step()
|
|
@@ -374,9 +380,14 @@ def train(args):
|
|
| 374 |
progress_bar.update(1)
|
| 375 |
global_step += 1
|
| 376 |
|
|
|
|
|
|
|
|
|
|
| 377 |
current_loss = loss.detach().item()
|
| 378 |
if args.logging_dir is not None:
|
| 379 |
-
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
|
|
|
|
| 380 |
accelerator.log(logs, step=global_step)
|
| 381 |
|
| 382 |
loss_total += current_loss
|
|
@@ -394,8 +405,6 @@ def train(args):
|
|
| 394 |
accelerator.wait_for_everyone()
|
| 395 |
|
| 396 |
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
| 397 |
-
# d = updated_embs - bef_epo_embs
|
| 398 |
-
# print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
|
| 399 |
|
| 400 |
if args.save_every_n_epochs is not None:
|
| 401 |
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
|
@@ -417,6 +426,9 @@ def train(args):
|
|
| 417 |
if saving and args.save_state:
|
| 418 |
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
| 419 |
|
|
|
|
|
|
|
|
|
|
| 420 |
# end of epoch
|
| 421 |
|
| 422 |
is_main_process = accelerator.is_main_process
|
|
@@ -491,6 +503,8 @@ if __name__ == '__main__':
|
|
| 491 |
train_util.add_sd_models_arguments(parser)
|
| 492 |
train_util.add_dataset_arguments(parser, True, True, False)
|
| 493 |
train_util.add_training_arguments(parser, True)
|
|
|
|
|
|
|
| 494 |
|
| 495 |
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
| 496 |
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|
|
|
|
| 11 |
from diffusers import DDPMScheduler
|
| 12 |
|
| 13 |
import library.train_util as train_util
|
| 14 |
+
import library.config_util as config_util
|
| 15 |
+
from library.config_util import (
|
| 16 |
+
ConfigSanitizer,
|
| 17 |
+
BlueprintGenerator,
|
| 18 |
+
)
|
| 19 |
|
| 20 |
imagenet_templates_small = [
|
| 21 |
"a photo of a {}",
|
|
|
|
| 83 |
train_util.prepare_dataset_args(args, True)
|
| 84 |
|
| 85 |
cache_latents = args.cache_latents
|
|
|
|
| 86 |
|
| 87 |
if args.seed is not None:
|
| 88 |
set_seed(args.seed)
|
|
|
|
| 142 |
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
| 143 |
|
| 144 |
# データセットを準備する
|
| 145 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
|
| 146 |
+
if args.dataset_config is not None:
|
| 147 |
+
print(f"Load dataset config from {args.dataset_config}")
|
| 148 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
| 149 |
+
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
| 150 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
| 151 |
+
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
| 152 |
else:
|
| 153 |
+
use_dreambooth_method = args.in_json is None
|
| 154 |
+
if use_dreambooth_method:
|
| 155 |
+
print("Use DreamBooth method.")
|
| 156 |
+
user_config = {
|
| 157 |
+
"datasets": [{
|
| 158 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
|
| 159 |
+
}]
|
| 160 |
+
}
|
| 161 |
+
else:
|
| 162 |
+
print("Train with captions.")
|
| 163 |
+
user_config = {
|
| 164 |
+
"datasets": [{
|
| 165 |
+
"subsets": [{
|
| 166 |
+
"image_dir": args.train_data_dir,
|
| 167 |
+
"metadata_file": args.in_json,
|
| 168 |
+
}]
|
| 169 |
+
}]
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
| 173 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
| 174 |
|
| 175 |
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
| 176 |
if use_template:
|
|
|
|
| 180 |
captions = []
|
| 181 |
for tmpl in templates:
|
| 182 |
captions.append(tmpl.format(replace_to))
|
| 183 |
+
train_dataset_group.add_replacement("", captions)
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
+
if args.num_vectors_per_token > 1:
|
| 186 |
+
prompt_replacement = (args.token_string, replace_to)
|
| 187 |
+
else:
|
| 188 |
+
prompt_replacement = None
|
| 189 |
+
else:
|
| 190 |
+
if args.num_vectors_per_token > 1:
|
| 191 |
+
replace_to = " ".join(token_strings)
|
| 192 |
+
train_dataset_group.add_replacement(args.token_string, replace_to)
|
| 193 |
+
prompt_replacement = (args.token_string, replace_to)
|
| 194 |
+
else:
|
| 195 |
+
prompt_replacement = None
|
| 196 |
|
| 197 |
if args.debug_dataset:
|
| 198 |
+
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
|
| 199 |
return
|
| 200 |
+
if len(train_dataset_group) == 0:
|
| 201 |
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
| 202 |
return
|
| 203 |
|
| 204 |
+
if cache_latents:
|
| 205 |
+
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
| 206 |
+
|
| 207 |
# モデルに xformers とか memory efficient attention を組み込む
|
| 208 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
| 209 |
|
|
|
|
| 213 |
vae.requires_grad_(False)
|
| 214 |
vae.eval()
|
| 215 |
with torch.no_grad():
|
| 216 |
+
train_dataset_group.cache_latents(vae)
|
| 217 |
vae.to("cpu")
|
| 218 |
if torch.cuda.is_available():
|
| 219 |
torch.cuda.empty_cache()
|
|
|
|
| 225 |
|
| 226 |
# 学習に必要なクラスを準備する
|
| 227 |
print("prepare optimizer, data loader etc.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
trainable_params = text_encoder.get_input_embeddings().parameters()
|
| 229 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
|
|
|
|
|
|
| 230 |
|
| 231 |
# dataloaderを準備する
|
| 232 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
| 233 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
| 234 |
train_dataloader = torch.utils.data.DataLoader(
|
| 235 |
+
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
| 236 |
|
| 237 |
# 学習ステップ数を計算する
|
| 238 |
if args.max_train_epochs is not None:
|
|
|
|
| 240 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
| 241 |
|
| 242 |
# lr schedulerを用意する
|
| 243 |
+
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
| 244 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
| 245 |
+
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
| 246 |
|
| 247 |
# acceleratorがなんかよろしくやってくれるらしい
|
| 248 |
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
|
|
|
| 290 |
# 学習する
|
| 291 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 292 |
print("running training / 学習開始")
|
| 293 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
| 294 |
+
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
| 295 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
| 296 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
| 297 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
|
|
|
| 310 |
|
| 311 |
for epoch in range(num_train_epochs):
|
| 312 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
| 313 |
+
train_dataset_group.set_current_epoch(epoch + 1)
|
| 314 |
|
| 315 |
text_encoder.train()
|
| 316 |
|
| 317 |
loss_total = 0
|
|
|
|
| 318 |
for step, batch in enumerate(train_dataloader):
|
| 319 |
with accelerator.accumulate(text_encoder):
|
| 320 |
with torch.no_grad():
|
|
|
|
| 363 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
| 364 |
|
| 365 |
accelerator.backward(loss)
|
| 366 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
| 367 |
params_to_clip = text_encoder.get_input_embeddings().parameters()
|
| 368 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 369 |
|
| 370 |
optimizer.step()
|
| 371 |
lr_scheduler.step()
|
|
|
|
| 380 |
progress_bar.update(1)
|
| 381 |
global_step += 1
|
| 382 |
|
| 383 |
+
train_util.sample_images(accelerator, args, None, global_step, accelerator.device,
|
| 384 |
+
vae, tokenizer, text_encoder, unet, prompt_replacement)
|
| 385 |
+
|
| 386 |
current_loss = loss.detach().item()
|
| 387 |
if args.logging_dir is not None:
|
| 388 |
+
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
| 389 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
| 390 |
+
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
| 391 |
accelerator.log(logs, step=global_step)
|
| 392 |
|
| 393 |
loss_total += current_loss
|
|
|
|
| 405 |
accelerator.wait_for_everyone()
|
| 406 |
|
| 407 |
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
|
|
|
|
|
|
| 408 |
|
| 409 |
if args.save_every_n_epochs is not None:
|
| 410 |
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
|
|
|
| 426 |
if saving and args.save_state:
|
| 427 |
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
| 428 |
|
| 429 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device,
|
| 430 |
+
vae, tokenizer, text_encoder, unet, prompt_replacement)
|
| 431 |
+
|
| 432 |
# end of epoch
|
| 433 |
|
| 434 |
is_main_process = accelerator.is_main_process
|
|
|
|
| 503 |
train_util.add_sd_models_arguments(parser)
|
| 504 |
train_util.add_dataset_arguments(parser, True, True, False)
|
| 505 |
train_util.add_training_arguments(parser, True)
|
| 506 |
+
train_util.add_optimizer_arguments(parser)
|
| 507 |
+
config_util.add_config_arguments(parser)
|
| 508 |
|
| 509 |
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
| 510 |
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|