abc
commited on
Commit
·
0d34cbd
1
Parent(s):
b9a80b5
Upload 44 files
Browse files- .gitattributes +1 -0
- fine_tune.py +50 -45
- gen_img_diffusers.py +234 -55
- library/model_util.py +5 -1
- library/train_util.py +853 -229
- networks/check_lora_weights.py +1 -1
- networks/extract_lora_from_models.py +44 -25
- networks/lora.py +191 -30
- networks/merge_lora.py +11 -5
- networks/resize_lora.py +187 -50
- networks/svd_merge_lora.py +40 -18
- requirements.txt +2 -1
- train_db.py +47 -45
- train_network.py +248 -175
- train_textual_inversion.py +72 -58
.gitattributes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
bitsandbytes_windows/libbitsandbytes_cuda116.dll filter=lfs diff=lfs merge=lfs -text
|
fine_tune.py
CHANGED
@@ -13,7 +13,11 @@ import diffusers
|
|
13 |
from diffusers import DDPMScheduler
|
14 |
|
15 |
import library.train_util as train_util
|
16 |
-
|
|
|
|
|
|
|
|
|
17 |
|
18 |
def collate_fn(examples):
|
19 |
return examples[0]
|
@@ -30,25 +34,36 @@ def train(args):
|
|
30 |
|
31 |
tokenizer = train_util.load_tokenizer(args)
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
if args.debug_dataset:
|
46 |
-
train_util.debug_dataset(
|
47 |
return
|
48 |
-
if len(
|
49 |
print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
|
50 |
return
|
51 |
|
|
|
|
|
|
|
52 |
# acceleratorを準備する
|
53 |
print("prepare accelerator")
|
54 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
@@ -109,7 +124,7 @@ def train(args):
|
|
109 |
vae.requires_grad_(False)
|
110 |
vae.eval()
|
111 |
with torch.no_grad():
|
112 |
-
|
113 |
vae.to("cpu")
|
114 |
if torch.cuda.is_available():
|
115 |
torch.cuda.empty_cache()
|
@@ -149,33 +164,13 @@ def train(args):
|
|
149 |
|
150 |
# 学習に必要なクラスを準備する
|
151 |
print("prepare optimizer, data loader etc.")
|
152 |
-
|
153 |
-
# 8-bit Adamを使う
|
154 |
-
if args.use_8bit_adam:
|
155 |
-
try:
|
156 |
-
import bitsandbytes as bnb
|
157 |
-
except ImportError:
|
158 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
159 |
-
print("use 8-bit Adam optimizer")
|
160 |
-
optimizer_class = bnb.optim.AdamW8bit
|
161 |
-
elif args.use_lion_optimizer:
|
162 |
-
try:
|
163 |
-
import lion_pytorch
|
164 |
-
except ImportError:
|
165 |
-
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
166 |
-
print("use Lion optimizer")
|
167 |
-
optimizer_class = lion_pytorch.Lion
|
168 |
-
else:
|
169 |
-
optimizer_class = torch.optim.AdamW
|
170 |
-
|
171 |
-
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
172 |
-
optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
|
173 |
|
174 |
# dataloaderを準備する
|
175 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
176 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
177 |
train_dataloader = torch.utils.data.DataLoader(
|
178 |
-
|
179 |
|
180 |
# 学習ステップ数を計算する
|
181 |
if args.max_train_epochs is not None:
|
@@ -183,8 +178,9 @@ def train(args):
|
|
183 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
184 |
|
185 |
# lr schedulerを用意する
|
186 |
-
lr_scheduler =
|
187 |
-
|
|
|
188 |
|
189 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
190 |
if args.full_fp16:
|
@@ -218,7 +214,7 @@ def train(args):
|
|
218 |
# 学習する
|
219 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
220 |
print("running training / 学習開始")
|
221 |
-
print(f" num examples / サンプル数: {
|
222 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
223 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
224 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
@@ -237,7 +233,7 @@ def train(args):
|
|
237 |
|
238 |
for epoch in range(num_train_epochs):
|
239 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
240 |
-
|
241 |
|
242 |
for m in training_models:
|
243 |
m.train()
|
@@ -286,11 +282,11 @@ def train(args):
|
|
286 |
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
287 |
|
288 |
accelerator.backward(loss)
|
289 |
-
if accelerator.sync_gradients:
|
290 |
params_to_clip = []
|
291 |
for m in training_models:
|
292 |
params_to_clip.extend(m.parameters())
|
293 |
-
accelerator.clip_grad_norm_(params_to_clip,
|
294 |
|
295 |
optimizer.step()
|
296 |
lr_scheduler.step()
|
@@ -301,11 +297,16 @@ def train(args):
|
|
301 |
progress_bar.update(1)
|
302 |
global_step += 1
|
303 |
|
|
|
|
|
304 |
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
305 |
if args.logging_dir is not None:
|
306 |
-
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
|
|
307 |
accelerator.log(logs, step=global_step)
|
308 |
|
|
|
309 |
loss_total += current_loss
|
310 |
avr_loss = loss_total / (step+1)
|
311 |
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
@@ -315,7 +316,7 @@ def train(args):
|
|
315 |
break
|
316 |
|
317 |
if args.logging_dir is not None:
|
318 |
-
logs = {"
|
319 |
accelerator.log(logs, step=epoch+1)
|
320 |
|
321 |
accelerator.wait_for_everyone()
|
@@ -325,6 +326,8 @@ def train(args):
|
|
325 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
326 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
327 |
|
|
|
|
|
328 |
is_main_process = accelerator.is_main_process
|
329 |
if is_main_process:
|
330 |
unet = unwrap_model(unet)
|
@@ -351,6 +354,8 @@ if __name__ == '__main__':
|
|
351 |
train_util.add_dataset_arguments(parser, False, True, True)
|
352 |
train_util.add_training_arguments(parser, False)
|
353 |
train_util.add_sd_saving_arguments(parser)
|
|
|
|
|
354 |
|
355 |
parser.add_argument("--diffusers_xformers", action='store_true',
|
356 |
help='use xformers by diffusers / Diffusersでxformersを使用する')
|
|
|
13 |
from diffusers import DDPMScheduler
|
14 |
|
15 |
import library.train_util as train_util
|
16 |
+
import library.config_util as config_util
|
17 |
+
from library.config_util import (
|
18 |
+
ConfigSanitizer,
|
19 |
+
BlueprintGenerator,
|
20 |
+
)
|
21 |
|
22 |
def collate_fn(examples):
|
23 |
return examples[0]
|
|
|
34 |
|
35 |
tokenizer = train_util.load_tokenizer(args)
|
36 |
|
37 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
|
38 |
+
if args.dataset_config is not None:
|
39 |
+
print(f"Load dataset config from {args.dataset_config}")
|
40 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
41 |
+
ignored = ["train_data_dir", "in_json"]
|
42 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
43 |
+
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
44 |
+
else:
|
45 |
+
user_config = {
|
46 |
+
"datasets": [{
|
47 |
+
"subsets": [{
|
48 |
+
"image_dir": args.train_data_dir,
|
49 |
+
"metadata_file": args.in_json,
|
50 |
+
}]
|
51 |
+
}]
|
52 |
+
}
|
53 |
+
|
54 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
55 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
56 |
|
57 |
if args.debug_dataset:
|
58 |
+
train_util.debug_dataset(train_dataset_group)
|
59 |
return
|
60 |
+
if len(train_dataset_group) == 0:
|
61 |
print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
|
62 |
return
|
63 |
|
64 |
+
if cache_latents:
|
65 |
+
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
66 |
+
|
67 |
# acceleratorを準備する
|
68 |
print("prepare accelerator")
|
69 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
|
|
124 |
vae.requires_grad_(False)
|
125 |
vae.eval()
|
126 |
with torch.no_grad():
|
127 |
+
train_dataset_group.cache_latents(vae)
|
128 |
vae.to("cpu")
|
129 |
if torch.cuda.is_available():
|
130 |
torch.cuda.empty_cache()
|
|
|
164 |
|
165 |
# 学習に必要なクラスを準備する
|
166 |
print("prepare optimizer, data loader etc.")
|
167 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
# dataloaderを準備する
|
170 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
171 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
172 |
train_dataloader = torch.utils.data.DataLoader(
|
173 |
+
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
174 |
|
175 |
# 学習ステップ数を計算する
|
176 |
if args.max_train_epochs is not None:
|
|
|
178 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
179 |
|
180 |
# lr schedulerを用意する
|
181 |
+
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
182 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
183 |
+
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
184 |
|
185 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
186 |
if args.full_fp16:
|
|
|
214 |
# 学習する
|
215 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
216 |
print("running training / 学習開始")
|
217 |
+
print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
218 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
219 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
220 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
|
|
233 |
|
234 |
for epoch in range(num_train_epochs):
|
235 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
236 |
+
train_dataset_group.set_current_epoch(epoch + 1)
|
237 |
|
238 |
for m in training_models:
|
239 |
m.train()
|
|
|
282 |
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
283 |
|
284 |
accelerator.backward(loss)
|
285 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
286 |
params_to_clip = []
|
287 |
for m in training_models:
|
288 |
params_to_clip.extend(m.parameters())
|
289 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
290 |
|
291 |
optimizer.step()
|
292 |
lr_scheduler.step()
|
|
|
297 |
progress_bar.update(1)
|
298 |
global_step += 1
|
299 |
|
300 |
+
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
301 |
+
|
302 |
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
303 |
if args.logging_dir is not None:
|
304 |
+
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
305 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
306 |
+
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
307 |
accelerator.log(logs, step=global_step)
|
308 |
|
309 |
+
# TODO moving averageにする
|
310 |
loss_total += current_loss
|
311 |
avr_loss = loss_total / (step+1)
|
312 |
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
316 |
break
|
317 |
|
318 |
if args.logging_dir is not None:
|
319 |
+
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
320 |
accelerator.log(logs, step=epoch+1)
|
321 |
|
322 |
accelerator.wait_for_everyone()
|
|
|
326 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
327 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
328 |
|
329 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
330 |
+
|
331 |
is_main_process = accelerator.is_main_process
|
332 |
if is_main_process:
|
333 |
unet = unwrap_model(unet)
|
|
|
354 |
train_util.add_dataset_arguments(parser, False, True, True)
|
355 |
train_util.add_training_arguments(parser, False)
|
356 |
train_util.add_sd_saving_arguments(parser)
|
357 |
+
train_util.add_optimizer_arguments(parser)
|
358 |
+
config_util.add_config_arguments(parser)
|
359 |
|
360 |
parser.add_argument("--diffusers_xformers", action='store_true',
|
361 |
help='use xformers by diffusers / Diffusersでxformersを使用する')
|
gen_img_diffusers.py
CHANGED
@@ -47,7 +47,7 @@ VGG(
|
|
47 |
"""
|
48 |
|
49 |
import json
|
50 |
-
from typing import List, Optional, Union
|
51 |
import glob
|
52 |
import importlib
|
53 |
import inspect
|
@@ -60,7 +60,6 @@ import math
|
|
60 |
import os
|
61 |
import random
|
62 |
import re
|
63 |
-
from typing import Any, Callable, List, Optional, Union
|
64 |
|
65 |
import diffusers
|
66 |
import numpy as np
|
@@ -81,6 +80,9 @@ from PIL import Image
|
|
81 |
from PIL.PngImagePlugin import PngInfo
|
82 |
|
83 |
import library.model_util as model_util
|
|
|
|
|
|
|
84 |
|
85 |
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
86 |
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
@@ -487,6 +489,9 @@ class PipelineLike():
|
|
487 |
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
488 |
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
489 |
|
|
|
|
|
|
|
490 |
# Textual Inversion
|
491 |
def add_token_replacement(self, target_token_id, rep_token_ids):
|
492 |
self.token_replacements[target_token_id] = rep_token_ids
|
@@ -500,7 +505,11 @@ class PipelineLike():
|
|
500 |
new_tokens.append(token)
|
501 |
return new_tokens
|
502 |
|
|
|
|
|
|
|
503 |
# region xformersとか使う部分:独自に書き換えるので関係なし
|
|
|
504 |
def enable_xformers_memory_efficient_attention(self):
|
505 |
r"""
|
506 |
Enable memory efficient attention as implemented in xformers.
|
@@ -581,6 +590,8 @@ class PipelineLike():
|
|
581 |
latents: Optional[torch.FloatTensor] = None,
|
582 |
max_embeddings_multiples: Optional[int] = 3,
|
583 |
output_type: Optional[str] = "pil",
|
|
|
|
|
584 |
# return_dict: bool = True,
|
585 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
586 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
@@ -672,6 +683,9 @@ class PipelineLike():
|
|
672 |
else:
|
673 |
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
674 |
|
|
|
|
|
|
|
675 |
if strength < 0 or strength > 1:
|
676 |
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
677 |
|
@@ -752,7 +766,7 @@ class PipelineLike():
|
|
752 |
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
|
753 |
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
|
754 |
|
755 |
-
if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None:
|
756 |
if isinstance(clip_guide_images, PIL.Image.Image):
|
757 |
clip_guide_images = [clip_guide_images]
|
758 |
|
@@ -765,7 +779,7 @@ class PipelineLike():
|
|
765 |
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
|
766 |
if len(image_embeddings_clip) == 1:
|
767 |
image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
|
768 |
-
|
769 |
size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
|
770 |
clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
|
771 |
clip_guide_images = torch.cat(clip_guide_images, dim=0)
|
@@ -774,6 +788,10 @@ class PipelineLike():
|
|
774 |
image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
|
775 |
if len(image_embeddings_vgg16) == 1:
|
776 |
image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
|
|
|
|
|
|
|
|
|
777 |
|
778 |
# set timesteps
|
779 |
self.scheduler.set_timesteps(num_inference_steps, self.device)
|
@@ -781,7 +799,6 @@ class PipelineLike():
|
|
781 |
latents_dtype = text_embeddings.dtype
|
782 |
init_latents_orig = None
|
783 |
mask = None
|
784 |
-
noise = None
|
785 |
|
786 |
if init_image is None:
|
787 |
# get the initial random noise unless the user supplied it
|
@@ -813,6 +830,8 @@ class PipelineLike():
|
|
813 |
if isinstance(init_image[0], PIL.Image.Image):
|
814 |
init_image = [preprocess_image(im) for im in init_image]
|
815 |
init_image = torch.cat(init_image)
|
|
|
|
|
816 |
|
817 |
# mask image to tensor
|
818 |
if mask_image is not None:
|
@@ -823,9 +842,24 @@ class PipelineLike():
|
|
823 |
|
824 |
# encode the init image into latents and scale the latents
|
825 |
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
826 |
-
|
827 |
-
|
828 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
829 |
if len(init_latents) == 1:
|
830 |
init_latents = init_latents.repeat((batch_size, 1, 1, 1))
|
831 |
init_latents_orig = init_latents
|
@@ -864,12 +898,21 @@ class PipelineLike():
|
|
864 |
extra_step_kwargs["eta"] = eta
|
865 |
|
866 |
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
|
|
|
|
|
|
|
|
|
867 |
for i, t in enumerate(tqdm(timesteps)):
|
868 |
# expand the latents if we are doing classifier free guidance
|
869 |
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
870 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
|
|
871 |
# predict the noise residual
|
872 |
-
|
|
|
|
|
|
|
|
|
873 |
|
874 |
# perform guidance
|
875 |
if do_classifier_free_guidance:
|
@@ -911,8 +954,19 @@ class PipelineLike():
|
|
911 |
if is_cancelled_callback is not None and is_cancelled_callback():
|
912 |
return None
|
913 |
|
|
|
|
|
|
|
914 |
latents = 1 / 0.18215 * latents
|
915 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
916 |
|
917 |
image = (image / 2 + 0.5).clamp(0, 1)
|
918 |
|
@@ -1595,10 +1649,11 @@ def get_unweighted_text_embeddings(
|
|
1595 |
if pad == eos: # v1
|
1596 |
text_input_chunk[:, -1] = text_input[0, -1]
|
1597 |
else: # v2
|
1598 |
-
|
1599 |
-
text_input_chunk[
|
1600 |
-
|
1601 |
-
text_input_chunk[
|
|
|
1602 |
|
1603 |
if clip_skip is None or clip_skip == 1:
|
1604 |
text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
@@ -1799,7 +1854,7 @@ def preprocess_mask(mask):
|
|
1799 |
mask = mask.convert("L")
|
1800 |
w, h = mask.size
|
1801 |
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
1802 |
-
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.LANCZOS)
|
1803 |
mask = np.array(mask).astype(np.float32) / 255.0
|
1804 |
mask = np.tile(mask, (4, 1, 1))
|
1805 |
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
@@ -1817,6 +1872,35 @@ def preprocess_mask(mask):
|
|
1817 |
# return text_encoder
|
1818 |
|
1819 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1820 |
def main(args):
|
1821 |
if args.fp16:
|
1822 |
dtype = torch.float16
|
@@ -1881,10 +1965,7 @@ def main(args):
|
|
1881 |
# tokenizerを読み込む
|
1882 |
print("loading tokenizer")
|
1883 |
if use_stable_diffusion_format:
|
1884 |
-
|
1885 |
-
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
1886 |
-
else:
|
1887 |
-
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
|
1888 |
|
1889 |
# schedulerを用意する
|
1890 |
sched_init_args = {}
|
@@ -1995,11 +2076,13 @@ def main(args):
|
|
1995 |
# networkを組み込む
|
1996 |
if args.network_module:
|
1997 |
networks = []
|
|
|
1998 |
for i, network_module in enumerate(args.network_module):
|
1999 |
print("import network module:", network_module)
|
2000 |
imported_module = importlib.import_module(network_module)
|
2001 |
|
2002 |
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
|
|
2003 |
|
2004 |
net_kwargs = {}
|
2005 |
if args.network_args and i < len(args.network_args):
|
@@ -2014,7 +2097,7 @@ def main(args):
|
|
2014 |
network_weight = args.network_weights[i]
|
2015 |
print("load network weights from:", network_weight)
|
2016 |
|
2017 |
-
if model_util.is_safetensors(network_weight):
|
2018 |
from safetensors.torch import safe_open
|
2019 |
with safe_open(network_weight, framework="pt") as f:
|
2020 |
metadata = f.metadata()
|
@@ -2037,6 +2120,18 @@ def main(args):
|
|
2037 |
else:
|
2038 |
networks = []
|
2039 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2040 |
if args.opt_channels_last:
|
2041 |
print(f"set optimizing: channels last")
|
2042 |
text_encoder.to(memory_format=torch.channels_last)
|
@@ -2050,9 +2145,14 @@ def main(args):
|
|
2050 |
if vgg16_model is not None:
|
2051 |
vgg16_model.to(memory_format=torch.channels_last)
|
2052 |
|
|
|
|
|
|
|
|
|
2053 |
pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
|
2054 |
clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
|
2055 |
vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
|
|
|
2056 |
print("pipeline is ready.")
|
2057 |
|
2058 |
if args.diffusers_xformers:
|
@@ -2177,18 +2277,34 @@ def main(args):
|
|
2177 |
mask_images = l
|
2178 |
|
2179 |
# 画像サイズにオプション指定があるときはリサイズする
|
2180 |
-
if
|
2181 |
-
|
2182 |
-
|
|
|
2183 |
if mask_images is not None:
|
2184 |
print(f"resize img2img mask images to {args.W}*{args.H}")
|
2185 |
mask_images = resize_images(mask_images, (args.W, args.H))
|
2186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2187 |
prev_image = None # for VGG16 guided
|
2188 |
if args.guide_image_path is not None:
|
2189 |
-
print(f"load image for CLIP/VGG16 guidance: {args.guide_image_path}")
|
2190 |
-
guide_images =
|
2191 |
-
|
|
|
|
|
|
|
2192 |
if len(guide_images) == 0:
|
2193 |
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
|
2194 |
guide_images = None
|
@@ -2219,33 +2335,46 @@ def main(args):
|
|
2219 |
iter_seed = random.randint(0, 0x7fffffff)
|
2220 |
|
2221 |
# バッチ処理の関数
|
2222 |
-
def process_batch(batch, highres_fix, highres_1st=False):
|
2223 |
batch_size = len(batch)
|
2224 |
|
2225 |
# highres_fixの処理
|
2226 |
if highres_fix and not highres_1st:
|
2227 |
-
# 1st stage
|
2228 |
-
print("process 1st
|
2229 |
batch_1st = []
|
2230 |
-
for
|
2231 |
-
width_1st = int(width * args.highres_fix_scale + .5)
|
2232 |
-
height_1st = int(height * args.highres_fix_scale + .5)
|
2233 |
width_1st = width_1st - width_1st % 32
|
2234 |
height_1st = height_1st - height_1st % 32
|
2235 |
-
|
|
|
|
|
|
|
2236 |
images_1st = process_batch(batch_1st, True, True)
|
2237 |
|
2238 |
# 2nd stageのバッチを作成して以下処理する
|
2239 |
-
print("process 2nd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2240 |
batch_2nd = []
|
2241 |
-
for i, (
|
2242 |
-
|
2243 |
-
|
2244 |
-
|
|
|
2245 |
batch = batch_2nd
|
2246 |
|
2247 |
-
|
2248 |
-
|
|
|
2249 |
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
2250 |
|
2251 |
prompts = []
|
@@ -2278,7 +2407,7 @@ def main(args):
|
|
2278 |
all_images_are_same = True
|
2279 |
all_masks_are_same = True
|
2280 |
all_guide_images_are_same = True
|
2281 |
-
for i, ((_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
|
2282 |
prompts.append(prompt)
|
2283 |
negative_prompts.append(negative_prompt)
|
2284 |
seeds.append(seed)
|
@@ -2295,9 +2424,13 @@ def main(args):
|
|
2295 |
all_masks_are_same = mask_images[-2] is mask_image
|
2296 |
|
2297 |
if guide_image is not None:
|
2298 |
-
|
2299 |
-
|
2300 |
-
all_guide_images_are_same =
|
|
|
|
|
|
|
|
|
2301 |
|
2302 |
# make start code
|
2303 |
torch.manual_seed(seed)
|
@@ -2320,10 +2453,24 @@ def main(args):
|
|
2320 |
if guide_images is not None and all_guide_images_are_same:
|
2321 |
guide_images = guide_images[0]
|
2322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2323 |
# generate
|
|
|
|
|
|
|
|
|
2324 |
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
|
2325 |
-
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises,
|
2326 |
-
|
|
|
|
|
2327 |
return images
|
2328 |
|
2329 |
# save image
|
@@ -2398,6 +2545,7 @@ def main(args):
|
|
2398 |
strength = 0.8 if args.strength is None else args.strength
|
2399 |
negative_prompt = ""
|
2400 |
clip_prompt = None
|
|
|
2401 |
|
2402 |
prompt_args = prompt.strip().split(' --')
|
2403 |
prompt = prompt_args[0]
|
@@ -2461,6 +2609,15 @@ def main(args):
|
|
2461 |
clip_prompt = m.group(1)
|
2462 |
print(f"clip prompt: {clip_prompt}")
|
2463 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2464 |
except ValueError as ex:
|
2465 |
print(f"Exception in parsing / 解析エラー: {parg}")
|
2466 |
print(ex)
|
@@ -2498,7 +2655,12 @@ def main(args):
|
|
2498 |
mask_image = mask_images[global_step % len(mask_images)]
|
2499 |
|
2500 |
if guide_images is not None:
|
2501 |
-
|
|
|
|
|
|
|
|
|
|
|
2502 |
elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
|
2503 |
if prev_image is None:
|
2504 |
print("Generate 1st image without guide image.")
|
@@ -2506,10 +2668,9 @@ def main(args):
|
|
2506 |
print("Use previous image as guide image.")
|
2507 |
guide_image = prev_image
|
2508 |
|
2509 |
-
|
2510 |
-
|
2511 |
-
|
2512 |
-
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
|
2513 |
process_batch(batch_data, highres_fix)
|
2514 |
batch_data.clear()
|
2515 |
|
@@ -2553,6 +2714,8 @@ if __name__ == '__main__':
|
|
2553 |
parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
|
2554 |
parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
|
2555 |
parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
|
|
|
|
|
2556 |
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
|
2557 |
parser.add_argument('--sampler', type=str, default='ddim',
|
2558 |
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
|
@@ -2564,6 +2727,8 @@ if __name__ == '__main__':
|
|
2564 |
parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
|
2565 |
parser.add_argument("--vae", type=str, default=None,
|
2566 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
|
|
|
|
2567 |
# parser.add_argument("--replace_clip_l14_336", action='store_true',
|
2568 |
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
|
2569 |
parser.add_argument("--seed", type=int, default=None,
|
@@ -2578,12 +2743,15 @@ if __name__ == '__main__':
|
|
2578 |
parser.add_argument("--opt_channels_last", action='store_true',
|
2579 |
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
|
2580 |
parser.add_argument("--network_module", type=str, default=None, nargs='*',
|
2581 |
-
help='
|
2582 |
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
2583 |
-
help='
|
2584 |
-
parser.add_argument("--network_mul", type=float, default=None, nargs='*',
|
|
|
2585 |
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
2586 |
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
|
|
|
|
2587 |
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
|
2588 |
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
|
2589 |
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
@@ -2597,15 +2765,26 @@ if __name__ == '__main__':
|
|
2597 |
help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
|
2598 |
parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
|
2599 |
help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
|
2600 |
-
parser.add_argument("--guide_image_path", type=str, default=None,
|
|
|
2601 |
parser.add_argument("--highres_fix_scale", type=float, default=None,
|
2602 |
help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
|
2603 |
parser.add_argument("--highres_fix_steps", type=int, default=28,
|
2604 |
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
|
2605 |
parser.add_argument("--highres_fix_save_1st", action='store_true',
|
2606 |
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
|
|
|
|
|
2607 |
parser.add_argument("--negative_scale", type=float, default=None,
|
2608 |
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
|
2609 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2610 |
args = parser.parse_args()
|
2611 |
main(args)
|
|
|
47 |
"""
|
48 |
|
49 |
import json
|
50 |
+
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
|
51 |
import glob
|
52 |
import importlib
|
53 |
import inspect
|
|
|
60 |
import os
|
61 |
import random
|
62 |
import re
|
|
|
63 |
|
64 |
import diffusers
|
65 |
import numpy as np
|
|
|
80 |
from PIL.PngImagePlugin import PngInfo
|
81 |
|
82 |
import library.model_util as model_util
|
83 |
+
import library.train_util as train_util
|
84 |
+
import tools.original_control_net as original_control_net
|
85 |
+
from tools.original_control_net import ControlNetInfo
|
86 |
|
87 |
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
88 |
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
|
|
489 |
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
490 |
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
491 |
|
492 |
+
# ControlNet
|
493 |
+
self.control_nets: List[ControlNetInfo] = []
|
494 |
+
|
495 |
# Textual Inversion
|
496 |
def add_token_replacement(self, target_token_id, rep_token_ids):
|
497 |
self.token_replacements[target_token_id] = rep_token_ids
|
|
|
505 |
new_tokens.append(token)
|
506 |
return new_tokens
|
507 |
|
508 |
+
def set_control_nets(self, ctrl_nets):
|
509 |
+
self.control_nets = ctrl_nets
|
510 |
+
|
511 |
# region xformersとか使う部分:独自に書き換えるので関係なし
|
512 |
+
|
513 |
def enable_xformers_memory_efficient_attention(self):
|
514 |
r"""
|
515 |
Enable memory efficient attention as implemented in xformers.
|
|
|
590 |
latents: Optional[torch.FloatTensor] = None,
|
591 |
max_embeddings_multiples: Optional[int] = 3,
|
592 |
output_type: Optional[str] = "pil",
|
593 |
+
vae_batch_size: float = None,
|
594 |
+
return_latents: bool = False,
|
595 |
# return_dict: bool = True,
|
596 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
597 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
|
|
683 |
else:
|
684 |
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
685 |
|
686 |
+
vae_batch_size = batch_size if vae_batch_size is None else (
|
687 |
+
int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size)))
|
688 |
+
|
689 |
if strength < 0 or strength > 1:
|
690 |
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
691 |
|
|
|
766 |
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
|
767 |
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
|
768 |
|
769 |
+
if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None or self.control_nets:
|
770 |
if isinstance(clip_guide_images, PIL.Image.Image):
|
771 |
clip_guide_images = [clip_guide_images]
|
772 |
|
|
|
779 |
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
|
780 |
if len(image_embeddings_clip) == 1:
|
781 |
image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
|
782 |
+
elif self.vgg16_guidance_scale > 0:
|
783 |
size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
|
784 |
clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
|
785 |
clip_guide_images = torch.cat(clip_guide_images, dim=0)
|
|
|
788 |
image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
|
789 |
if len(image_embeddings_vgg16) == 1:
|
790 |
image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
|
791 |
+
else:
|
792 |
+
# ControlNetのhintにguide imageを流用する
|
793 |
+
# 前処理はControlNet側で行う
|
794 |
+
pass
|
795 |
|
796 |
# set timesteps
|
797 |
self.scheduler.set_timesteps(num_inference_steps, self.device)
|
|
|
799 |
latents_dtype = text_embeddings.dtype
|
800 |
init_latents_orig = None
|
801 |
mask = None
|
|
|
802 |
|
803 |
if init_image is None:
|
804 |
# get the initial random noise unless the user supplied it
|
|
|
830 |
if isinstance(init_image[0], PIL.Image.Image):
|
831 |
init_image = [preprocess_image(im) for im in init_image]
|
832 |
init_image = torch.cat(init_image)
|
833 |
+
if isinstance(init_image, list):
|
834 |
+
init_image = torch.stack(init_image)
|
835 |
|
836 |
# mask image to tensor
|
837 |
if mask_image is not None:
|
|
|
842 |
|
843 |
# encode the init image into latents and scale the latents
|
844 |
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
845 |
+
if init_image.size()[2:] == (height // 8, width // 8):
|
846 |
+
init_latents = init_image
|
847 |
+
else:
|
848 |
+
if vae_batch_size >= batch_size:
|
849 |
+
init_latent_dist = self.vae.encode(init_image).latent_dist
|
850 |
+
init_latents = init_latent_dist.sample(generator=generator)
|
851 |
+
else:
|
852 |
+
if torch.cuda.is_available():
|
853 |
+
torch.cuda.empty_cache()
|
854 |
+
init_latents = []
|
855 |
+
for i in tqdm(range(0, batch_size, vae_batch_size)):
|
856 |
+
init_latent_dist = self.vae.encode(init_image[i:i + vae_batch_size]
|
857 |
+
if vae_batch_size > 1 else init_image[i].unsqueeze(0)).latent_dist
|
858 |
+
init_latents.append(init_latent_dist.sample(generator=generator))
|
859 |
+
init_latents = torch.cat(init_latents)
|
860 |
+
|
861 |
+
init_latents = 0.18215 * init_latents
|
862 |
+
|
863 |
if len(init_latents) == 1:
|
864 |
init_latents = init_latents.repeat((batch_size, 1, 1, 1))
|
865 |
init_latents_orig = init_latents
|
|
|
898 |
extra_step_kwargs["eta"] = eta
|
899 |
|
900 |
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
|
901 |
+
|
902 |
+
if self.control_nets:
|
903 |
+
guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
|
904 |
+
|
905 |
for i, t in enumerate(tqdm(timesteps)):
|
906 |
# expand the latents if we are doing classifier free guidance
|
907 |
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
908 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
909 |
+
|
910 |
# predict the noise residual
|
911 |
+
if self.control_nets:
|
912 |
+
noise_pred = original_control_net.call_unet_and_control_net(
|
913 |
+
i, num_latent_input, self.unet, self.control_nets, guided_hints, i / len(timesteps), latent_model_input, t, text_embeddings).sample
|
914 |
+
else:
|
915 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
916 |
|
917 |
# perform guidance
|
918 |
if do_classifier_free_guidance:
|
|
|
954 |
if is_cancelled_callback is not None and is_cancelled_callback():
|
955 |
return None
|
956 |
|
957 |
+
if return_latents:
|
958 |
+
return (latents, False)
|
959 |
+
|
960 |
latents = 1 / 0.18215 * latents
|
961 |
+
if vae_batch_size >= batch_size:
|
962 |
+
image = self.vae.decode(latents).sample
|
963 |
+
else:
|
964 |
+
if torch.cuda.is_available():
|
965 |
+
torch.cuda.empty_cache()
|
966 |
+
images = []
|
967 |
+
for i in tqdm(range(0, batch_size, vae_batch_size)):
|
968 |
+
images.append(self.vae.decode(latents[i:i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample)
|
969 |
+
image = torch.cat(images)
|
970 |
|
971 |
image = (image / 2 + 0.5).clamp(0, 1)
|
972 |
|
|
|
1649 |
if pad == eos: # v1
|
1650 |
text_input_chunk[:, -1] = text_input[0, -1]
|
1651 |
else: # v2
|
1652 |
+
for j in range(len(text_input_chunk)):
|
1653 |
+
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
1654 |
+
text_input_chunk[j, -1] = eos
|
1655 |
+
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
1656 |
+
text_input_chunk[j, 1] = eos
|
1657 |
|
1658 |
if clip_skip is None or clip_skip == 1:
|
1659 |
text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
|
|
1854 |
mask = mask.convert("L")
|
1855 |
w, h = mask.size
|
1856 |
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
1857 |
+
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS)
|
1858 |
mask = np.array(mask).astype(np.float32) / 255.0
|
1859 |
mask = np.tile(mask, (4, 1, 1))
|
1860 |
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
|
|
1872 |
# return text_encoder
|
1873 |
|
1874 |
|
1875 |
+
class BatchDataBase(NamedTuple):
|
1876 |
+
# バッチ分割が必要ないデータ
|
1877 |
+
step: int
|
1878 |
+
prompt: str
|
1879 |
+
negative_prompt: str
|
1880 |
+
seed: int
|
1881 |
+
init_image: Any
|
1882 |
+
mask_image: Any
|
1883 |
+
clip_prompt: str
|
1884 |
+
guide_image: Any
|
1885 |
+
|
1886 |
+
|
1887 |
+
class BatchDataExt(NamedTuple):
|
1888 |
+
# バッチ分割が必要なデータ
|
1889 |
+
width: int
|
1890 |
+
height: int
|
1891 |
+
steps: int
|
1892 |
+
scale: float
|
1893 |
+
negative_scale: float
|
1894 |
+
strength: float
|
1895 |
+
network_muls: Tuple[float]
|
1896 |
+
|
1897 |
+
|
1898 |
+
class BatchData(NamedTuple):
|
1899 |
+
return_latents: bool
|
1900 |
+
base: BatchDataBase
|
1901 |
+
ext: BatchDataExt
|
1902 |
+
|
1903 |
+
|
1904 |
def main(args):
|
1905 |
if args.fp16:
|
1906 |
dtype = torch.float16
|
|
|
1965 |
# tokenizerを読み込む
|
1966 |
print("loading tokenizer")
|
1967 |
if use_stable_diffusion_format:
|
1968 |
+
tokenizer = train_util.load_tokenizer(args)
|
|
|
|
|
|
|
1969 |
|
1970 |
# schedulerを用意する
|
1971 |
sched_init_args = {}
|
|
|
2076 |
# networkを組み込む
|
2077 |
if args.network_module:
|
2078 |
networks = []
|
2079 |
+
network_default_muls = []
|
2080 |
for i, network_module in enumerate(args.network_module):
|
2081 |
print("import network module:", network_module)
|
2082 |
imported_module = importlib.import_module(network_module)
|
2083 |
|
2084 |
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
2085 |
+
network_default_muls.append(network_mul)
|
2086 |
|
2087 |
net_kwargs = {}
|
2088 |
if args.network_args and i < len(args.network_args):
|
|
|
2097 |
network_weight = args.network_weights[i]
|
2098 |
print("load network weights from:", network_weight)
|
2099 |
|
2100 |
+
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
2101 |
from safetensors.torch import safe_open
|
2102 |
with safe_open(network_weight, framework="pt") as f:
|
2103 |
metadata = f.metadata()
|
|
|
2120 |
else:
|
2121 |
networks = []
|
2122 |
|
2123 |
+
# ControlNetの処理
|
2124 |
+
control_nets: List[ControlNetInfo] = []
|
2125 |
+
if args.control_net_models:
|
2126 |
+
for i, model in enumerate(args.control_net_models):
|
2127 |
+
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
|
2128 |
+
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
|
2129 |
+
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
|
2130 |
+
|
2131 |
+
ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
|
2132 |
+
prep = original_control_net.load_preprocess(prep_type)
|
2133 |
+
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
|
2134 |
+
|
2135 |
if args.opt_channels_last:
|
2136 |
print(f"set optimizing: channels last")
|
2137 |
text_encoder.to(memory_format=torch.channels_last)
|
|
|
2145 |
if vgg16_model is not None:
|
2146 |
vgg16_model.to(memory_format=torch.channels_last)
|
2147 |
|
2148 |
+
for cn in control_nets:
|
2149 |
+
cn.unet.to(memory_format=torch.channels_last)
|
2150 |
+
cn.net.to(memory_format=torch.channels_last)
|
2151 |
+
|
2152 |
pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
|
2153 |
clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
|
2154 |
vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
|
2155 |
+
pipe.set_control_nets(control_nets)
|
2156 |
print("pipeline is ready.")
|
2157 |
|
2158 |
if args.diffusers_xformers:
|
|
|
2277 |
mask_images = l
|
2278 |
|
2279 |
# 画像サイズにオプション指定があるときはリサイズする
|
2280 |
+
if args.W is not None and args.H is not None:
|
2281 |
+
if init_images is not None:
|
2282 |
+
print(f"resize img2img source images to {args.W}*{args.H}")
|
2283 |
+
init_images = resize_images(init_images, (args.W, args.H))
|
2284 |
if mask_images is not None:
|
2285 |
print(f"resize img2img mask images to {args.W}*{args.H}")
|
2286 |
mask_images = resize_images(mask_images, (args.W, args.H))
|
2287 |
|
2288 |
+
if networks and mask_images:
|
2289 |
+
# mask を領域情報として流用する、現在は1枚だけ対応
|
2290 |
+
# TODO 複数のnetwork classの混在時の考慮
|
2291 |
+
print("use mask as region")
|
2292 |
+
# import cv2
|
2293 |
+
# for i in range(3):
|
2294 |
+
# cv2.imshow("msk", np.array(mask_images[0])[:,:,i])
|
2295 |
+
# cv2.waitKey()
|
2296 |
+
# cv2.destroyAllWindows()
|
2297 |
+
networks[0].__class__.set_regions(networks, np.array(mask_images[0]))
|
2298 |
+
mask_images = None
|
2299 |
+
|
2300 |
prev_image = None # for VGG16 guided
|
2301 |
if args.guide_image_path is not None:
|
2302 |
+
print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
|
2303 |
+
guide_images = []
|
2304 |
+
for p in args.guide_image_path:
|
2305 |
+
guide_images.extend(load_images(p))
|
2306 |
+
|
2307 |
+
print(f"loaded {len(guide_images)} guide images for guidance")
|
2308 |
if len(guide_images) == 0:
|
2309 |
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
|
2310 |
guide_images = None
|
|
|
2335 |
iter_seed = random.randint(0, 0x7fffffff)
|
2336 |
|
2337 |
# バッチ処理の関数
|
2338 |
+
def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
|
2339 |
batch_size = len(batch)
|
2340 |
|
2341 |
# highres_fixの処理
|
2342 |
if highres_fix and not highres_1st:
|
2343 |
+
# 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
|
2344 |
+
print("process 1st stage")
|
2345 |
batch_1st = []
|
2346 |
+
for _, base, ext in batch:
|
2347 |
+
width_1st = int(ext.width * args.highres_fix_scale + .5)
|
2348 |
+
height_1st = int(ext.height * args.highres_fix_scale + .5)
|
2349 |
width_1st = width_1st - width_1st % 32
|
2350 |
height_1st = height_1st - height_1st % 32
|
2351 |
+
|
2352 |
+
ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale,
|
2353 |
+
ext.negative_scale, ext.strength, ext.network_muls)
|
2354 |
+
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
|
2355 |
images_1st = process_batch(batch_1st, True, True)
|
2356 |
|
2357 |
# 2nd stageのバッチを作成して以下処理する
|
2358 |
+
print("process 2nd stage")
|
2359 |
+
if args.highres_fix_latents_upscaling:
|
2360 |
+
org_dtype = images_1st.dtype
|
2361 |
+
if images_1st.dtype == torch.bfloat16:
|
2362 |
+
images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
|
2363 |
+
images_1st = torch.nn.functional.interpolate(
|
2364 |
+
images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode='bilinear') # , antialias=True)
|
2365 |
+
images_1st = images_1st.to(org_dtype)
|
2366 |
+
|
2367 |
batch_2nd = []
|
2368 |
+
for i, (bd, image) in enumerate(zip(batch, images_1st)):
|
2369 |
+
if not args.highres_fix_latents_upscaling:
|
2370 |
+
image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
|
2371 |
+
bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext)
|
2372 |
+
batch_2nd.append(bd_2nd)
|
2373 |
batch = batch_2nd
|
2374 |
|
2375 |
+
# このバッチの情報を取り出す
|
2376 |
+
return_latents, (step_first, _, _, _, init_image, mask_image, _, guide_image), \
|
2377 |
+
(width, height, steps, scale, negative_scale, strength, network_muls) = batch[0]
|
2378 |
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
2379 |
|
2380 |
prompts = []
|
|
|
2407 |
all_images_are_same = True
|
2408 |
all_masks_are_same = True
|
2409 |
all_guide_images_are_same = True
|
2410 |
+
for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
|
2411 |
prompts.append(prompt)
|
2412 |
negative_prompts.append(negative_prompt)
|
2413 |
seeds.append(seed)
|
|
|
2424 |
all_masks_are_same = mask_images[-2] is mask_image
|
2425 |
|
2426 |
if guide_image is not None:
|
2427 |
+
if type(guide_image) is list:
|
2428 |
+
guide_images.extend(guide_image)
|
2429 |
+
all_guide_images_are_same = False
|
2430 |
+
else:
|
2431 |
+
guide_images.append(guide_image)
|
2432 |
+
if i > 0 and all_guide_images_are_same:
|
2433 |
+
all_guide_images_are_same = guide_images[-2] is guide_image
|
2434 |
|
2435 |
# make start code
|
2436 |
torch.manual_seed(seed)
|
|
|
2453 |
if guide_images is not None and all_guide_images_are_same:
|
2454 |
guide_images = guide_images[0]
|
2455 |
|
2456 |
+
# ControlNet使用時はguide imageをリサイズする
|
2457 |
+
if control_nets:
|
2458 |
+
# TODO resample��メソッド
|
2459 |
+
guide_images = guide_images if type(guide_images) == list else [guide_images]
|
2460 |
+
guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images]
|
2461 |
+
if len(guide_images) == 1:
|
2462 |
+
guide_images = guide_images[0]
|
2463 |
+
|
2464 |
# generate
|
2465 |
+
if networks:
|
2466 |
+
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
2467 |
+
n.set_multiplier(m)
|
2468 |
+
|
2469 |
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
|
2470 |
+
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises,
|
2471 |
+
vae_batch_size=args.vae_batch_size, return_latents=return_latents,
|
2472 |
+
clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
|
2473 |
+
if highres_1st and not args.highres_fix_save_1st: # return images or latents
|
2474 |
return images
|
2475 |
|
2476 |
# save image
|
|
|
2545 |
strength = 0.8 if args.strength is None else args.strength
|
2546 |
negative_prompt = ""
|
2547 |
clip_prompt = None
|
2548 |
+
network_muls = None
|
2549 |
|
2550 |
prompt_args = prompt.strip().split(' --')
|
2551 |
prompt = prompt_args[0]
|
|
|
2609 |
clip_prompt = m.group(1)
|
2610 |
print(f"clip prompt: {clip_prompt}")
|
2611 |
continue
|
2612 |
+
|
2613 |
+
m = re.match(r'am ([\d\.\-,]+)', parg, re.IGNORECASE)
|
2614 |
+
if m: # network multiplies
|
2615 |
+
network_muls = [float(v) for v in m.group(1).split(",")]
|
2616 |
+
while len(network_muls) < len(networks):
|
2617 |
+
network_muls.append(network_muls[-1])
|
2618 |
+
print(f"network mul: {network_muls}")
|
2619 |
+
continue
|
2620 |
+
|
2621 |
except ValueError as ex:
|
2622 |
print(f"Exception in parsing / 解析エラー: {parg}")
|
2623 |
print(ex)
|
|
|
2655 |
mask_image = mask_images[global_step % len(mask_images)]
|
2656 |
|
2657 |
if guide_images is not None:
|
2658 |
+
if control_nets: # 複数件の場合あり
|
2659 |
+
c = len(control_nets)
|
2660 |
+
p = global_step % (len(guide_images) // c)
|
2661 |
+
guide_image = guide_images[p * c:p * c + c]
|
2662 |
+
else:
|
2663 |
+
guide_image = guide_images[global_step % len(guide_images)]
|
2664 |
elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
|
2665 |
if prev_image is None:
|
2666 |
print("Generate 1st image without guide image.")
|
|
|
2668 |
print("Use previous image as guide image.")
|
2669 |
guide_image = prev_image
|
2670 |
|
2671 |
+
b1 = BatchData(False, BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
2672 |
+
BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None))
|
2673 |
+
if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
|
|
|
2674 |
process_batch(batch_data, highres_fix)
|
2675 |
batch_data.clear()
|
2676 |
|
|
|
2714 |
parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
|
2715 |
parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
|
2716 |
parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
|
2717 |
+
parser.add_argument("--vae_batch_size", type=float, default=None,
|
2718 |
+
help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率")
|
2719 |
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
|
2720 |
parser.add_argument('--sampler', type=str, default='ddim',
|
2721 |
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
|
|
|
2727 |
parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
|
2728 |
parser.add_argument("--vae", type=str, default=None,
|
2729 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
2730 |
+
parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
|
2731 |
+
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
|
2732 |
# parser.add_argument("--replace_clip_l14_336", action='store_true',
|
2733 |
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
|
2734 |
parser.add_argument("--seed", type=int, default=None,
|
|
|
2743 |
parser.add_argument("--opt_channels_last", action='store_true',
|
2744 |
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
|
2745 |
parser.add_argument("--network_module", type=str, default=None, nargs='*',
|
2746 |
+
help='additional network module to use / 追加ネットワークを使う時そのモジュール名')
|
2747 |
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
2748 |
+
help='additional network weights to load / 追加ネットワークの重み')
|
2749 |
+
parser.add_argument("--network_mul", type=float, default=None, nargs='*',
|
2750 |
+
help='additional network multiplier / 追加ネットワークの効果の倍率')
|
2751 |
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
2752 |
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
2753 |
+
parser.add_argument("--network_show_meta", action='store_true',
|
2754 |
+
help='show metadata of network model / ネットワークモデルのメタデータを表示する')
|
2755 |
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
|
2756 |
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
|
2757 |
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
|
|
2765 |
help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
|
2766 |
parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
|
2767 |
help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
|
2768 |
+
parser.add_argument("--guide_image_path", type=str, default=None, nargs="*",
|
2769 |
+
help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
|
2770 |
parser.add_argument("--highres_fix_scale", type=float, default=None,
|
2771 |
help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
|
2772 |
parser.add_argument("--highres_fix_steps", type=int, default=28,
|
2773 |
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
|
2774 |
parser.add_argument("--highres_fix_save_1st", action='store_true',
|
2775 |
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
|
2776 |
+
parser.add_argument("--highres_fix_latents_upscaling", action='store_true',
|
2777 |
+
help="use latents upscaling for highres fix / highres fixでlatentで拡大する")
|
2778 |
parser.add_argument("--negative_scale", type=float, default=None,
|
2779 |
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
|
2780 |
|
2781 |
+
parser.add_argument("--control_net_models", type=str, default=None, nargs='*',
|
2782 |
+
help='ControlNet models to use / 使用するControlNetのモデル名')
|
2783 |
+
parser.add_argument("--control_net_preps", type=str, default=None, nargs='*',
|
2784 |
+
help='ControlNet preprocess to use / 使用するControlNetのプリプロセス名')
|
2785 |
+
parser.add_argument("--control_net_weights", type=float, default=None, nargs='*', help='ControlNet weights / ControlNetの重み')
|
2786 |
+
parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*',
|
2787 |
+
help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率')
|
2788 |
+
|
2789 |
args = parser.parse_args()
|
2790 |
main(args)
|
library/model_util.py
CHANGED
@@ -4,7 +4,7 @@
|
|
4 |
import math
|
5 |
import os
|
6 |
import torch
|
7 |
-
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
|
8 |
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
9 |
from safetensors.torch import load_file, save_file
|
10 |
|
@@ -916,7 +916,11 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
|
916 |
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
917 |
else:
|
918 |
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
|
|
|
|
919 |
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
|
|
|
|
920 |
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
921 |
print("loading text encoder:", info)
|
922 |
|
|
|
4 |
import math
|
5 |
import os
|
6 |
import torch
|
7 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
8 |
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
9 |
from safetensors.torch import load_file, save_file
|
10 |
|
|
|
916 |
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
917 |
else:
|
918 |
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
919 |
+
|
920 |
+
logging.set_verbosity_error() # don't show annoying warning
|
921 |
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
922 |
+
logging.set_verbosity_warning()
|
923 |
+
|
924 |
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
925 |
print("loading text encoder:", info)
|
926 |
|
library/train_util.py
CHANGED
@@ -1,12 +1,21 @@
|
|
1 |
# common functions for training
|
2 |
|
3 |
import argparse
|
|
|
4 |
import json
|
|
|
5 |
import shutil
|
6 |
import time
|
7 |
-
from typing import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from accelerate import Accelerator
|
9 |
-
from torch.autograd.function import Function
|
10 |
import glob
|
11 |
import math
|
12 |
import os
|
@@ -17,10 +26,16 @@ from io import BytesIO
|
|
17 |
|
18 |
from tqdm import tqdm
|
19 |
import torch
|
|
|
20 |
from torchvision import transforms
|
21 |
from transformers import CLIPTokenizer
|
|
|
22 |
import diffusers
|
23 |
-
from diffusers import
|
|
|
|
|
|
|
|
|
24 |
import albumentations as albu
|
25 |
import numpy as np
|
26 |
from PIL import Image
|
@@ -195,23 +210,95 @@ class BucketBatchIndex(NamedTuple):
|
|
195 |
batch_index: int
|
196 |
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
class BaseDataset(torch.utils.data.Dataset):
|
199 |
-
def __init__(self, tokenizer, max_token_length
|
200 |
super().__init__()
|
201 |
-
self.tokenizer
|
202 |
self.max_token_length = max_token_length
|
203 |
-
self.shuffle_caption = shuffle_caption
|
204 |
-
self.shuffle_keep_tokens = shuffle_keep_tokens
|
205 |
# width/height is used when enable_bucket==False
|
206 |
self.width, self.height = (None, None) if resolution is None else resolution
|
207 |
-
self.face_crop_aug_range = face_crop_aug_range
|
208 |
-
self.flip_aug = flip_aug
|
209 |
-
self.color_aug = color_aug
|
210 |
self.debug_dataset = debug_dataset
|
211 |
-
|
|
|
|
|
212 |
self.token_padding_disabled = False
|
213 |
-
self.dataset_dirs_info = {}
|
214 |
-
self.reg_dataset_dirs_info = {}
|
215 |
self.tag_frequency = {}
|
216 |
|
217 |
self.enable_bucket = False
|
@@ -225,49 +312,28 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
225 |
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
226 |
|
227 |
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
228 |
-
self.dropout_rate: float = 0
|
229 |
-
self.dropout_every_n_epochs: int = None
|
230 |
-
self.tag_dropout_rate: float = 0
|
231 |
|
232 |
# augmentation
|
233 |
-
|
234 |
-
if color_aug:
|
235 |
-
# わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
|
236 |
-
self.aug = albu.Compose([
|
237 |
-
albu.OneOf([
|
238 |
-
albu.HueSaturationValue(8, 0, 0, p=.5),
|
239 |
-
albu.RandomGamma((95, 105), p=.5),
|
240 |
-
], p=.33),
|
241 |
-
albu.HorizontalFlip(p=flip_p)
|
242 |
-
], p=1.)
|
243 |
-
elif flip_aug:
|
244 |
-
self.aug = albu.Compose([
|
245 |
-
albu.HorizontalFlip(p=flip_p)
|
246 |
-
], p=1.)
|
247 |
-
else:
|
248 |
-
self.aug = None
|
249 |
|
250 |
self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
|
251 |
|
252 |
self.image_data: Dict[str, ImageInfo] = {}
|
|
|
253 |
|
254 |
self.replacements = {}
|
255 |
|
256 |
def set_current_epoch(self, epoch):
|
257 |
self.current_epoch = epoch
|
258 |
-
|
259 |
-
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
|
260 |
-
# コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
|
261 |
-
self.dropout_rate = dropout_rate
|
262 |
-
self.dropout_every_n_epochs = dropout_every_n_epochs
|
263 |
-
self.tag_dropout_rate = tag_dropout_rate
|
264 |
|
265 |
def set_tag_frequency(self, dir_name, captions):
|
266 |
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
267 |
self.tag_frequency[dir_name] = frequency_for_dir
|
268 |
for caption in captions:
|
269 |
for tag in caption.split(","):
|
270 |
-
|
|
|
271 |
tag = tag.lower()
|
272 |
frequency = frequency_for_dir.get(tag, 0)
|
273 |
frequency_for_dir[tag] = frequency + 1
|
@@ -278,42 +344,36 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
278 |
def add_replacement(self, str_from, str_to):
|
279 |
self.replacements[str_from] = str_to
|
280 |
|
281 |
-
def process_caption(self, caption):
|
282 |
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
283 |
-
is_drop_out =
|
284 |
-
is_drop_out = is_drop_out or
|
285 |
|
286 |
if is_drop_out:
|
287 |
caption = ""
|
288 |
else:
|
289 |
-
if
|
290 |
def dropout_tags(tokens):
|
291 |
-
if
|
292 |
return tokens
|
293 |
l = []
|
294 |
for token in tokens:
|
295 |
-
if random.random() >=
|
296 |
l.append(token)
|
297 |
return l
|
298 |
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
tokens = dropout_tags(tokens)
|
305 |
-
else:
|
306 |
-
if len(tokens) > self.shuffle_keep_tokens:
|
307 |
-
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
308 |
-
tokens = tokens[self.shuffle_keep_tokens:]
|
309 |
|
310 |
-
|
311 |
-
|
312 |
|
313 |
-
|
314 |
|
315 |
-
|
316 |
-
caption = ", ".join(tokens)
|
317 |
|
318 |
# textual inversion対応
|
319 |
for str_from, str_to in self.replacements.items():
|
@@ -367,8 +427,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
367 |
input_ids = torch.stack(iids_list) # 3,77
|
368 |
return input_ids
|
369 |
|
370 |
-
def register_image(self, info: ImageInfo):
|
371 |
self.image_data[info.image_key] = info
|
|
|
372 |
|
373 |
def make_buckets(self):
|
374 |
'''
|
@@ -467,7 +528,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
467 |
img = np.array(image, np.uint8)
|
468 |
return img
|
469 |
|
470 |
-
def trim_and_resize_if_required(self, image, reso, resized_size):
|
471 |
image_height, image_width = image.shape[0:2]
|
472 |
|
473 |
if image_width != resized_size[0] or image_height != resized_size[1]:
|
@@ -477,22 +538,27 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
477 |
image_height, image_width = image.shape[0:2]
|
478 |
if image_width > reso[0]:
|
479 |
trim_size = image_width - reso[0]
|
480 |
-
p = trim_size // 2 if not
|
481 |
# print("w", trim_size, p)
|
482 |
image = image[:, p:p + reso[0]]
|
483 |
if image_height > reso[1]:
|
484 |
trim_size = image_height - reso[1]
|
485 |
-
p = trim_size // 2 if not
|
486 |
# print("h", trim_size, p)
|
487 |
image = image[p:p + reso[1]]
|
488 |
|
489 |
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
490 |
return image
|
491 |
|
|
|
|
|
|
|
492 |
def cache_latents(self, vae):
|
493 |
# TODO ここを高速化したい
|
494 |
print("caching latents.")
|
495 |
for info in tqdm(self.image_data.values()):
|
|
|
|
|
496 |
if info.latents_npz is not None:
|
497 |
info.latents = self.load_latents_from_npz(info, False)
|
498 |
info.latents = torch.FloatTensor(info.latents)
|
@@ -502,13 +568,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
502 |
continue
|
503 |
|
504 |
image = self.load_image(info.absolute_path)
|
505 |
-
image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
|
506 |
|
507 |
img_tensor = self.image_transforms(image)
|
508 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
509 |
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
510 |
|
511 |
-
if
|
512 |
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
|
513 |
img_tensor = self.image_transforms(image)
|
514 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
@@ -518,11 +584,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
518 |
image = Image.open(image_path)
|
519 |
return image.size
|
520 |
|
521 |
-
def load_image_with_face_info(self, image_path: str):
|
522 |
img = self.load_image(image_path)
|
523 |
|
524 |
face_cx = face_cy = face_w = face_h = 0
|
525 |
-
if
|
526 |
tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
|
527 |
if len(tokens) >= 5:
|
528 |
face_cx = int(tokens[-4])
|
@@ -533,7 +599,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
533 |
return img, face_cx, face_cy, face_w, face_h
|
534 |
|
535 |
# いい感じに切り出す
|
536 |
-
def crop_target(self, image, face_cx, face_cy, face_w, face_h):
|
537 |
height, width = image.shape[0:2]
|
538 |
if height == self.height and width == self.width:
|
539 |
return image
|
@@ -541,8 +607,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
541 |
# 画像サイズはsizeより大きいのでリサイズする
|
542 |
face_size = max(face_w, face_h)
|
543 |
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
|
544 |
-
min_scale = min(1.0, max(min_scale, self.size / (face_size *
|
545 |
-
max_scale = min(1.0, max(min_scale, self.size / (face_size *
|
546 |
if min_scale >= max_scale: # range指定がmin==max
|
547 |
scale = min_scale
|
548 |
else:
|
@@ -560,13 +626,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
560 |
for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
|
561 |
p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
|
562 |
|
563 |
-
if
|
564 |
# 背景も含めるために顔を中心に置く確率を高めつつずらす
|
565 |
range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
|
566 |
p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
|
567 |
else:
|
568 |
# range指定があるときのみ、すこしだけランダムに(わりと適当)
|
569 |
-
if
|
570 |
if face_size > self.size // 10 and face_size >= 40:
|
571 |
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
|
572 |
|
@@ -589,9 +655,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
589 |
return self._length
|
590 |
|
591 |
def __getitem__(self, index):
|
592 |
-
if index == 0:
|
593 |
-
self.shuffle_buckets()
|
594 |
-
|
595 |
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
596 |
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
597 |
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
@@ -604,28 +667,29 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
604 |
|
605 |
for image_key in bucket[image_index:image_index + bucket_batch_size]:
|
606 |
image_info = self.image_data[image_key]
|
|
|
607 |
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
608 |
|
609 |
# image/latentsを処理する
|
610 |
if image_info.latents is not None:
|
611 |
-
latents = image_info.latents if not
|
612 |
image = None
|
613 |
elif image_info.latents_npz is not None:
|
614 |
-
latents = self.load_latents_from_npz(image_info,
|
615 |
latents = torch.FloatTensor(latents)
|
616 |
image = None
|
617 |
else:
|
618 |
# 画像を読み込み、必要ならcropする
|
619 |
-
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
|
620 |
im_h, im_w = img.shape[0:2]
|
621 |
|
622 |
if self.enable_bucket:
|
623 |
-
img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
|
624 |
else:
|
625 |
if face_cx > 0: # 顔位置情報あり
|
626 |
-
img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
|
627 |
elif im_h > self.height or im_w > self.width:
|
628 |
-
assert
|
629 |
if im_h > self.height:
|
630 |
p = random.randint(0, im_h - self.height)
|
631 |
img = img[p:p + self.height]
|
@@ -637,8 +701,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
637 |
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
638 |
|
639 |
# augmentation
|
640 |
-
|
641 |
-
|
|
|
642 |
|
643 |
latents = None
|
644 |
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
@@ -646,7 +711,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
646 |
images.append(image)
|
647 |
latents_list.append(latents)
|
648 |
|
649 |
-
caption = self.process_caption(image_info.caption)
|
650 |
captions.append(caption)
|
651 |
if not self.token_padding_disabled: # this option might be omitted in future
|
652 |
input_ids_list.append(self.get_input_ids(caption))
|
@@ -677,9 +742,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
677 |
|
678 |
|
679 |
class DreamBoothDataset(BaseDataset):
|
680 |
-
def __init__(self,
|
681 |
-
super().__init__(tokenizer, max_token_length,
|
682 |
-
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
683 |
|
684 |
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
685 |
|
@@ -702,7 +766,7 @@ class DreamBoothDataset(BaseDataset):
|
|
702 |
self.bucket_reso_steps = None # この情報は使われない
|
703 |
self.bucket_no_upscale = False
|
704 |
|
705 |
-
def read_caption(img_path):
|
706 |
# captionの候補ファイル名を作る
|
707 |
base_name = os.path.splitext(img_path)[0]
|
708 |
base_name_face_det = base_name
|
@@ -725,153 +789,181 @@ class DreamBoothDataset(BaseDataset):
|
|
725 |
break
|
726 |
return caption
|
727 |
|
728 |
-
def load_dreambooth_dir(
|
729 |
-
if not os.path.isdir(
|
730 |
-
|
731 |
-
return
|
732 |
|
733 |
-
|
734 |
-
|
735 |
-
n_repeats = int(tokens[0])
|
736 |
-
except ValueError as e:
|
737 |
-
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
|
738 |
-
return 0, [], []
|
739 |
-
|
740 |
-
caption_by_folder = '_'.join(tokens[1:])
|
741 |
-
img_paths = glob_images(dir, "*")
|
742 |
-
print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
|
743 |
|
744 |
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
745 |
captions = []
|
746 |
for img_path in img_paths:
|
747 |
-
cap_for_img = read_caption(img_path)
|
748 |
-
|
|
|
|
|
|
|
|
|
749 |
|
750 |
-
self.set_tag_frequency(os.path.basename(
|
751 |
|
752 |
-
return
|
753 |
|
754 |
-
print("prepare
|
755 |
-
train_dirs = os.listdir(train_data_dir)
|
756 |
num_train_images = 0
|
757 |
-
|
758 |
-
|
759 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
760 |
|
761 |
for img_path, caption in zip(img_paths, captions):
|
762 |
-
info = ImageInfo(img_path,
|
763 |
-
|
|
|
|
|
|
|
764 |
|
765 |
-
|
|
|
766 |
|
767 |
print(f"{num_train_images} train images with repeating.")
|
768 |
self.num_train_images = num_train_images
|
769 |
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
print("prepare reg images.")
|
774 |
-
reg_infos: List[ImageInfo] = []
|
775 |
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
780 |
|
781 |
-
|
782 |
-
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
783 |
-
reg_infos.append(info)
|
784 |
|
785 |
-
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
786 |
|
787 |
-
|
788 |
-
|
789 |
-
|
|
|
|
|
|
|
|
|
|
|
790 |
|
791 |
-
|
792 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
793 |
else:
|
794 |
-
|
795 |
-
n = 0
|
796 |
-
first_loop = True
|
797 |
-
while n < num_train_images:
|
798 |
-
for info in reg_infos:
|
799 |
-
if first_loop:
|
800 |
-
self.register_image(info)
|
801 |
-
n += info.num_repeats
|
802 |
-
else:
|
803 |
-
info.num_repeats += 1
|
804 |
-
n += 1
|
805 |
-
if n >= num_train_images:
|
806 |
-
break
|
807 |
-
first_loop = False
|
808 |
|
809 |
-
|
|
|
|
|
810 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
811 |
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
print(f"loading existing metadata: {json_file_name}")
|
820 |
-
with open(json_file_name, "rt", encoding='utf-8') as f:
|
821 |
-
metadata = json.load(f)
|
822 |
-
else:
|
823 |
-
raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}")
|
824 |
|
825 |
-
|
826 |
-
|
827 |
-
self.batch_size = batch_size
|
828 |
|
829 |
-
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
838 |
-
abs_path = abs_path[0]
|
839 |
-
|
840 |
-
caption = img_md.get('caption')
|
841 |
-
tags = img_md.get('tags')
|
842 |
-
if caption is None:
|
843 |
-
caption = tags
|
844 |
-
elif tags is not None and len(tags) > 0:
|
845 |
-
caption = caption + ', ' + tags
|
846 |
-
tags_list.append(tags)
|
847 |
-
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
|
848 |
-
|
849 |
-
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
|
850 |
-
image_info.image_size = img_md.get('train_resolution')
|
851 |
-
|
852 |
-
if not self.color_aug and not self.random_crop:
|
853 |
-
# if npz exists, use them
|
854 |
-
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
|
855 |
-
|
856 |
-
self.register_image(image_info)
|
857 |
-
self.num_train_images = len(metadata) * dataset_repeats
|
858 |
-
self.num_reg_images = 0
|
859 |
|
860 |
-
|
861 |
-
|
862 |
-
|
|
|
|
|
|
|
863 |
|
864 |
# check existence of all npz files
|
865 |
-
use_npz_latents = not (
|
866 |
if use_npz_latents:
|
|
|
867 |
npz_any = False
|
868 |
npz_all = True
|
|
|
869 |
for image_info in self.image_data.values():
|
|
|
|
|
870 |
has_npz = image_info.latents_npz is not None
|
871 |
npz_any = npz_any or has_npz
|
872 |
|
873 |
-
if
|
874 |
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
|
|
875 |
npz_all = npz_all and has_npz
|
876 |
|
877 |
if npz_any and not npz_all:
|
@@ -883,7 +975,7 @@ class FineTuningDataset(BaseDataset):
|
|
883 |
elif not npz_all:
|
884 |
use_npz_latents = False
|
885 |
print(f"some of npz file does not exist. ignore npz files / いくつ���のnpzファイルが見つからないためnpzファイルを無視します")
|
886 |
-
if
|
887 |
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
888 |
# else:
|
889 |
# print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
@@ -929,7 +1021,7 @@ class FineTuningDataset(BaseDataset):
|
|
929 |
for image_info in self.image_data.values():
|
930 |
image_info.latents_npz = image_info.latents_npz_flipped = None
|
931 |
|
932 |
-
def image_key_to_npz_file(self, image_key):
|
933 |
base_name = os.path.splitext(image_key)[0]
|
934 |
npz_file_norm = base_name + '.npz'
|
935 |
|
@@ -941,8 +1033,8 @@ class FineTuningDataset(BaseDataset):
|
|
941 |
return npz_file_norm, npz_file_flip
|
942 |
|
943 |
# image_key is relative path
|
944 |
-
npz_file_norm = os.path.join(
|
945 |
-
npz_file_flip = os.path.join(
|
946 |
|
947 |
if not os.path.exists(npz_file_norm):
|
948 |
npz_file_norm = None
|
@@ -953,13 +1045,60 @@ class FineTuningDataset(BaseDataset):
|
|
953 |
return npz_file_norm, npz_file_flip
|
954 |
|
955 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
956 |
def debug_dataset(train_dataset, show_input_ids=False):
|
957 |
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
958 |
print("Escape for exit. / Escキーで中断、終了します")
|
959 |
|
960 |
train_dataset.set_current_epoch(1)
|
961 |
k = 0
|
962 |
-
|
|
|
|
|
|
|
963 |
if example['latents'] is not None:
|
964 |
print(f"sample has latents from npz file: {example['latents'].size()}")
|
965 |
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
|
@@ -1364,6 +1503,35 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
|
1364 |
help='enable v-parameterization training / v-parameterization学習を有効にする')
|
1365 |
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
|
1366 |
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1367 |
|
1368 |
|
1369 |
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
@@ -1387,10 +1555,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|
1387 |
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
1388 |
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
1389 |
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
1390 |
-
parser.add_argument("--use_8bit_adam", action="store_true",
|
1391 |
-
help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
1392 |
-
parser.add_argument("--use_lion_optimizer", action="store_true",
|
1393 |
-
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
1394 |
parser.add_argument("--mem_eff_attn", action="store_true",
|
1395 |
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
1396 |
parser.add_argument("--xformers", action="store_true",
|
@@ -1398,7 +1562,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|
1398 |
parser.add_argument("--vae", type=str, default=None,
|
1399 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
1400 |
|
1401 |
-
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
1402 |
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
1403 |
parser.add_argument("--max_train_epochs", type=int, default=None,
|
1404 |
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
@@ -1419,15 +1582,23 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|
1419 |
parser.add_argument("--logging_dir", type=str, default=None,
|
1420 |
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
1421 |
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
1422 |
-
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
1423 |
-
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
|
1424 |
-
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
1425 |
-
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
1426 |
parser.add_argument("--noise_offset", type=float, default=None,
|
1427 |
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
|
1428 |
parser.add_argument("--lowram", action="store_true",
|
1429 |
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
|
1430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1431 |
if support_dreambooth:
|
1432 |
# DreamBooth training
|
1433 |
parser.add_argument("--prior_loss_weight", type=float, default=1.0,
|
@@ -1449,8 +1620,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
|
1449 |
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
|
1450 |
parser.add_argument("--caption_extention", type=str, default=None,
|
1451 |
help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
|
1452 |
-
parser.add_argument("--keep_tokens", type=int, default=
|
1453 |
-
help="keep heading N tokens when shuffling caption tokens / caption
|
1454 |
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
|
1455 |
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
|
1456 |
parser.add_argument("--face_crop_aug_range", type=str, default=None,
|
@@ -1475,11 +1646,11 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
|
1475 |
if support_caption_dropout:
|
1476 |
# Textual Inversion はcaptionのdropoutをsupportしない
|
1477 |
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
1478 |
-
parser.add_argument("--caption_dropout_rate", type=float, default=0,
|
1479 |
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
1480 |
-
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=
|
1481 |
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
1482 |
-
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
|
1483 |
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
|
1484 |
|
1485 |
if support_dreambooth:
|
@@ -1504,16 +1675,256 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
|
1504 |
# region utils
|
1505 |
|
1506 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1507 |
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
1508 |
# backward compatibility
|
1509 |
if args.caption_extention is not None:
|
1510 |
args.caption_extension = args.caption_extention
|
1511 |
args.caption_extention = None
|
1512 |
|
1513 |
-
if args.cache_latents:
|
1514 |
-
assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
|
1515 |
-
assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
|
1516 |
-
|
1517 |
# assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
|
1518 |
if args.resolution is not None:
|
1519 |
args.resolution = tuple([int(r) for r in args.resolution.split(',')])
|
@@ -1536,12 +1947,28 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
|
1536 |
|
1537 |
def load_tokenizer(args: argparse.Namespace):
|
1538 |
print("prepare tokenizer")
|
1539 |
-
if args.v2
|
1540 |
-
|
1541 |
-
|
1542 |
-
|
1543 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1544 |
print(f"update token length: {args.max_token_length}")
|
|
|
|
|
|
|
|
|
|
|
1545 |
return tokenizer
|
1546 |
|
1547 |
|
@@ -1592,13 +2019,19 @@ def prepare_dtype(args: argparse.Namespace):
|
|
1592 |
|
1593 |
|
1594 |
def load_target_model(args: argparse.Namespace, weight_dtype):
|
1595 |
-
|
|
|
|
|
1596 |
if load_stable_diffusion_format:
|
1597 |
print("load StableDiffusion checkpoint")
|
1598 |
-
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2,
|
1599 |
else:
|
1600 |
print("load Diffusers pretrained models")
|
1601 |
-
|
|
|
|
|
|
|
|
|
1602 |
text_encoder = pipe.text_encoder
|
1603 |
vae = pipe.vae
|
1604 |
unet = pipe.unet
|
@@ -1767,6 +2200,197 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
|
1767 |
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
1768 |
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
1769 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1770 |
# endregion
|
1771 |
|
1772 |
# region 前処理用
|
|
|
1 |
# common functions for training
|
2 |
|
3 |
import argparse
|
4 |
+
import importlib
|
5 |
import json
|
6 |
+
import re
|
7 |
import shutil
|
8 |
import time
|
9 |
+
from typing import (
|
10 |
+
Dict,
|
11 |
+
List,
|
12 |
+
NamedTuple,
|
13 |
+
Optional,
|
14 |
+
Sequence,
|
15 |
+
Tuple,
|
16 |
+
Union,
|
17 |
+
)
|
18 |
from accelerate import Accelerator
|
|
|
19 |
import glob
|
20 |
import math
|
21 |
import os
|
|
|
26 |
|
27 |
from tqdm import tqdm
|
28 |
import torch
|
29 |
+
from torch.optim import Optimizer
|
30 |
from torchvision import transforms
|
31 |
from transformers import CLIPTokenizer
|
32 |
+
import transformers
|
33 |
import diffusers
|
34 |
+
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
35 |
+
from diffusers import (StableDiffusionPipeline, DDPMScheduler,
|
36 |
+
EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler,
|
37 |
+
LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler,
|
38 |
+
KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler)
|
39 |
import albumentations as albu
|
40 |
import numpy as np
|
41 |
from PIL import Image
|
|
|
210 |
batch_index: int
|
211 |
|
212 |
|
213 |
+
class AugHelper:
|
214 |
+
def __init__(self):
|
215 |
+
# prepare all possible augmentators
|
216 |
+
color_aug_method = albu.OneOf([
|
217 |
+
albu.HueSaturationValue(8, 0, 0, p=.5),
|
218 |
+
albu.RandomGamma((95, 105), p=.5),
|
219 |
+
], p=.33)
|
220 |
+
flip_aug_method = albu.HorizontalFlip(p=0.5)
|
221 |
+
|
222 |
+
# key: (use_color_aug, use_flip_aug)
|
223 |
+
self.augmentors = {
|
224 |
+
(True, True): albu.Compose([
|
225 |
+
color_aug_method,
|
226 |
+
flip_aug_method,
|
227 |
+
], p=1.),
|
228 |
+
(True, False): albu.Compose([
|
229 |
+
color_aug_method,
|
230 |
+
], p=1.),
|
231 |
+
(False, True): albu.Compose([
|
232 |
+
flip_aug_method,
|
233 |
+
], p=1.),
|
234 |
+
(False, False): None
|
235 |
+
}
|
236 |
+
|
237 |
+
def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]:
|
238 |
+
return self.augmentors[(use_color_aug, use_flip_aug)]
|
239 |
+
|
240 |
+
|
241 |
+
class BaseSubset:
|
242 |
+
def __init__(self, image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, keep_tokens: int, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], random_crop: bool, caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float) -> None:
|
243 |
+
self.image_dir = image_dir
|
244 |
+
self.num_repeats = num_repeats
|
245 |
+
self.shuffle_caption = shuffle_caption
|
246 |
+
self.keep_tokens = keep_tokens
|
247 |
+
self.color_aug = color_aug
|
248 |
+
self.flip_aug = flip_aug
|
249 |
+
self.face_crop_aug_range = face_crop_aug_range
|
250 |
+
self.random_crop = random_crop
|
251 |
+
self.caption_dropout_rate = caption_dropout_rate
|
252 |
+
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
|
253 |
+
self.caption_tag_dropout_rate = caption_tag_dropout_rate
|
254 |
+
|
255 |
+
self.img_count = 0
|
256 |
+
|
257 |
+
|
258 |
+
class DreamBoothSubset(BaseSubset):
|
259 |
+
def __init__(self, image_dir: str, is_reg: bool, class_tokens: Optional[str], caption_extension: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
|
260 |
+
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
261 |
+
|
262 |
+
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
|
263 |
+
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
|
264 |
+
|
265 |
+
self.is_reg = is_reg
|
266 |
+
self.class_tokens = class_tokens
|
267 |
+
self.caption_extension = caption_extension
|
268 |
+
|
269 |
+
def __eq__(self, other) -> bool:
|
270 |
+
if not isinstance(other, DreamBoothSubset):
|
271 |
+
return NotImplemented
|
272 |
+
return self.image_dir == other.image_dir
|
273 |
+
|
274 |
+
|
275 |
+
class FineTuningSubset(BaseSubset):
|
276 |
+
def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
|
277 |
+
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
278 |
+
|
279 |
+
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
|
280 |
+
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
|
281 |
+
|
282 |
+
self.metadata_file = metadata_file
|
283 |
+
|
284 |
+
def __eq__(self, other) -> bool:
|
285 |
+
if not isinstance(other, FineTuningSubset):
|
286 |
+
return NotImplemented
|
287 |
+
return self.metadata_file == other.metadata_file
|
288 |
+
|
289 |
+
|
290 |
class BaseDataset(torch.utils.data.Dataset):
|
291 |
+
def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None:
|
292 |
super().__init__()
|
293 |
+
self.tokenizer = tokenizer
|
294 |
self.max_token_length = max_token_length
|
|
|
|
|
295 |
# width/height is used when enable_bucket==False
|
296 |
self.width, self.height = (None, None) if resolution is None else resolution
|
|
|
|
|
|
|
297 |
self.debug_dataset = debug_dataset
|
298 |
+
|
299 |
+
self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
|
300 |
+
|
301 |
self.token_padding_disabled = False
|
|
|
|
|
302 |
self.tag_frequency = {}
|
303 |
|
304 |
self.enable_bucket = False
|
|
|
312 |
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
313 |
|
314 |
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
|
|
|
|
|
|
315 |
|
316 |
# augmentation
|
317 |
+
self.aug_helper = AugHelper()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
|
319 |
self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
|
320 |
|
321 |
self.image_data: Dict[str, ImageInfo] = {}
|
322 |
+
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
|
323 |
|
324 |
self.replacements = {}
|
325 |
|
326 |
def set_current_epoch(self, epoch):
|
327 |
self.current_epoch = epoch
|
328 |
+
self.shuffle_buckets()
|
|
|
|
|
|
|
|
|
|
|
329 |
|
330 |
def set_tag_frequency(self, dir_name, captions):
|
331 |
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
332 |
self.tag_frequency[dir_name] = frequency_for_dir
|
333 |
for caption in captions:
|
334 |
for tag in caption.split(","):
|
335 |
+
tag = tag.strip()
|
336 |
+
if tag:
|
337 |
tag = tag.lower()
|
338 |
frequency = frequency_for_dir.get(tag, 0)
|
339 |
frequency_for_dir[tag] = frequency + 1
|
|
|
344 |
def add_replacement(self, str_from, str_to):
|
345 |
self.replacements[str_from] = str_to
|
346 |
|
347 |
+
def process_caption(self, subset: BaseSubset, caption):
|
348 |
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
349 |
+
is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate
|
350 |
+
is_drop_out = is_drop_out or subset.caption_dropout_every_n_epochs > 0 and self.current_epoch % subset.caption_dropout_every_n_epochs == 0
|
351 |
|
352 |
if is_drop_out:
|
353 |
caption = ""
|
354 |
else:
|
355 |
+
if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0:
|
356 |
def dropout_tags(tokens):
|
357 |
+
if subset.caption_tag_dropout_rate <= 0:
|
358 |
return tokens
|
359 |
l = []
|
360 |
for token in tokens:
|
361 |
+
if random.random() >= subset.caption_tag_dropout_rate:
|
362 |
l.append(token)
|
363 |
return l
|
364 |
|
365 |
+
fixed_tokens = []
|
366 |
+
flex_tokens = [t.strip() for t in caption.strip().split(",")]
|
367 |
+
if subset.keep_tokens > 0:
|
368 |
+
fixed_tokens = flex_tokens[:subset.keep_tokens]
|
369 |
+
flex_tokens = flex_tokens[subset.keep_tokens:]
|
|
|
|
|
|
|
|
|
|
|
370 |
|
371 |
+
if subset.shuffle_caption:
|
372 |
+
random.shuffle(flex_tokens)
|
373 |
|
374 |
+
flex_tokens = dropout_tags(flex_tokens)
|
375 |
|
376 |
+
caption = ", ".join(fixed_tokens + flex_tokens)
|
|
|
377 |
|
378 |
# textual inversion対応
|
379 |
for str_from, str_to in self.replacements.items():
|
|
|
427 |
input_ids = torch.stack(iids_list) # 3,77
|
428 |
return input_ids
|
429 |
|
430 |
+
def register_image(self, info: ImageInfo, subset: BaseSubset):
|
431 |
self.image_data[info.image_key] = info
|
432 |
+
self.image_to_subset[info.image_key] = subset
|
433 |
|
434 |
def make_buckets(self):
|
435 |
'''
|
|
|
528 |
img = np.array(image, np.uint8)
|
529 |
return img
|
530 |
|
531 |
+
def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size):
|
532 |
image_height, image_width = image.shape[0:2]
|
533 |
|
534 |
if image_width != resized_size[0] or image_height != resized_size[1]:
|
|
|
538 |
image_height, image_width = image.shape[0:2]
|
539 |
if image_width > reso[0]:
|
540 |
trim_size = image_width - reso[0]
|
541 |
+
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
|
542 |
# print("w", trim_size, p)
|
543 |
image = image[:, p:p + reso[0]]
|
544 |
if image_height > reso[1]:
|
545 |
trim_size = image_height - reso[1]
|
546 |
+
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
|
547 |
# print("h", trim_size, p)
|
548 |
image = image[p:p + reso[1]]
|
549 |
|
550 |
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
551 |
return image
|
552 |
|
553 |
+
def is_latent_cacheable(self):
|
554 |
+
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
555 |
+
|
556 |
def cache_latents(self, vae):
|
557 |
# TODO ここを高速化したい
|
558 |
print("caching latents.")
|
559 |
for info in tqdm(self.image_data.values()):
|
560 |
+
subset = self.image_to_subset[info.image_key]
|
561 |
+
|
562 |
if info.latents_npz is not None:
|
563 |
info.latents = self.load_latents_from_npz(info, False)
|
564 |
info.latents = torch.FloatTensor(info.latents)
|
|
|
568 |
continue
|
569 |
|
570 |
image = self.load_image(info.absolute_path)
|
571 |
+
image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
|
572 |
|
573 |
img_tensor = self.image_transforms(image)
|
574 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
575 |
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
576 |
|
577 |
+
if subset.flip_aug:
|
578 |
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
|
579 |
img_tensor = self.image_transforms(image)
|
580 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
|
|
584 |
image = Image.open(image_path)
|
585 |
return image.size
|
586 |
|
587 |
+
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
|
588 |
img = self.load_image(image_path)
|
589 |
|
590 |
face_cx = face_cy = face_w = face_h = 0
|
591 |
+
if subset.face_crop_aug_range is not None:
|
592 |
tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
|
593 |
if len(tokens) >= 5:
|
594 |
face_cx = int(tokens[-4])
|
|
|
599 |
return img, face_cx, face_cy, face_w, face_h
|
600 |
|
601 |
# いい感じに切り出す
|
602 |
+
def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h):
|
603 |
height, width = image.shape[0:2]
|
604 |
if height == self.height and width == self.width:
|
605 |
return image
|
|
|
607 |
# 画像サイズはsizeより大きいのでリサイズする
|
608 |
face_size = max(face_w, face_h)
|
609 |
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
|
610 |
+
min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
|
611 |
+
max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
|
612 |
if min_scale >= max_scale: # range指定がmin==max
|
613 |
scale = min_scale
|
614 |
else:
|
|
|
626 |
for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
|
627 |
p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
|
628 |
|
629 |
+
if subset.random_crop:
|
630 |
# 背景も含めるために顔を中心に置く確率を高めつつずらす
|
631 |
range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
|
632 |
p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
|
633 |
else:
|
634 |
# range指定があるときのみ、すこしだけランダムに(わりと適当)
|
635 |
+
if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]:
|
636 |
if face_size > self.size // 10 and face_size >= 40:
|
637 |
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
|
638 |
|
|
|
655 |
return self._length
|
656 |
|
657 |
def __getitem__(self, index):
|
|
|
|
|
|
|
658 |
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
659 |
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
660 |
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
|
|
667 |
|
668 |
for image_key in bucket[image_index:image_index + bucket_batch_size]:
|
669 |
image_info = self.image_data[image_key]
|
670 |
+
subset = self.image_to_subset[image_key]
|
671 |
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
672 |
|
673 |
# image/latentsを処理する
|
674 |
if image_info.latents is not None:
|
675 |
+
latents = image_info.latents if not subset.flip_aug or random.random() < .5 else image_info.latents_flipped
|
676 |
image = None
|
677 |
elif image_info.latents_npz is not None:
|
678 |
+
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= .5)
|
679 |
latents = torch.FloatTensor(latents)
|
680 |
image = None
|
681 |
else:
|
682 |
# 画像を読み込み、必要ならcropする
|
683 |
+
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path)
|
684 |
im_h, im_w = img.shape[0:2]
|
685 |
|
686 |
if self.enable_bucket:
|
687 |
+
img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size)
|
688 |
else:
|
689 |
if face_cx > 0: # 顔位置情報あり
|
690 |
+
img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
|
691 |
elif im_h > self.height or im_w > self.width:
|
692 |
+
assert subset.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
|
693 |
if im_h > self.height:
|
694 |
p = random.randint(0, im_h - self.height)
|
695 |
img = img[p:p + self.height]
|
|
|
701 |
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
702 |
|
703 |
# augmentation
|
704 |
+
aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
|
705 |
+
if aug is not None:
|
706 |
+
img = aug(image=img)['image']
|
707 |
|
708 |
latents = None
|
709 |
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
|
|
711 |
images.append(image)
|
712 |
latents_list.append(latents)
|
713 |
|
714 |
+
caption = self.process_caption(subset, image_info.caption)
|
715 |
captions.append(caption)
|
716 |
if not self.token_padding_disabled: # this option might be omitted in future
|
717 |
input_ids_list.append(self.get_input_ids(caption))
|
|
|
742 |
|
743 |
|
744 |
class DreamBoothDataset(BaseDataset):
|
745 |
+
def __init__(self, subsets: Sequence[DreamBoothSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset) -> None:
|
746 |
+
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
|
|
747 |
|
748 |
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
749 |
|
|
|
766 |
self.bucket_reso_steps = None # この情報は使われない
|
767 |
self.bucket_no_upscale = False
|
768 |
|
769 |
+
def read_caption(img_path, caption_extension):
|
770 |
# captionの候補ファイル名を作る
|
771 |
base_name = os.path.splitext(img_path)[0]
|
772 |
base_name_face_det = base_name
|
|
|
789 |
break
|
790 |
return caption
|
791 |
|
792 |
+
def load_dreambooth_dir(subset: DreamBoothSubset):
|
793 |
+
if not os.path.isdir(subset.image_dir):
|
794 |
+
print(f"not directory: {subset.image_dir}")
|
795 |
+
return [], []
|
796 |
|
797 |
+
img_paths = glob_images(subset.image_dir, "*")
|
798 |
+
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
799 |
|
800 |
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
801 |
captions = []
|
802 |
for img_path in img_paths:
|
803 |
+
cap_for_img = read_caption(img_path, subset.caption_extension)
|
804 |
+
if cap_for_img is None and subset.class_tokens is None:
|
805 |
+
print(f"neither caption file nor class tokens are found. use empty caption for {img_path}")
|
806 |
+
captions.append("")
|
807 |
+
else:
|
808 |
+
captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
|
809 |
|
810 |
+
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
811 |
|
812 |
+
return img_paths, captions
|
813 |
|
814 |
+
print("prepare images.")
|
|
|
815 |
num_train_images = 0
|
816 |
+
num_reg_images = 0
|
817 |
+
reg_infos: List[ImageInfo] = []
|
818 |
+
for subset in subsets:
|
819 |
+
if subset.num_repeats < 1:
|
820 |
+
print(
|
821 |
+
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
822 |
+
continue
|
823 |
+
|
824 |
+
if subset in self.subsets:
|
825 |
+
print(
|
826 |
+
f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
827 |
+
continue
|
828 |
+
|
829 |
+
img_paths, captions = load_dreambooth_dir(subset)
|
830 |
+
if len(img_paths) < 1:
|
831 |
+
print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
|
832 |
+
continue
|
833 |
+
|
834 |
+
if subset.is_reg:
|
835 |
+
num_reg_images += subset.num_repeats * len(img_paths)
|
836 |
+
else:
|
837 |
+
num_train_images += subset.num_repeats * len(img_paths)
|
838 |
|
839 |
for img_path, caption in zip(img_paths, captions):
|
840 |
+
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
|
841 |
+
if subset.is_reg:
|
842 |
+
reg_infos.append(info)
|
843 |
+
else:
|
844 |
+
self.register_image(info, subset)
|
845 |
|
846 |
+
subset.img_count = len(img_paths)
|
847 |
+
self.subsets.append(subset)
|
848 |
|
849 |
print(f"{num_train_images} train images with repeating.")
|
850 |
self.num_train_images = num_train_images
|
851 |
|
852 |
+
print(f"{num_reg_images} reg images.")
|
853 |
+
if num_train_images < num_reg_images:
|
854 |
+
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
|
|
|
|
|
855 |
|
856 |
+
if num_reg_images == 0:
|
857 |
+
print("no regularization images / 正則化画像が見つかりませんでした")
|
858 |
+
else:
|
859 |
+
# num_repeatsを計算する:どうせ大した数ではないのでループで処理する
|
860 |
+
n = 0
|
861 |
+
first_loop = True
|
862 |
+
while n < num_train_images:
|
863 |
+
for info in reg_infos:
|
864 |
+
if first_loop:
|
865 |
+
self.register_image(info, subset)
|
866 |
+
n += info.num_repeats
|
867 |
+
else:
|
868 |
+
info.num_repeats += 1
|
869 |
+
n += 1
|
870 |
+
if n >= num_train_images:
|
871 |
+
break
|
872 |
+
first_loop = False
|
873 |
|
874 |
+
self.num_reg_images = num_reg_images
|
|
|
|
|
875 |
|
|
|
876 |
|
877 |
+
class FineTuningDataset(BaseDataset):
|
878 |
+
def __init__(self, subsets: Sequence[FineTuningSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset) -> None:
|
879 |
+
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
880 |
+
|
881 |
+
self.batch_size = batch_size
|
882 |
+
|
883 |
+
self.num_train_images = 0
|
884 |
+
self.num_reg_images = 0
|
885 |
|
886 |
+
for subset in subsets:
|
887 |
+
if subset.num_repeats < 1:
|
888 |
+
print(
|
889 |
+
f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
890 |
+
continue
|
891 |
+
|
892 |
+
if subset in self.subsets:
|
893 |
+
print(
|
894 |
+
f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
895 |
+
continue
|
896 |
+
|
897 |
+
# メタデータを読み込む
|
898 |
+
if os.path.exists(subset.metadata_file):
|
899 |
+
print(f"loading existing metadata: {subset.metadata_file}")
|
900 |
+
with open(subset.metadata_file, "rt", encoding='utf-8') as f:
|
901 |
+
metadata = json.load(f)
|
902 |
else:
|
903 |
+
raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
904 |
|
905 |
+
if len(metadata) < 1:
|
906 |
+
print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します")
|
907 |
+
continue
|
908 |
|
909 |
+
tags_list = []
|
910 |
+
for image_key, img_md in metadata.items():
|
911 |
+
# path情報を作る
|
912 |
+
if os.path.exists(image_key):
|
913 |
+
abs_path = image_key
|
914 |
+
else:
|
915 |
+
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
916 |
+
if os.path.exists(npz_path):
|
917 |
+
abs_path = npz_path
|
918 |
+
else:
|
919 |
+
# わりといい加減だがいい方法が思いつかん
|
920 |
+
abs_path = glob_images(subset.image_dir, image_key)
|
921 |
+
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
922 |
+
abs_path = abs_path[0]
|
923 |
|
924 |
+
caption = img_md.get('caption')
|
925 |
+
tags = img_md.get('tags')
|
926 |
+
if caption is None:
|
927 |
+
caption = tags
|
928 |
+
elif tags is not None and len(tags) > 0:
|
929 |
+
caption = caption + ', ' + tags
|
930 |
+
tags_list.append(tags)
|
|
|
|
|
|
|
|
|
|
|
931 |
|
932 |
+
if caption is None:
|
933 |
+
caption = ""
|
|
|
934 |
|
935 |
+
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
|
936 |
+
image_info.image_size = img_md.get('train_resolution')
|
937 |
+
|
938 |
+
if not subset.color_aug and not subset.random_crop:
|
939 |
+
# if npz exists, use them
|
940 |
+
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
|
941 |
+
|
942 |
+
self.register_image(image_info, subset)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
943 |
|
944 |
+
self.num_train_images += len(metadata) * subset.num_repeats
|
945 |
+
|
946 |
+
# TODO do not record tag freq when no tag
|
947 |
+
self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list)
|
948 |
+
subset.img_count = len(metadata)
|
949 |
+
self.subsets.append(subset)
|
950 |
|
951 |
# check existence of all npz files
|
952 |
+
use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
|
953 |
if use_npz_latents:
|
954 |
+
flip_aug_in_subset = False
|
955 |
npz_any = False
|
956 |
npz_all = True
|
957 |
+
|
958 |
for image_info in self.image_data.values():
|
959 |
+
subset = self.image_to_subset[image_info.image_key]
|
960 |
+
|
961 |
has_npz = image_info.latents_npz is not None
|
962 |
npz_any = npz_any or has_npz
|
963 |
|
964 |
+
if subset.flip_aug:
|
965 |
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
966 |
+
flip_aug_in_subset = True
|
967 |
npz_all = npz_all and has_npz
|
968 |
|
969 |
if npz_any and not npz_all:
|
|
|
975 |
elif not npz_all:
|
976 |
use_npz_latents = False
|
977 |
print(f"some of npz file does not exist. ignore npz files / いくつ���のnpzファイルが見つからないためnpzファイルを無視します")
|
978 |
+
if flip_aug_in_subset:
|
979 |
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
980 |
# else:
|
981 |
# print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
|
|
1021 |
for image_info in self.image_data.values():
|
1022 |
image_info.latents_npz = image_info.latents_npz_flipped = None
|
1023 |
|
1024 |
+
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
|
1025 |
base_name = os.path.splitext(image_key)[0]
|
1026 |
npz_file_norm = base_name + '.npz'
|
1027 |
|
|
|
1033 |
return npz_file_norm, npz_file_flip
|
1034 |
|
1035 |
# image_key is relative path
|
1036 |
+
npz_file_norm = os.path.join(subset.image_dir, image_key + '.npz')
|
1037 |
+
npz_file_flip = os.path.join(subset.image_dir, image_key + '_flip.npz')
|
1038 |
|
1039 |
if not os.path.exists(npz_file_norm):
|
1040 |
npz_file_norm = None
|
|
|
1045 |
return npz_file_norm, npz_file_flip
|
1046 |
|
1047 |
|
1048 |
+
# behave as Dataset mock
|
1049 |
+
class DatasetGroup(torch.utils.data.ConcatDataset):
|
1050 |
+
def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
|
1051 |
+
self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]]
|
1052 |
+
|
1053 |
+
super().__init__(datasets)
|
1054 |
+
|
1055 |
+
self.image_data = {}
|
1056 |
+
self.num_train_images = 0
|
1057 |
+
self.num_reg_images = 0
|
1058 |
+
|
1059 |
+
# simply concat together
|
1060 |
+
# TODO: handling image_data key duplication among dataset
|
1061 |
+
# In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset.
|
1062 |
+
for dataset in datasets:
|
1063 |
+
self.image_data.update(dataset.image_data)
|
1064 |
+
self.num_train_images += dataset.num_train_images
|
1065 |
+
self.num_reg_images += dataset.num_reg_images
|
1066 |
+
|
1067 |
+
def add_replacement(self, str_from, str_to):
|
1068 |
+
for dataset in self.datasets:
|
1069 |
+
dataset.add_replacement(str_from, str_to)
|
1070 |
+
|
1071 |
+
# def make_buckets(self):
|
1072 |
+
# for dataset in self.datasets:
|
1073 |
+
# dataset.make_buckets()
|
1074 |
+
|
1075 |
+
def cache_latents(self, vae):
|
1076 |
+
for i, dataset in enumerate(self.datasets):
|
1077 |
+
print(f"[Dataset {i}]")
|
1078 |
+
dataset.cache_latents(vae)
|
1079 |
+
|
1080 |
+
def is_latent_cacheable(self) -> bool:
|
1081 |
+
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
1082 |
+
|
1083 |
+
def set_current_epoch(self, epoch):
|
1084 |
+
for dataset in self.datasets:
|
1085 |
+
dataset.set_current_epoch(epoch)
|
1086 |
+
|
1087 |
+
def disable_token_padding(self):
|
1088 |
+
for dataset in self.datasets:
|
1089 |
+
dataset.disable_token_padding()
|
1090 |
+
|
1091 |
+
|
1092 |
def debug_dataset(train_dataset, show_input_ids=False):
|
1093 |
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
1094 |
print("Escape for exit. / Escキーで中断、終了します")
|
1095 |
|
1096 |
train_dataset.set_current_epoch(1)
|
1097 |
k = 0
|
1098 |
+
indices = list(range(len(train_dataset)))
|
1099 |
+
random.shuffle(indices)
|
1100 |
+
for i, idx in enumerate(indices):
|
1101 |
+
example = train_dataset[idx]
|
1102 |
if example['latents'] is not None:
|
1103 |
print(f"sample has latents from npz file: {example['latents'].size()}")
|
1104 |
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
|
|
|
1503 |
help='enable v-parameterization training / v-parameterization学習を有効にする')
|
1504 |
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
|
1505 |
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
1506 |
+
parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
|
1507 |
+
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
|
1508 |
+
|
1509 |
+
|
1510 |
+
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
1511 |
+
parser.add_argument("--optimizer_type", type=str, default="",
|
1512 |
+
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor")
|
1513 |
+
|
1514 |
+
# backward compatibility
|
1515 |
+
parser.add_argument("--use_8bit_adam", action="store_true",
|
1516 |
+
help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
1517 |
+
parser.add_argument("--use_lion_optimizer", action="store_true",
|
1518 |
+
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
1519 |
+
|
1520 |
+
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
1521 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
1522 |
+
help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない")
|
1523 |
+
|
1524 |
+
parser.add_argument("--optimizer_args", type=str, default=None, nargs='*',
|
1525 |
+
help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")")
|
1526 |
+
|
1527 |
+
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
1528 |
+
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor")
|
1529 |
+
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
1530 |
+
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
1531 |
+
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
1532 |
+
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
1533 |
+
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
1534 |
+
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
1535 |
|
1536 |
|
1537 |
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
|
|
1555 |
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
1556 |
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
1557 |
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
|
|
|
|
|
|
|
|
1558 |
parser.add_argument("--mem_eff_attn", action="store_true",
|
1559 |
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
1560 |
parser.add_argument("--xformers", action="store_true",
|
|
|
1562 |
parser.add_argument("--vae", type=str, default=None,
|
1563 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
1564 |
|
|
|
1565 |
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
1566 |
parser.add_argument("--max_train_epochs", type=int, default=None,
|
1567 |
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
|
|
1582 |
parser.add_argument("--logging_dir", type=str, default=None,
|
1583 |
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
1584 |
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
|
|
|
|
|
|
|
|
1585 |
parser.add_argument("--noise_offset", type=float, default=None,
|
1586 |
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
|
1587 |
parser.add_argument("--lowram", action="store_true",
|
1588 |
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
|
1589 |
|
1590 |
+
parser.add_argument("--sample_every_n_steps", type=int, default=None,
|
1591 |
+
help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する")
|
1592 |
+
parser.add_argument("--sample_every_n_epochs", type=int, default=None,
|
1593 |
+
help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)")
|
1594 |
+
parser.add_argument("--sample_prompts", type=str, default=None,
|
1595 |
+
help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル")
|
1596 |
+
parser.add_argument('--sample_sampler', type=str, default='ddim',
|
1597 |
+
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
|
1598 |
+
'dpmsolver++', 'dpmsingle',
|
1599 |
+
'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'],
|
1600 |
+
help=f'sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類')
|
1601 |
+
|
1602 |
if support_dreambooth:
|
1603 |
# DreamBooth training
|
1604 |
parser.add_argument("--prior_loss_weight", type=float, default=1.0,
|
|
|
1620 |
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
|
1621 |
parser.add_argument("--caption_extention", type=str, default=None,
|
1622 |
help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
|
1623 |
+
parser.add_argument("--keep_tokens", type=int, default=0,
|
1624 |
+
help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)")
|
1625 |
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
|
1626 |
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
|
1627 |
parser.add_argument("--face_crop_aug_range", type=str, default=None,
|
|
|
1646 |
if support_caption_dropout:
|
1647 |
# Textual Inversion はcaptionのdropoutをsupportしない
|
1648 |
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
1649 |
+
parser.add_argument("--caption_dropout_rate", type=float, default=0.0,
|
1650 |
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
1651 |
+
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=0,
|
1652 |
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
1653 |
+
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0.0,
|
1654 |
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
|
1655 |
|
1656 |
if support_dreambooth:
|
|
|
1675 |
# region utils
|
1676 |
|
1677 |
|
1678 |
+
def get_optimizer(args, trainable_params):
|
1679 |
+
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
|
1680 |
+
|
1681 |
+
optimizer_type = args.optimizer_type
|
1682 |
+
if args.use_8bit_adam:
|
1683 |
+
assert not args.use_lion_optimizer, "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています"
|
1684 |
+
assert optimizer_type is None or optimizer_type == "", "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています"
|
1685 |
+
optimizer_type = "AdamW8bit"
|
1686 |
+
|
1687 |
+
elif args.use_lion_optimizer:
|
1688 |
+
assert optimizer_type is None or optimizer_type == "", "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています"
|
1689 |
+
optimizer_type = "Lion"
|
1690 |
+
|
1691 |
+
if optimizer_type is None or optimizer_type == "":
|
1692 |
+
optimizer_type = "AdamW"
|
1693 |
+
optimizer_type = optimizer_type.lower()
|
1694 |
+
|
1695 |
+
# 引数を分解する:boolとfloat、tupleのみ対応
|
1696 |
+
optimizer_kwargs = {}
|
1697 |
+
if args.optimizer_args is not None and len(args.optimizer_args) > 0:
|
1698 |
+
for arg in args.optimizer_args:
|
1699 |
+
key, value = arg.split('=')
|
1700 |
+
|
1701 |
+
value = value.split(",")
|
1702 |
+
for i in range(len(value)):
|
1703 |
+
if value[i].lower() == "true" or value[i].lower() == "false":
|
1704 |
+
value[i] = (value[i].lower() == "true")
|
1705 |
+
else:
|
1706 |
+
value[i] = float(value[i])
|
1707 |
+
if len(value) == 1:
|
1708 |
+
value = value[0]
|
1709 |
+
else:
|
1710 |
+
value = tuple(value)
|
1711 |
+
|
1712 |
+
optimizer_kwargs[key] = value
|
1713 |
+
# print("optkwargs:", optimizer_kwargs)
|
1714 |
+
|
1715 |
+
lr = args.learning_rate
|
1716 |
+
|
1717 |
+
if optimizer_type == "AdamW8bit".lower():
|
1718 |
+
try:
|
1719 |
+
import bitsandbytes as bnb
|
1720 |
+
except ImportError:
|
1721 |
+
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
1722 |
+
print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
|
1723 |
+
optimizer_class = bnb.optim.AdamW8bit
|
1724 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1725 |
+
|
1726 |
+
elif optimizer_type == "SGDNesterov8bit".lower():
|
1727 |
+
try:
|
1728 |
+
import bitsandbytes as bnb
|
1729 |
+
except ImportError:
|
1730 |
+
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
1731 |
+
print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
|
1732 |
+
if "momentum" not in optimizer_kwargs:
|
1733 |
+
print(f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
|
1734 |
+
optimizer_kwargs["momentum"] = 0.9
|
1735 |
+
|
1736 |
+
optimizer_class = bnb.optim.SGD8bit
|
1737 |
+
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
1738 |
+
|
1739 |
+
elif optimizer_type == "Lion".lower():
|
1740 |
+
try:
|
1741 |
+
import lion_pytorch
|
1742 |
+
except ImportError:
|
1743 |
+
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
1744 |
+
print(f"use Lion optimizer | {optimizer_kwargs}")
|
1745 |
+
optimizer_class = lion_pytorch.Lion
|
1746 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1747 |
+
|
1748 |
+
elif optimizer_type == "SGDNesterov".lower():
|
1749 |
+
print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
|
1750 |
+
if "momentum" not in optimizer_kwargs:
|
1751 |
+
print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
|
1752 |
+
optimizer_kwargs["momentum"] = 0.9
|
1753 |
+
|
1754 |
+
optimizer_class = torch.optim.SGD
|
1755 |
+
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
1756 |
+
|
1757 |
+
elif optimizer_type == "DAdaptation".lower():
|
1758 |
+
try:
|
1759 |
+
import dadaptation
|
1760 |
+
except ImportError:
|
1761 |
+
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
1762 |
+
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
|
1763 |
+
|
1764 |
+
actual_lr = lr
|
1765 |
+
lr_count = 1
|
1766 |
+
if type(trainable_params) == list and type(trainable_params[0]) == dict:
|
1767 |
+
lrs = set()
|
1768 |
+
actual_lr = trainable_params[0].get("lr", actual_lr)
|
1769 |
+
for group in trainable_params:
|
1770 |
+
lrs.add(group.get("lr", actual_lr))
|
1771 |
+
lr_count = len(lrs)
|
1772 |
+
|
1773 |
+
if actual_lr <= 0.1:
|
1774 |
+
print(
|
1775 |
+
f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}')
|
1776 |
+
print('recommend option: lr=1.0 / 推奨は1.0です')
|
1777 |
+
if lr_count > 1:
|
1778 |
+
print(
|
1779 |
+
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}")
|
1780 |
+
|
1781 |
+
optimizer_class = dadaptation.DAdaptAdam
|
1782 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1783 |
+
|
1784 |
+
elif optimizer_type == "Adafactor".lower():
|
1785 |
+
# 引数を確認して適宜補正する
|
1786 |
+
if "relative_step" not in optimizer_kwargs:
|
1787 |
+
optimizer_kwargs["relative_step"] = True # default
|
1788 |
+
if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
|
1789 |
+
print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします")
|
1790 |
+
optimizer_kwargs["relative_step"] = True
|
1791 |
+
print(f"use Adafactor optimizer | {optimizer_kwargs}")
|
1792 |
+
|
1793 |
+
if optimizer_kwargs["relative_step"]:
|
1794 |
+
print(f"relative_step is true / relative_stepがtrueです")
|
1795 |
+
if lr != 0.0:
|
1796 |
+
print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
|
1797 |
+
args.learning_rate = None
|
1798 |
+
|
1799 |
+
# trainable_paramsがgroupだった時の処理:lrを削除する
|
1800 |
+
if type(trainable_params) == list and type(trainable_params[0]) == dict:
|
1801 |
+
has_group_lr = False
|
1802 |
+
for group in trainable_params:
|
1803 |
+
p = group.pop("lr", None)
|
1804 |
+
has_group_lr = has_group_lr or (p is not None)
|
1805 |
+
|
1806 |
+
if has_group_lr:
|
1807 |
+
# 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない
|
1808 |
+
print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます")
|
1809 |
+
args.unet_lr = None
|
1810 |
+
args.text_encoder_lr = None
|
1811 |
+
|
1812 |
+
if args.lr_scheduler != "adafactor":
|
1813 |
+
print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
|
1814 |
+
args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
|
1815 |
+
|
1816 |
+
lr = None
|
1817 |
+
else:
|
1818 |
+
if args.max_grad_norm != 0.0:
|
1819 |
+
print(f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません")
|
1820 |
+
if args.lr_scheduler != "constant_with_warmup":
|
1821 |
+
print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
|
1822 |
+
if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
|
1823 |
+
print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
|
1824 |
+
|
1825 |
+
optimizer_class = transformers.optimization.Adafactor
|
1826 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1827 |
+
|
1828 |
+
elif optimizer_type == "AdamW".lower():
|
1829 |
+
print(f"use AdamW optimizer | {optimizer_kwargs}")
|
1830 |
+
optimizer_class = torch.optim.AdamW
|
1831 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1832 |
+
|
1833 |
+
else:
|
1834 |
+
# 任意のoptimizerを使う
|
1835 |
+
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
1836 |
+
print(f"use {optimizer_type} | {optimizer_kwargs}")
|
1837 |
+
if "." not in optimizer_type:
|
1838 |
+
optimizer_module = torch.optim
|
1839 |
+
else:
|
1840 |
+
values = optimizer_type.split(".")
|
1841 |
+
optimizer_module = importlib.import_module(".".join(values[:-1]))
|
1842 |
+
optimizer_type = values[-1]
|
1843 |
+
|
1844 |
+
optimizer_class = getattr(optimizer_module, optimizer_type)
|
1845 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1846 |
+
|
1847 |
+
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
1848 |
+
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
|
1849 |
+
|
1850 |
+
return optimizer_name, optimizer_args, optimizer
|
1851 |
+
|
1852 |
+
|
1853 |
+
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
1854 |
+
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
1855 |
+
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
1856 |
+
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
1857 |
+
|
1858 |
+
|
1859 |
+
def get_scheduler_fix(
|
1860 |
+
name: Union[str, SchedulerType],
|
1861 |
+
optimizer: Optimizer,
|
1862 |
+
num_warmup_steps: Optional[int] = None,
|
1863 |
+
num_training_steps: Optional[int] = None,
|
1864 |
+
num_cycles: int = 1,
|
1865 |
+
power: float = 1.0,
|
1866 |
+
):
|
1867 |
+
"""
|
1868 |
+
Unified API to get any scheduler from its name.
|
1869 |
+
Args:
|
1870 |
+
name (`str` or `SchedulerType`):
|
1871 |
+
The name of the scheduler to use.
|
1872 |
+
optimizer (`torch.optim.Optimizer`):
|
1873 |
+
The optimizer that will be used during training.
|
1874 |
+
num_warmup_steps (`int`, *optional*):
|
1875 |
+
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
1876 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
1877 |
+
num_training_steps (`int``, *optional*):
|
1878 |
+
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
1879 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
1880 |
+
num_cycles (`int`, *optional*):
|
1881 |
+
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
1882 |
+
power (`float`, *optional*, defaults to 1.0):
|
1883 |
+
Power factor. See `POLYNOMIAL` scheduler
|
1884 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
1885 |
+
The index of the last epoch when resuming training.
|
1886 |
+
"""
|
1887 |
+
if name.startswith("adafactor"):
|
1888 |
+
assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
|
1889 |
+
initial_lr = float(name.split(':')[1])
|
1890 |
+
# print("adafactor scheduler init lr", initial_lr)
|
1891 |
+
return transformers.optimization.AdafactorSchedule(optimizer, initial_lr)
|
1892 |
+
|
1893 |
+
name = SchedulerType(name)
|
1894 |
+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
1895 |
+
if name == SchedulerType.CONSTANT:
|
1896 |
+
return schedule_func(optimizer)
|
1897 |
+
|
1898 |
+
# All other schedulers require `num_warmup_steps`
|
1899 |
+
if num_warmup_steps is None:
|
1900 |
+
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
1901 |
+
|
1902 |
+
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
1903 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
1904 |
+
|
1905 |
+
# All other schedulers require `num_training_steps`
|
1906 |
+
if num_training_steps is None:
|
1907 |
+
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
1908 |
+
|
1909 |
+
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
1910 |
+
return schedule_func(
|
1911 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
1912 |
+
)
|
1913 |
+
|
1914 |
+
if name == SchedulerType.POLYNOMIAL:
|
1915 |
+
return schedule_func(
|
1916 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
1917 |
+
)
|
1918 |
+
|
1919 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
1920 |
+
|
1921 |
+
|
1922 |
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
1923 |
# backward compatibility
|
1924 |
if args.caption_extention is not None:
|
1925 |
args.caption_extension = args.caption_extention
|
1926 |
args.caption_extention = None
|
1927 |
|
|
|
|
|
|
|
|
|
1928 |
# assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
|
1929 |
if args.resolution is not None:
|
1930 |
args.resolution = tuple([int(r) for r in args.resolution.split(',')])
|
|
|
1947 |
|
1948 |
def load_tokenizer(args: argparse.Namespace):
|
1949 |
print("prepare tokenizer")
|
1950 |
+
original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH
|
1951 |
+
|
1952 |
+
tokenizer: CLIPTokenizer = None
|
1953 |
+
if args.tokenizer_cache_dir:
|
1954 |
+
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace('/', '_'))
|
1955 |
+
if os.path.exists(local_tokenizer_path):
|
1956 |
+
print(f"load tokenizer from cache: {local_tokenizer_path}")
|
1957 |
+
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2
|
1958 |
+
|
1959 |
+
if tokenizer is None:
|
1960 |
+
if args.v2:
|
1961 |
+
tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer")
|
1962 |
+
else:
|
1963 |
+
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
1964 |
+
|
1965 |
+
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
1966 |
print(f"update token length: {args.max_token_length}")
|
1967 |
+
|
1968 |
+
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
1969 |
+
print(f"save Tokenizer to cache: {local_tokenizer_path}")
|
1970 |
+
tokenizer.save_pretrained(local_tokenizer_path)
|
1971 |
+
|
1972 |
return tokenizer
|
1973 |
|
1974 |
|
|
|
2019 |
|
2020 |
|
2021 |
def load_target_model(args: argparse.Namespace, weight_dtype):
|
2022 |
+
name_or_path = args.pretrained_model_name_or_path
|
2023 |
+
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
2024 |
+
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
2025 |
if load_stable_diffusion_format:
|
2026 |
print("load StableDiffusion checkpoint")
|
2027 |
+
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path)
|
2028 |
else:
|
2029 |
print("load Diffusers pretrained models")
|
2030 |
+
try:
|
2031 |
+
pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
|
2032 |
+
except EnvironmentError as ex:
|
2033 |
+
print(
|
2034 |
+
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}")
|
2035 |
text_encoder = pipe.text_encoder
|
2036 |
vae = pipe.vae
|
2037 |
unet = pipe.unet
|
|
|
2200 |
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
2201 |
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
2202 |
|
2203 |
+
|
2204 |
+
# scheduler:
|
2205 |
+
SCHEDULER_LINEAR_START = 0.00085
|
2206 |
+
SCHEDULER_LINEAR_END = 0.0120
|
2207 |
+
SCHEDULER_TIMESTEPS = 1000
|
2208 |
+
SCHEDLER_SCHEDULE = 'scaled_linear'
|
2209 |
+
|
2210 |
+
|
2211 |
+
def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None):
|
2212 |
+
"""
|
2213 |
+
生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない
|
2214 |
+
clip skipは対応した
|
2215 |
+
"""
|
2216 |
+
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
2217 |
+
return
|
2218 |
+
if args.sample_every_n_epochs is not None:
|
2219 |
+
# sample_every_n_steps は無視する
|
2220 |
+
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
2221 |
+
return
|
2222 |
+
else:
|
2223 |
+
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
2224 |
+
return
|
2225 |
+
|
2226 |
+
print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
2227 |
+
if not os.path.isfile(args.sample_prompts):
|
2228 |
+
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
2229 |
+
return
|
2230 |
+
|
2231 |
+
org_vae_device = vae.device # CPUにいるはず
|
2232 |
+
vae.to(device)
|
2233 |
+
|
2234 |
+
# clip skip 対応のための wrapper を作る
|
2235 |
+
if args.clip_skip is None:
|
2236 |
+
text_encoder_or_wrapper = text_encoder
|
2237 |
+
else:
|
2238 |
+
class Wrapper():
|
2239 |
+
def __init__(self, tenc) -> None:
|
2240 |
+
self.tenc = tenc
|
2241 |
+
self.config = {}
|
2242 |
+
super().__init__()
|
2243 |
+
|
2244 |
+
def __call__(self, input_ids, attention_mask):
|
2245 |
+
enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True)
|
2246 |
+
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
2247 |
+
encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states)
|
2248 |
+
pooled_output = enc_out['pooler_output']
|
2249 |
+
return encoder_hidden_states, pooled_output # 1st output is only used
|
2250 |
+
|
2251 |
+
text_encoder_or_wrapper = Wrapper(text_encoder)
|
2252 |
+
|
2253 |
+
# read prompts
|
2254 |
+
with open(args.sample_prompts, 'rt', encoding='utf-8') as f:
|
2255 |
+
prompts = f.readlines()
|
2256 |
+
|
2257 |
+
# schedulerを用意する
|
2258 |
+
sched_init_args = {}
|
2259 |
+
if args.sample_sampler == "ddim":
|
2260 |
+
scheduler_cls = DDIMScheduler
|
2261 |
+
elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
|
2262 |
+
scheduler_cls = DDPMScheduler
|
2263 |
+
elif args.sample_sampler == "pndm":
|
2264 |
+
scheduler_cls = PNDMScheduler
|
2265 |
+
elif args.sample_sampler == 'lms' or args.sample_sampler == 'k_lms':
|
2266 |
+
scheduler_cls = LMSDiscreteScheduler
|
2267 |
+
elif args.sample_sampler == 'euler' or args.sample_sampler == 'k_euler':
|
2268 |
+
scheduler_cls = EulerDiscreteScheduler
|
2269 |
+
elif args.sample_sampler == 'euler_a' or args.sample_sampler == 'k_euler_a':
|
2270 |
+
scheduler_cls = EulerAncestralDiscreteScheduler
|
2271 |
+
elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
|
2272 |
+
scheduler_cls = DPMSolverMultistepScheduler
|
2273 |
+
sched_init_args['algorithm_type'] = args.sample_sampler
|
2274 |
+
elif args.sample_sampler == "dpmsingle":
|
2275 |
+
scheduler_cls = DPMSolverSinglestepScheduler
|
2276 |
+
elif args.sample_sampler == "heun":
|
2277 |
+
scheduler_cls = HeunDiscreteScheduler
|
2278 |
+
elif args.sample_sampler == 'dpm_2' or args.sample_sampler == 'k_dpm_2':
|
2279 |
+
scheduler_cls = KDPM2DiscreteScheduler
|
2280 |
+
elif args.sample_sampler == 'dpm_2_a' or args.sample_sampler == 'k_dpm_2_a':
|
2281 |
+
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
2282 |
+
else:
|
2283 |
+
scheduler_cls = DDIMScheduler
|
2284 |
+
|
2285 |
+
if args.v_parameterization:
|
2286 |
+
sched_init_args['prediction_type'] = 'v_prediction'
|
2287 |
+
|
2288 |
+
scheduler = scheduler_cls(num_train_timesteps=SCHEDULER_TIMESTEPS,
|
2289 |
+
beta_start=SCHEDULER_LINEAR_START, beta_end=SCHEDULER_LINEAR_END,
|
2290 |
+
beta_schedule=SCHEDLER_SCHEDULE, **sched_init_args)
|
2291 |
+
|
2292 |
+
# clip_sample=Trueにする
|
2293 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
2294 |
+
# print("set clip_sample to True")
|
2295 |
+
scheduler.config.clip_sample = True
|
2296 |
+
|
2297 |
+
pipeline = StableDiffusionPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer,
|
2298 |
+
scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False)
|
2299 |
+
pipeline.to(device)
|
2300 |
+
|
2301 |
+
save_dir = args.output_dir + "/sample"
|
2302 |
+
os.makedirs(save_dir, exist_ok=True)
|
2303 |
+
|
2304 |
+
rng_state = torch.get_rng_state()
|
2305 |
+
cuda_rng_state = torch.cuda.get_rng_state()
|
2306 |
+
|
2307 |
+
with torch.no_grad():
|
2308 |
+
with accelerator.autocast():
|
2309 |
+
for i, prompt in enumerate(prompts):
|
2310 |
+
if not accelerator.is_main_process:
|
2311 |
+
continue
|
2312 |
+
prompt = prompt.strip()
|
2313 |
+
if len(prompt) == 0 or prompt[0] == '#':
|
2314 |
+
continue
|
2315 |
+
|
2316 |
+
# subset of gen_img_diffusers
|
2317 |
+
prompt_args = prompt.split(' --')
|
2318 |
+
prompt = prompt_args[0]
|
2319 |
+
negative_prompt = None
|
2320 |
+
sample_steps = 30
|
2321 |
+
width = height = 512
|
2322 |
+
scale = 7.5
|
2323 |
+
seed = None
|
2324 |
+
for parg in prompt_args:
|
2325 |
+
try:
|
2326 |
+
m = re.match(r'w (\d+)', parg, re.IGNORECASE)
|
2327 |
+
if m:
|
2328 |
+
width = int(m.group(1))
|
2329 |
+
continue
|
2330 |
+
|
2331 |
+
m = re.match(r'h (\d+)', parg, re.IGNORECASE)
|
2332 |
+
if m:
|
2333 |
+
height = int(m.group(1))
|
2334 |
+
continue
|
2335 |
+
|
2336 |
+
m = re.match(r'd (\d+)', parg, re.IGNORECASE)
|
2337 |
+
if m:
|
2338 |
+
seed = int(m.group(1))
|
2339 |
+
continue
|
2340 |
+
|
2341 |
+
m = re.match(r's (\d+)', parg, re.IGNORECASE)
|
2342 |
+
if m: # steps
|
2343 |
+
sample_steps = max(1, min(1000, int(m.group(1))))
|
2344 |
+
continue
|
2345 |
+
|
2346 |
+
m = re.match(r'l ([\d\.]+)', parg, re.IGNORECASE)
|
2347 |
+
if m: # scale
|
2348 |
+
scale = float(m.group(1))
|
2349 |
+
continue
|
2350 |
+
|
2351 |
+
m = re.match(r'n (.+)', parg, re.IGNORECASE)
|
2352 |
+
if m: # negative prompt
|
2353 |
+
negative_prompt = m.group(1)
|
2354 |
+
continue
|
2355 |
+
|
2356 |
+
except ValueError as ex:
|
2357 |
+
print(f"Exception in parsing / 解析エラー: {parg}")
|
2358 |
+
print(ex)
|
2359 |
+
|
2360 |
+
if seed is not None:
|
2361 |
+
torch.manual_seed(seed)
|
2362 |
+
torch.cuda.manual_seed(seed)
|
2363 |
+
|
2364 |
+
if prompt_replacement is not None:
|
2365 |
+
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
2366 |
+
if negative_prompt is not None:
|
2367 |
+
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
2368 |
+
|
2369 |
+
height = max(64, height - height % 8) # round to divisible by 8
|
2370 |
+
width = max(64, width - width % 8) # round to divisible by 8
|
2371 |
+
print(f"prompt: {prompt}")
|
2372 |
+
print(f"negative_prompt: {negative_prompt}")
|
2373 |
+
print(f"height: {height}")
|
2374 |
+
print(f"width: {width}")
|
2375 |
+
print(f"sample_steps: {sample_steps}")
|
2376 |
+
print(f"scale: {scale}")
|
2377 |
+
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
|
2378 |
+
|
2379 |
+
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
|
2380 |
+
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
2381 |
+
seed_suffix = "" if seed is None else f"_{seed}"
|
2382 |
+
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
|
2383 |
+
|
2384 |
+
image.save(os.path.join(save_dir, img_filename))
|
2385 |
+
|
2386 |
+
# clear pipeline and cache to reduce vram usage
|
2387 |
+
del pipeline
|
2388 |
+
torch.cuda.empty_cache()
|
2389 |
+
|
2390 |
+
torch.set_rng_state(rng_state)
|
2391 |
+
torch.cuda.set_rng_state(cuda_rng_state)
|
2392 |
+
vae.to(org_vae_device)
|
2393 |
+
|
2394 |
# endregion
|
2395 |
|
2396 |
# region 前処理用
|
networks/check_lora_weights.py
CHANGED
@@ -21,7 +21,7 @@ def main(file):
|
|
21 |
|
22 |
for key, value in values:
|
23 |
value = value.to(torch.float32)
|
24 |
-
print(f"{key},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
|
25 |
|
26 |
|
27 |
if __name__ == '__main__':
|
|
|
21 |
|
22 |
for key, value in values:
|
23 |
value = value.to(torch.float32)
|
24 |
+
print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
|
25 |
|
26 |
|
27 |
if __name__ == '__main__':
|
networks/extract_lora_from_models.py
CHANGED
@@ -45,8 +45,13 @@ def svd(args):
|
|
45 |
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
|
46 |
|
47 |
# create LoRA network to extract weights: Use dim (rank) as alpha
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
50 |
assert len(lora_network_o.text_encoder_loras) == len(
|
51 |
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
52 |
|
@@ -85,13 +90,28 @@ def svd(args):
|
|
85 |
|
86 |
# make LoRA with svd
|
87 |
print("calculating by svd")
|
88 |
-
rank = args.dim
|
89 |
lora_weights = {}
|
90 |
with torch.no_grad():
|
91 |
for lora_name, mat in tqdm(list(diffs.items())):
|
|
|
92 |
conv2d = (len(mat.size()) == 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
if conv2d:
|
94 |
-
|
|
|
|
|
|
|
95 |
|
96 |
U, S, Vh = torch.linalg.svd(mat)
|
97 |
|
@@ -108,30 +128,27 @@ def svd(args):
|
|
108 |
U = U.clamp(low_val, hi_val)
|
109 |
Vh = Vh.clamp(low_val, hi_val)
|
110 |
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict
|
115 |
-
lora_sd = lora_network_o.state_dict()
|
116 |
-
print(f"LoRA has {len(lora_sd)} weights.")
|
117 |
-
|
118 |
-
for key in list(lora_sd.keys()):
|
119 |
-
if "alpha" in key:
|
120 |
-
continue
|
121 |
|
122 |
-
|
123 |
-
|
124 |
|
125 |
-
|
126 |
-
# print(key, i, weights.size(), lora_sd[key].size())
|
127 |
-
if len(lora_sd[key].size()) == 4:
|
128 |
-
weights = weights.unsqueeze(2).unsqueeze(3)
|
129 |
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
|
132 |
|
133 |
# load state dict to LoRA and save it
|
134 |
-
|
|
|
|
|
|
|
135 |
print(f"Loading extracted LoRA weights: {info}")
|
136 |
|
137 |
dir_name = os.path.dirname(args.save_to)
|
@@ -139,9 +156,9 @@ def svd(args):
|
|
139 |
os.makedirs(dir_name, exist_ok=True)
|
140 |
|
141 |
# minimum metadata
|
142 |
-
metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
|
143 |
|
144 |
-
|
145 |
print(f"LoRA weights are saved to: {args.save_to}")
|
146 |
|
147 |
|
@@ -158,6 +175,8 @@ if __name__ == '__main__':
|
|
158 |
parser.add_argument("--save_to", type=str, default=None,
|
159 |
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
160 |
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
|
|
|
|
161 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイ��、cuda でGPUを使う")
|
162 |
|
163 |
args = parser.parse_args()
|
|
|
45 |
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
|
46 |
|
47 |
# create LoRA network to extract weights: Use dim (rank) as alpha
|
48 |
+
if args.conv_dim is None:
|
49 |
+
kwargs = {}
|
50 |
+
else:
|
51 |
+
kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim}
|
52 |
+
|
53 |
+
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o, **kwargs)
|
54 |
+
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t, **kwargs)
|
55 |
assert len(lora_network_o.text_encoder_loras) == len(
|
56 |
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
57 |
|
|
|
90 |
|
91 |
# make LoRA with svd
|
92 |
print("calculating by svd")
|
|
|
93 |
lora_weights = {}
|
94 |
with torch.no_grad():
|
95 |
for lora_name, mat in tqdm(list(diffs.items())):
|
96 |
+
# if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3
|
97 |
conv2d = (len(mat.size()) == 4)
|
98 |
+
kernel_size = None if not conv2d else mat.size()[2:4]
|
99 |
+
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
100 |
+
|
101 |
+
rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim
|
102 |
+
out_dim, in_dim = mat.size()[0:2]
|
103 |
+
|
104 |
+
if args.device:
|
105 |
+
mat = mat.to(args.device)
|
106 |
+
|
107 |
+
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
108 |
+
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
109 |
+
|
110 |
if conv2d:
|
111 |
+
if conv2d_3x3:
|
112 |
+
mat = mat.flatten(start_dim=1)
|
113 |
+
else:
|
114 |
+
mat = mat.squeeze()
|
115 |
|
116 |
U, S, Vh = torch.linalg.svd(mat)
|
117 |
|
|
|
128 |
U = U.clamp(low_val, hi_val)
|
129 |
Vh = Vh.clamp(low_val, hi_val)
|
130 |
|
131 |
+
if conv2d:
|
132 |
+
U = U.reshape(out_dim, rank, 1, 1)
|
133 |
+
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
+
U = U.to("cpu").contiguous()
|
136 |
+
Vh = Vh.to("cpu").contiguous()
|
137 |
|
138 |
+
lora_weights[lora_name] = (U, Vh)
|
|
|
|
|
|
|
139 |
|
140 |
+
# make state dict for LoRA
|
141 |
+
lora_sd = {}
|
142 |
+
for lora_name, (up_weight, down_weight) in lora_weights.items():
|
143 |
+
lora_sd[lora_name + '.lora_up.weight'] = up_weight
|
144 |
+
lora_sd[lora_name + '.lora_down.weight'] = down_weight
|
145 |
+
lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0])
|
146 |
|
147 |
# load state dict to LoRA and save it
|
148 |
+
lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
|
149 |
+
lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict
|
150 |
+
|
151 |
+
info = lora_network_save.load_state_dict(lora_sd)
|
152 |
print(f"Loading extracted LoRA weights: {info}")
|
153 |
|
154 |
dir_name = os.path.dirname(args.save_to)
|
|
|
156 |
os.makedirs(dir_name, exist_ok=True)
|
157 |
|
158 |
# minimum metadata
|
159 |
+
metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
|
160 |
|
161 |
+
lora_network_save.save_weights(args.save_to, save_dtype, metadata)
|
162 |
print(f"LoRA weights are saved to: {args.save_to}")
|
163 |
|
164 |
|
|
|
175 |
parser.add_argument("--save_to", type=str, default=None,
|
176 |
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
177 |
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
178 |
+
parser.add_argument("--conv_dim", type=int, default=None,
|
179 |
+
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)")
|
180 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイ��、cuda でGPUを使う")
|
181 |
|
182 |
args = parser.parse_args()
|
networks/lora.py
CHANGED
@@ -6,6 +6,7 @@
|
|
6 |
import math
|
7 |
import os
|
8 |
from typing import List
|
|
|
9 |
import torch
|
10 |
|
11 |
from library import train_util
|
@@ -20,22 +21,34 @@ class LoRAModule(torch.nn.Module):
|
|
20 |
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
21 |
super().__init__()
|
22 |
self.lora_name = lora_name
|
23 |
-
self.lora_dim = lora_dim
|
24 |
|
25 |
if org_module.__class__.__name__ == 'Conv2d':
|
26 |
in_dim = org_module.in_channels
|
27 |
out_dim = org_module.out_channels
|
28 |
-
self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
|
29 |
-
self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
|
30 |
else:
|
31 |
in_dim = org_module.in_features
|
32 |
out_dim = org_module.out_features
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
if type(alpha) == torch.Tensor:
|
37 |
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
38 |
-
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
39 |
self.scale = alpha / self.lora_dim
|
40 |
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
41 |
|
@@ -45,69 +58,192 @@ class LoRAModule(torch.nn.Module):
|
|
45 |
|
46 |
self.multiplier = multiplier
|
47 |
self.org_module = org_module # remove in applying
|
|
|
|
|
48 |
|
49 |
def apply_to(self):
|
50 |
self.org_forward = self.org_module.forward
|
51 |
self.org_module.forward = self.forward
|
52 |
del self.org_module
|
53 |
|
|
|
|
|
|
|
|
|
54 |
def forward(self, x):
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
|
58 |
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
59 |
if network_dim is None:
|
60 |
network_dim = 4 # default
|
61 |
-
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
|
62 |
-
return network
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
for key, value in weights_sd.items():
|
76 |
-
if network_alpha is None and 'alpha' in key:
|
77 |
-
network_alpha = value
|
78 |
-
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
|
79 |
-
network_dim = value.size()[0]
|
80 |
|
81 |
-
if network_alpha is None:
|
82 |
-
network_alpha = network_dim
|
83 |
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
network.weights_sd = weights_sd
|
86 |
return network
|
87 |
|
88 |
|
89 |
class LoRANetwork(torch.nn.Module):
|
|
|
90 |
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
|
|
91 |
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
92 |
LORA_PREFIX_UNET = 'lora_unet'
|
93 |
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
94 |
|
95 |
-
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
|
96 |
super().__init__()
|
97 |
self.multiplier = multiplier
|
|
|
98 |
self.lora_dim = lora_dim
|
99 |
self.alpha = alpha
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
# create module instances
|
102 |
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
103 |
loras = []
|
104 |
for name, module in root_module.named_modules():
|
105 |
if module.__class__.__name__ in target_replace_modules:
|
|
|
106 |
for child_name, child_module in module.named_modules():
|
107 |
-
|
|
|
|
|
|
|
108 |
lora_name = prefix + '.' + name + '.' + child_name
|
109 |
lora_name = lora_name.replace('.', '_')
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
loras.append(lora)
|
112 |
return loras
|
113 |
|
@@ -115,7 +251,12 @@ class LoRANetwork(torch.nn.Module):
|
|
115 |
text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
116 |
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
117 |
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
119 |
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
120 |
|
121 |
self.weights_sd = None
|
@@ -126,6 +267,11 @@ class LoRANetwork(torch.nn.Module):
|
|
126 |
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
127 |
names.add(lora.lora_name)
|
128 |
|
|
|
|
|
|
|
|
|
|
|
129 |
def load_weights(self, file):
|
130 |
if os.path.splitext(file)[1] == '.safetensors':
|
131 |
from safetensors.torch import load_file, safe_open
|
@@ -235,3 +381,18 @@ class LoRANetwork(torch.nn.Module):
|
|
235 |
save_file(state_dict, file, metadata)
|
236 |
else:
|
237 |
torch.save(state_dict, file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import math
|
7 |
import os
|
8 |
from typing import List
|
9 |
+
import numpy as np
|
10 |
import torch
|
11 |
|
12 |
from library import train_util
|
|
|
21 |
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
22 |
super().__init__()
|
23 |
self.lora_name = lora_name
|
|
|
24 |
|
25 |
if org_module.__class__.__name__ == 'Conv2d':
|
26 |
in_dim = org_module.in_channels
|
27 |
out_dim = org_module.out_channels
|
|
|
|
|
28 |
else:
|
29 |
in_dim = org_module.in_features
|
30 |
out_dim = org_module.out_features
|
31 |
+
|
32 |
+
# if limit_rank:
|
33 |
+
# self.lora_dim = min(lora_dim, in_dim, out_dim)
|
34 |
+
# if self.lora_dim != lora_dim:
|
35 |
+
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
36 |
+
# else:
|
37 |
+
self.lora_dim = lora_dim
|
38 |
+
|
39 |
+
if org_module.__class__.__name__ == 'Conv2d':
|
40 |
+
kernel_size = org_module.kernel_size
|
41 |
+
stride = org_module.stride
|
42 |
+
padding = org_module.padding
|
43 |
+
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
44 |
+
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
45 |
+
else:
|
46 |
+
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
47 |
+
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
48 |
|
49 |
if type(alpha) == torch.Tensor:
|
50 |
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
51 |
+
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
52 |
self.scale = alpha / self.lora_dim
|
53 |
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
54 |
|
|
|
58 |
|
59 |
self.multiplier = multiplier
|
60 |
self.org_module = org_module # remove in applying
|
61 |
+
self.region = None
|
62 |
+
self.region_mask = None
|
63 |
|
64 |
def apply_to(self):
|
65 |
self.org_forward = self.org_module.forward
|
66 |
self.org_module.forward = self.forward
|
67 |
del self.org_module
|
68 |
|
69 |
+
def set_region(self, region):
|
70 |
+
self.region = region
|
71 |
+
self.region_mask = None
|
72 |
+
|
73 |
def forward(self, x):
|
74 |
+
if self.region is None:
|
75 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
76 |
+
|
77 |
+
# regional LoRA FIXME same as additional-network extension
|
78 |
+
if x.size()[1] % 77 == 0:
|
79 |
+
# print(f"LoRA for context: {self.lora_name}")
|
80 |
+
self.region = None
|
81 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
82 |
+
|
83 |
+
# calculate region mask first time
|
84 |
+
if self.region_mask is None:
|
85 |
+
if len(x.size()) == 4:
|
86 |
+
h, w = x.size()[2:4]
|
87 |
+
else:
|
88 |
+
seq_len = x.size()[1]
|
89 |
+
ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
|
90 |
+
h = int(self.region.size()[0] / ratio + .5)
|
91 |
+
w = seq_len // h
|
92 |
+
|
93 |
+
r = self.region.to(x.device)
|
94 |
+
if r.dtype == torch.bfloat16:
|
95 |
+
r = r.to(torch.float)
|
96 |
+
r = r.unsqueeze(0).unsqueeze(1)
|
97 |
+
# print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
|
98 |
+
r = torch.nn.functional.interpolate(r, (h, w), mode='bilinear')
|
99 |
+
r = r.to(x.dtype)
|
100 |
+
|
101 |
+
if len(x.size()) == 3:
|
102 |
+
r = torch.reshape(r, (1, x.size()[1], -1))
|
103 |
+
|
104 |
+
self.region_mask = r
|
105 |
+
|
106 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask
|
107 |
|
108 |
|
109 |
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
110 |
if network_dim is None:
|
111 |
network_dim = 4 # default
|
|
|
|
|
112 |
|
113 |
+
# extract dim/alpha for conv2d, and block dim
|
114 |
+
conv_dim = kwargs.get('conv_dim', None)
|
115 |
+
conv_alpha = kwargs.get('conv_alpha', None)
|
116 |
+
if conv_dim is not None:
|
117 |
+
conv_dim = int(conv_dim)
|
118 |
+
if conv_alpha is None:
|
119 |
+
conv_alpha = 1.0
|
120 |
+
else:
|
121 |
+
conv_alpha = float(conv_alpha)
|
122 |
|
123 |
+
"""
|
124 |
+
block_dims = kwargs.get("block_dims")
|
125 |
+
block_alphas = None
|
126 |
+
|
127 |
+
if block_dims is not None:
|
128 |
+
block_dims = [int(d) for d in block_dims.split(',')]
|
129 |
+
assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
|
130 |
+
block_alphas = kwargs.get("block_alphas")
|
131 |
+
if block_alphas is None:
|
132 |
+
block_alphas = [1] * len(block_dims)
|
133 |
+
else:
|
134 |
+
block_alphas = [int(a) for a in block_alphas(',')]
|
135 |
+
assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
136 |
+
|
137 |
+
conv_block_dims = kwargs.get("conv_block_dims")
|
138 |
+
conv_block_alphas = None
|
139 |
+
|
140 |
+
if conv_block_dims is not None:
|
141 |
+
conv_block_dims = [int(d) for d in conv_block_dims.split(',')]
|
142 |
+
assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
|
143 |
+
conv_block_alphas = kwargs.get("conv_block_alphas")
|
144 |
+
if conv_block_alphas is None:
|
145 |
+
conv_block_alphas = [1] * len(conv_block_dims)
|
146 |
+
else:
|
147 |
+
conv_block_alphas = [int(a) for a in conv_block_alphas(',')]
|
148 |
+
assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
149 |
+
"""
|
150 |
|
151 |
+
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim,
|
152 |
+
alpha=network_alpha, conv_lora_dim=conv_dim, conv_alpha=conv_alpha)
|
153 |
+
return network
|
|
|
|
|
|
|
|
|
|
|
154 |
|
|
|
|
|
155 |
|
156 |
+
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
|
157 |
+
if weights_sd is None:
|
158 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
159 |
+
from safetensors.torch import load_file, safe_open
|
160 |
+
weights_sd = load_file(file)
|
161 |
+
else:
|
162 |
+
weights_sd = torch.load(file, map_location='cpu')
|
163 |
+
|
164 |
+
# get dim/alpha mapping
|
165 |
+
modules_dim = {}
|
166 |
+
modules_alpha = {}
|
167 |
+
for key, value in weights_sd.items():
|
168 |
+
if '.' not in key:
|
169 |
+
continue
|
170 |
+
|
171 |
+
lora_name = key.split('.')[0]
|
172 |
+
if 'alpha' in key:
|
173 |
+
modules_alpha[lora_name] = value
|
174 |
+
elif 'lora_down' in key:
|
175 |
+
dim = value.size()[0]
|
176 |
+
modules_dim[lora_name] = dim
|
177 |
+
# print(lora_name, value.size(), dim)
|
178 |
+
|
179 |
+
# support old LoRA without alpha
|
180 |
+
for key in modules_dim.keys():
|
181 |
+
if key not in modules_alpha:
|
182 |
+
modules_alpha = modules_dim[key]
|
183 |
+
|
184 |
+
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
185 |
network.weights_sd = weights_sd
|
186 |
return network
|
187 |
|
188 |
|
189 |
class LoRANetwork(torch.nn.Module):
|
190 |
+
# is it possible to apply conv_in and conv_out?
|
191 |
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
192 |
+
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
193 |
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
194 |
LORA_PREFIX_UNET = 'lora_unet'
|
195 |
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
196 |
|
197 |
+
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1, conv_lora_dim=None, conv_alpha=None, modules_dim=None, modules_alpha=None) -> None:
|
198 |
super().__init__()
|
199 |
self.multiplier = multiplier
|
200 |
+
|
201 |
self.lora_dim = lora_dim
|
202 |
self.alpha = alpha
|
203 |
+
self.conv_lora_dim = conv_lora_dim
|
204 |
+
self.conv_alpha = conv_alpha
|
205 |
+
|
206 |
+
if modules_dim is not None:
|
207 |
+
print(f"create LoRA network from weights")
|
208 |
+
else:
|
209 |
+
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
210 |
+
|
211 |
+
self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None
|
212 |
+
if self.apply_to_conv2d_3x3:
|
213 |
+
if self.conv_alpha is None:
|
214 |
+
self.conv_alpha = self.alpha
|
215 |
+
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
216 |
|
217 |
# create module instances
|
218 |
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
219 |
loras = []
|
220 |
for name, module in root_module.named_modules():
|
221 |
if module.__class__.__name__ in target_replace_modules:
|
222 |
+
# TODO get block index here
|
223 |
for child_name, child_module in module.named_modules():
|
224 |
+
is_linear = child_module.__class__.__name__ == "Linear"
|
225 |
+
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
226 |
+
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
227 |
+
if is_linear or is_conv2d:
|
228 |
lora_name = prefix + '.' + name + '.' + child_name
|
229 |
lora_name = lora_name.replace('.', '_')
|
230 |
+
|
231 |
+
if modules_dim is not None:
|
232 |
+
if lora_name not in modules_dim:
|
233 |
+
continue # no LoRA module in this weights file
|
234 |
+
dim = modules_dim[lora_name]
|
235 |
+
alpha = modules_alpha[lora_name]
|
236 |
+
else:
|
237 |
+
if is_linear or is_conv2d_1x1:
|
238 |
+
dim = self.lora_dim
|
239 |
+
alpha = self.alpha
|
240 |
+
elif self.apply_to_conv2d_3x3:
|
241 |
+
dim = self.conv_lora_dim
|
242 |
+
alpha = self.conv_alpha
|
243 |
+
else:
|
244 |
+
continue
|
245 |
+
|
246 |
+
lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
|
247 |
loras.append(lora)
|
248 |
return loras
|
249 |
|
|
|
251 |
text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
252 |
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
253 |
|
254 |
+
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
255 |
+
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
256 |
+
if modules_dim is not None or self.conv_lora_dim is not None:
|
257 |
+
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
258 |
+
|
259 |
+
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules)
|
260 |
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
261 |
|
262 |
self.weights_sd = None
|
|
|
267 |
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
268 |
names.add(lora.lora_name)
|
269 |
|
270 |
+
def set_multiplier(self, multiplier):
|
271 |
+
self.multiplier = multiplier
|
272 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
273 |
+
lora.multiplier = self.multiplier
|
274 |
+
|
275 |
def load_weights(self, file):
|
276 |
if os.path.splitext(file)[1] == '.safetensors':
|
277 |
from safetensors.torch import load_file, safe_open
|
|
|
381 |
save_file(state_dict, file, metadata)
|
382 |
else:
|
383 |
torch.save(state_dict, file)
|
384 |
+
|
385 |
+
@ staticmethod
|
386 |
+
def set_regions(networks, image):
|
387 |
+
image = image.astype(np.float32) / 255.0
|
388 |
+
for i, network in enumerate(networks[:3]):
|
389 |
+
# NOTE: consider averaging overwrapping area
|
390 |
+
region = image[:, :, i]
|
391 |
+
if region.max() == 0:
|
392 |
+
continue
|
393 |
+
region = torch.tensor(region)
|
394 |
+
network.set_region(region)
|
395 |
+
|
396 |
+
def set_region(self, region):
|
397 |
+
for lora in self.unet_loras:
|
398 |
+
lora.set_region(region)
|
networks/merge_lora.py
CHANGED
@@ -48,7 +48,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
|
48 |
for name, module in root_module.named_modules():
|
49 |
if module.__class__.__name__ in target_replace_modules:
|
50 |
for child_name, child_module in module.named_modules():
|
51 |
-
if child_module.__class__.__name__ == "Linear" or
|
52 |
lora_name = prefix + '.' + name + '.' + child_name
|
53 |
lora_name = lora_name.replace('.', '_')
|
54 |
name_to_module[lora_name] = child_module
|
@@ -80,13 +80,19 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
|
80 |
|
81 |
# W <- W + U * D
|
82 |
weight = module.weight
|
|
|
83 |
if len(weight.size()) == 2:
|
84 |
# linear
|
85 |
weight = weight + ratio * (up_weight @ down_weight) * scale
|
86 |
-
|
87 |
-
# conv2d
|
88 |
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
89 |
).unsqueeze(2).unsqueeze(3) * scale
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
module.weight = torch.nn.Parameter(weight)
|
92 |
|
@@ -123,7 +129,7 @@ def merge_lora_models(models, ratios, merge_dtype):
|
|
123 |
alphas[lora_module_name] = alpha
|
124 |
if lora_module_name not in base_alphas:
|
125 |
base_alphas[lora_module_name] = alpha
|
126 |
-
|
127 |
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
128 |
|
129 |
# merge
|
@@ -145,7 +151,7 @@ def merge_lora_models(models, ratios, merge_dtype):
|
|
145 |
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
146 |
else:
|
147 |
merged_sd[key] = lora_sd[key] * scale
|
148 |
-
|
149 |
# set alpha to sd
|
150 |
for lora_module_name, alpha in base_alphas.items():
|
151 |
key = lora_module_name + ".alpha"
|
|
|
48 |
for name, module in root_module.named_modules():
|
49 |
if module.__class__.__name__ in target_replace_modules:
|
50 |
for child_name, child_module in module.named_modules():
|
51 |
+
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
|
52 |
lora_name = prefix + '.' + name + '.' + child_name
|
53 |
lora_name = lora_name.replace('.', '_')
|
54 |
name_to_module[lora_name] = child_module
|
|
|
80 |
|
81 |
# W <- W + U * D
|
82 |
weight = module.weight
|
83 |
+
# print(module_name, down_weight.size(), up_weight.size())
|
84 |
if len(weight.size()) == 2:
|
85 |
# linear
|
86 |
weight = weight + ratio * (up_weight @ down_weight) * scale
|
87 |
+
elif down_weight.size()[2:4] == (1, 1):
|
88 |
+
# conv2d 1x1
|
89 |
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
90 |
).unsqueeze(2).unsqueeze(3) * scale
|
91 |
+
else:
|
92 |
+
# conv2d 3x3
|
93 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
94 |
+
# print(conved.size(), weight.size(), module.stride, module.padding)
|
95 |
+
weight = weight + ratio * conved * scale
|
96 |
|
97 |
module.weight = torch.nn.Parameter(weight)
|
98 |
|
|
|
129 |
alphas[lora_module_name] = alpha
|
130 |
if lora_module_name not in base_alphas:
|
131 |
base_alphas[lora_module_name] = alpha
|
132 |
+
|
133 |
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
134 |
|
135 |
# merge
|
|
|
151 |
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
152 |
else:
|
153 |
merged_sd[key] = lora_sd[key] * scale
|
154 |
+
|
155 |
# set alpha to sd
|
156 |
for lora_module_name, alpha in base_alphas.items():
|
157 |
key = lora_module_name + ".alpha"
|
networks/resize_lora.py
CHANGED
@@ -1,14 +1,15 @@
|
|
1 |
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
2 |
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
3 |
-
# Thanks to cloneofsimo
|
4 |
|
5 |
import argparse
|
6 |
-
import os
|
7 |
import torch
|
8 |
from safetensors.torch import load_file, save_file, safe_open
|
9 |
from tqdm import tqdm
|
10 |
from library import train_util, model_util
|
|
|
11 |
|
|
|
12 |
|
13 |
def load_state_dict(file_name, dtype):
|
14 |
if model_util.is_safetensors(file_name):
|
@@ -38,12 +39,149 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
|
|
38 |
torch.save(model, file_name)
|
39 |
|
40 |
|
41 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
network_alpha = None
|
43 |
network_dim = None
|
44 |
verbose_str = "\n"
|
45 |
-
|
46 |
-
CLAMP_QUANTILE = 0.99
|
47 |
|
48 |
# Extract loaded lora dim and alpha
|
49 |
for key, value in lora_sd.items():
|
@@ -57,9 +195,9 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
|
57 |
network_alpha = network_dim
|
58 |
|
59 |
scale = network_alpha/network_dim
|
60 |
-
new_alpha = float(scale*new_rank) # calculate new alpha from scale
|
61 |
|
62 |
-
|
|
|
63 |
|
64 |
lora_down_weight = None
|
65 |
lora_up_weight = None
|
@@ -68,7 +206,6 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
|
68 |
block_down_name = None
|
69 |
block_up_name = None
|
70 |
|
71 |
-
print("resizing lora...")
|
72 |
with torch.no_grad():
|
73 |
for key, value in tqdm(lora_sd.items()):
|
74 |
if 'lora_down' in key:
|
@@ -85,57 +222,43 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
|
85 |
conv2d = (len(lora_down_weight.size()) == 4)
|
86 |
|
87 |
if conv2d:
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
lora_up_weight = lora_up_weight.to(args.device)
|
94 |
-
lora_down_weight = lora_down_weight.to(args.device)
|
95 |
-
|
96 |
-
full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
|
97 |
-
|
98 |
-
U, S, Vh = torch.linalg.svd(full_weight_matrix)
|
99 |
|
100 |
if verbose:
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
U = U @ torch.diag(S)
|
109 |
|
110 |
-
|
|
|
|
|
|
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
U = U.clamp(low_val, hi_val)
|
117 |
-
Vh = Vh.clamp(low_val, hi_val)
|
118 |
-
|
119 |
-
if conv2d:
|
120 |
-
U = U.unsqueeze(2).unsqueeze(3)
|
121 |
-
Vh = Vh.unsqueeze(2).unsqueeze(3)
|
122 |
-
|
123 |
-
if device:
|
124 |
-
U = U.to(org_device)
|
125 |
-
Vh = Vh.to(org_device)
|
126 |
-
|
127 |
-
o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
|
128 |
-
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
|
129 |
-
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
|
130 |
|
131 |
block_down_name = None
|
132 |
block_up_name = None
|
133 |
lora_down_weight = None
|
134 |
lora_up_weight = None
|
135 |
weights_loaded = False
|
|
|
136 |
|
137 |
if verbose:
|
138 |
print(verbose_str)
|
|
|
|
|
139 |
print("resizing complete")
|
140 |
return o_lora_sd, network_dim, new_alpha
|
141 |
|
@@ -151,6 +274,9 @@ def resize(args):
|
|
151 |
return torch.bfloat16
|
152 |
return None
|
153 |
|
|
|
|
|
|
|
154 |
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
|
155 |
save_dtype = str_to_dtype(args.save_precision)
|
156 |
if save_dtype is None:
|
@@ -159,17 +285,23 @@ def resize(args):
|
|
159 |
print("loading Model...")
|
160 |
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
161 |
|
162 |
-
print("
|
163 |
-
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose)
|
164 |
|
165 |
# update metadata
|
166 |
if metadata is None:
|
167 |
metadata = {}
|
168 |
|
169 |
comment = metadata.get("ss_training_comment", "")
|
170 |
-
|
171 |
-
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
175 |
metadata["sshs_model_hash"] = model_hash
|
@@ -193,6 +325,11 @@ if __name__ == '__main__':
|
|
193 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
194 |
parser.add_argument("--verbose", action="store_true",
|
195 |
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
args = parser.parse_args()
|
198 |
resize(args)
|
|
|
1 |
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
2 |
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
3 |
+
# Thanks to cloneofsimo
|
4 |
|
5 |
import argparse
|
|
|
6 |
import torch
|
7 |
from safetensors.torch import load_file, save_file, safe_open
|
8 |
from tqdm import tqdm
|
9 |
from library import train_util, model_util
|
10 |
+
import numpy as np
|
11 |
|
12 |
+
MIN_SV = 1e-6
|
13 |
|
14 |
def load_state_dict(file_name, dtype):
|
15 |
if model_util.is_safetensors(file_name):
|
|
|
39 |
torch.save(model, file_name)
|
40 |
|
41 |
|
42 |
+
def index_sv_cumulative(S, target):
|
43 |
+
original_sum = float(torch.sum(S))
|
44 |
+
cumulative_sums = torch.cumsum(S, dim=0)/original_sum
|
45 |
+
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
46 |
+
if index >= len(S):
|
47 |
+
index = len(S) - 1
|
48 |
+
|
49 |
+
return index
|
50 |
+
|
51 |
+
|
52 |
+
def index_sv_fro(S, target):
|
53 |
+
S_squared = S.pow(2)
|
54 |
+
s_fro_sq = float(torch.sum(S_squared))
|
55 |
+
sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
|
56 |
+
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
57 |
+
if index >= len(S):
|
58 |
+
index = len(S) - 1
|
59 |
+
|
60 |
+
return index
|
61 |
+
|
62 |
+
|
63 |
+
# Modified from Kohaku-blueleaf's extract/merge functions
|
64 |
+
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
65 |
+
out_size, in_size, kernel_size, _ = weight.size()
|
66 |
+
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
|
67 |
+
|
68 |
+
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
69 |
+
lora_rank = param_dict["new_rank"]
|
70 |
+
|
71 |
+
U = U[:, :lora_rank]
|
72 |
+
S = S[:lora_rank]
|
73 |
+
U = U @ torch.diag(S)
|
74 |
+
Vh = Vh[:lora_rank, :]
|
75 |
+
|
76 |
+
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
|
77 |
+
param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
|
78 |
+
del U, S, Vh, weight
|
79 |
+
return param_dict
|
80 |
+
|
81 |
+
|
82 |
+
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
83 |
+
out_size, in_size = weight.size()
|
84 |
+
|
85 |
+
U, S, Vh = torch.linalg.svd(weight.to(device))
|
86 |
+
|
87 |
+
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
88 |
+
lora_rank = param_dict["new_rank"]
|
89 |
+
|
90 |
+
U = U[:, :lora_rank]
|
91 |
+
S = S[:lora_rank]
|
92 |
+
U = U @ torch.diag(S)
|
93 |
+
Vh = Vh[:lora_rank, :]
|
94 |
+
|
95 |
+
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
|
96 |
+
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
|
97 |
+
del U, S, Vh, weight
|
98 |
+
return param_dict
|
99 |
+
|
100 |
+
|
101 |
+
def merge_conv(lora_down, lora_up, device):
|
102 |
+
in_rank, in_size, kernel_size, k_ = lora_down.shape
|
103 |
+
out_size, out_rank, _, _ = lora_up.shape
|
104 |
+
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
|
105 |
+
|
106 |
+
lora_down = lora_down.to(device)
|
107 |
+
lora_up = lora_up.to(device)
|
108 |
+
|
109 |
+
merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
|
110 |
+
weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
|
111 |
+
del lora_up, lora_down
|
112 |
+
return weight
|
113 |
+
|
114 |
+
|
115 |
+
def merge_linear(lora_down, lora_up, device):
|
116 |
+
in_rank, in_size = lora_down.shape
|
117 |
+
out_size, out_rank = lora_up.shape
|
118 |
+
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
|
119 |
+
|
120 |
+
lora_down = lora_down.to(device)
|
121 |
+
lora_up = lora_up.to(device)
|
122 |
+
|
123 |
+
weight = lora_up @ lora_down
|
124 |
+
del lora_up, lora_down
|
125 |
+
return weight
|
126 |
+
|
127 |
+
|
128 |
+
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
129 |
+
param_dict = {}
|
130 |
+
|
131 |
+
if dynamic_method=="sv_ratio":
|
132 |
+
# Calculate new dim and alpha based off ratio
|
133 |
+
max_sv = S[0]
|
134 |
+
min_sv = max_sv/dynamic_param
|
135 |
+
new_rank = max(torch.sum(S > min_sv).item(),1)
|
136 |
+
new_alpha = float(scale*new_rank)
|
137 |
+
|
138 |
+
elif dynamic_method=="sv_cumulative":
|
139 |
+
# Calculate new dim and alpha based off cumulative sum
|
140 |
+
new_rank = index_sv_cumulative(S, dynamic_param)
|
141 |
+
new_rank = max(new_rank, 1)
|
142 |
+
new_alpha = float(scale*new_rank)
|
143 |
+
|
144 |
+
elif dynamic_method=="sv_fro":
|
145 |
+
# Calculate new dim and alpha based off sqrt sum of squares
|
146 |
+
new_rank = index_sv_fro(S, dynamic_param)
|
147 |
+
new_rank = min(max(new_rank, 1), len(S)-1)
|
148 |
+
new_alpha = float(scale*new_rank)
|
149 |
+
else:
|
150 |
+
new_rank = rank
|
151 |
+
new_alpha = float(scale*new_rank)
|
152 |
+
|
153 |
+
|
154 |
+
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
|
155 |
+
new_rank = 1
|
156 |
+
new_alpha = float(scale*new_rank)
|
157 |
+
elif new_rank > rank: # cap max rank at rank
|
158 |
+
new_rank = rank
|
159 |
+
new_alpha = float(scale*new_rank)
|
160 |
+
|
161 |
+
|
162 |
+
# Calculate resize info
|
163 |
+
s_sum = torch.sum(torch.abs(S))
|
164 |
+
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
165 |
+
|
166 |
+
S_squared = S.pow(2)
|
167 |
+
s_fro = torch.sqrt(torch.sum(S_squared))
|
168 |
+
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
|
169 |
+
fro_percent = float(s_red_fro/s_fro)
|
170 |
+
|
171 |
+
param_dict["new_rank"] = new_rank
|
172 |
+
param_dict["new_alpha"] = new_alpha
|
173 |
+
param_dict["sum_retained"] = (s_rank)/s_sum
|
174 |
+
param_dict["fro_retained"] = fro_percent
|
175 |
+
param_dict["max_ratio"] = S[0]/S[new_rank]
|
176 |
+
|
177 |
+
return param_dict
|
178 |
+
|
179 |
+
|
180 |
+
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
181 |
network_alpha = None
|
182 |
network_dim = None
|
183 |
verbose_str = "\n"
|
184 |
+
fro_list = []
|
|
|
185 |
|
186 |
# Extract loaded lora dim and alpha
|
187 |
for key, value in lora_sd.items():
|
|
|
195 |
network_alpha = network_dim
|
196 |
|
197 |
scale = network_alpha/network_dim
|
|
|
198 |
|
199 |
+
if dynamic_method:
|
200 |
+
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
|
201 |
|
202 |
lora_down_weight = None
|
203 |
lora_up_weight = None
|
|
|
206 |
block_down_name = None
|
207 |
block_up_name = None
|
208 |
|
|
|
209 |
with torch.no_grad():
|
210 |
for key, value in tqdm(lora_sd.items()):
|
211 |
if 'lora_down' in key:
|
|
|
222 |
conv2d = (len(lora_down_weight.size()) == 4)
|
223 |
|
224 |
if conv2d:
|
225 |
+
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
226 |
+
param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
227 |
+
else:
|
228 |
+
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
|
229 |
+
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
if verbose:
|
232 |
+
max_ratio = param_dict['max_ratio']
|
233 |
+
sum_retained = param_dict['sum_retained']
|
234 |
+
fro_retained = param_dict['fro_retained']
|
235 |
+
if not np.isnan(fro_retained):
|
236 |
+
fro_list.append(float(fro_retained))
|
237 |
|
238 |
+
verbose_str+=f"{block_down_name:75} | "
|
239 |
+
verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
|
|
|
240 |
|
241 |
+
if verbose and dynamic_method:
|
242 |
+
verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
|
243 |
+
else:
|
244 |
+
verbose_str+=f"\n"
|
245 |
|
246 |
+
new_alpha = param_dict['new_alpha']
|
247 |
+
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
|
248 |
+
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
|
249 |
+
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
|
251 |
block_down_name = None
|
252 |
block_up_name = None
|
253 |
lora_down_weight = None
|
254 |
lora_up_weight = None
|
255 |
weights_loaded = False
|
256 |
+
del param_dict
|
257 |
|
258 |
if verbose:
|
259 |
print(verbose_str)
|
260 |
+
|
261 |
+
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
262 |
print("resizing complete")
|
263 |
return o_lora_sd, network_dim, new_alpha
|
264 |
|
|
|
274 |
return torch.bfloat16
|
275 |
return None
|
276 |
|
277 |
+
if args.dynamic_method and not args.dynamic_param:
|
278 |
+
raise Exception("If using dynamic_method, then dynamic_param is required")
|
279 |
+
|
280 |
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
|
281 |
save_dtype = str_to_dtype(args.save_precision)
|
282 |
if save_dtype is None:
|
|
|
285 |
print("loading Model...")
|
286 |
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
287 |
|
288 |
+
print("Resizing Lora...")
|
289 |
+
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
|
290 |
|
291 |
# update metadata
|
292 |
if metadata is None:
|
293 |
metadata = {}
|
294 |
|
295 |
comment = metadata.get("ss_training_comment", "")
|
296 |
+
|
297 |
+
if not args.dynamic_method:
|
298 |
+
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
|
299 |
+
metadata["ss_network_dim"] = str(args.new_rank)
|
300 |
+
metadata["ss_network_alpha"] = str(new_alpha)
|
301 |
+
else:
|
302 |
+
metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
|
303 |
+
metadata["ss_network_dim"] = 'Dynamic'
|
304 |
+
metadata["ss_network_alpha"] = 'Dynamic'
|
305 |
|
306 |
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
307 |
metadata["sshs_model_hash"] = model_hash
|
|
|
325 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
326 |
parser.add_argument("--verbose", action="store_true",
|
327 |
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
|
328 |
+
parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
|
329 |
+
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
|
330 |
+
parser.add_argument("--dynamic_param", type=float, default=None,
|
331 |
+
help="Specify target for dynamic reduction")
|
332 |
+
|
333 |
|
334 |
args = parser.parse_args()
|
335 |
resize(args)
|
networks/svd_merge_lora.py
CHANGED
@@ -23,19 +23,20 @@ def load_state_dict(file_name, dtype):
|
|
23 |
return sd
|
24 |
|
25 |
|
26 |
-
def save_to_file(file_name,
|
27 |
if dtype is not None:
|
28 |
for key in list(state_dict.keys()):
|
29 |
if type(state_dict[key]) == torch.Tensor:
|
30 |
state_dict[key] = state_dict[key].to(dtype)
|
31 |
|
32 |
if os.path.splitext(file_name)[1] == '.safetensors':
|
33 |
-
save_file(
|
34 |
else:
|
35 |
-
torch.save(
|
36 |
|
37 |
|
38 |
-
def merge_lora_models(models, ratios, new_rank, device,
|
|
|
39 |
merged_sd = {}
|
40 |
for model, ratio in zip(models, ratios):
|
41 |
print(f"loading: {model}")
|
@@ -58,11 +59,12 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
|
58 |
in_dim = down_weight.size()[1]
|
59 |
out_dim = up_weight.size()[0]
|
60 |
conv2d = len(down_weight.size()) == 4
|
61 |
-
|
|
|
62 |
|
63 |
# make original weight if not exist
|
64 |
if lora_module_name not in merged_sd:
|
65 |
-
weight = torch.zeros((out_dim, in_dim,
|
66 |
if device:
|
67 |
weight = weight.to(device)
|
68 |
else:
|
@@ -75,11 +77,18 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
|
75 |
|
76 |
# W <- W + U * D
|
77 |
scale = (alpha / network_dim)
|
|
|
|
|
|
|
|
|
78 |
if not conv2d: # linear
|
79 |
weight = weight + ratio * (up_weight @ down_weight) * scale
|
80 |
-
|
81 |
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
82 |
).unsqueeze(2).unsqueeze(3) * scale
|
|
|
|
|
|
|
83 |
|
84 |
merged_sd[lora_module_name] = weight
|
85 |
|
@@ -89,16 +98,26 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
|
89 |
with torch.no_grad():
|
90 |
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
91 |
conv2d = (len(mat.size()) == 4)
|
|
|
|
|
|
|
|
|
92 |
if conv2d:
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
U, S, Vh = torch.linalg.svd(mat)
|
96 |
|
97 |
-
U = U[:, :
|
98 |
-
S = S[:
|
99 |
U = U @ torch.diag(S)
|
100 |
|
101 |
-
Vh = Vh[:
|
102 |
|
103 |
dist = torch.cat([U.flatten(), Vh.flatten()])
|
104 |
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
@@ -107,16 +126,16 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
|
107 |
U = U.clamp(low_val, hi_val)
|
108 |
Vh = Vh.clamp(low_val, hi_val)
|
109 |
|
|
|
|
|
|
|
|
|
110 |
up_weight = U
|
111 |
down_weight = Vh
|
112 |
|
113 |
-
if conv2d:
|
114 |
-
up_weight = up_weight.unsqueeze(2).unsqueeze(3)
|
115 |
-
down_weight = down_weight.unsqueeze(2).unsqueeze(3)
|
116 |
-
|
117 |
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
|
118 |
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
|
119 |
-
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(
|
120 |
|
121 |
return merged_lora_sd
|
122 |
|
@@ -138,10 +157,11 @@ def merge(args):
|
|
138 |
if save_dtype is None:
|
139 |
save_dtype = merge_dtype
|
140 |
|
141 |
-
|
|
|
142 |
|
143 |
print(f"saving model to: {args.save_to}")
|
144 |
-
save_to_file(args.save_to, state_dict,
|
145 |
|
146 |
|
147 |
if __name__ == '__main__':
|
@@ -158,6 +178,8 @@ if __name__ == '__main__':
|
|
158 |
help="ratios for each model / それぞれのLoRAモデルの比率")
|
159 |
parser.add_argument("--new_rank", type=int, default=4,
|
160 |
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
|
|
|
|
161 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
162 |
|
163 |
args = parser.parse_args()
|
|
|
23 |
return sd
|
24 |
|
25 |
|
26 |
+
def save_to_file(file_name, state_dict, dtype):
|
27 |
if dtype is not None:
|
28 |
for key in list(state_dict.keys()):
|
29 |
if type(state_dict[key]) == torch.Tensor:
|
30 |
state_dict[key] = state_dict[key].to(dtype)
|
31 |
|
32 |
if os.path.splitext(file_name)[1] == '.safetensors':
|
33 |
+
save_file(state_dict, file_name)
|
34 |
else:
|
35 |
+
torch.save(state_dict, file_name)
|
36 |
|
37 |
|
38 |
+
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
|
39 |
+
print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
|
40 |
merged_sd = {}
|
41 |
for model, ratio in zip(models, ratios):
|
42 |
print(f"loading: {model}")
|
|
|
59 |
in_dim = down_weight.size()[1]
|
60 |
out_dim = up_weight.size()[0]
|
61 |
conv2d = len(down_weight.size()) == 4
|
62 |
+
kernel_size = None if not conv2d else down_weight.size()[2:4]
|
63 |
+
# print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
|
64 |
|
65 |
# make original weight if not exist
|
66 |
if lora_module_name not in merged_sd:
|
67 |
+
weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
|
68 |
if device:
|
69 |
weight = weight.to(device)
|
70 |
else:
|
|
|
77 |
|
78 |
# W <- W + U * D
|
79 |
scale = (alpha / network_dim)
|
80 |
+
|
81 |
+
if device: # and isinstance(scale, torch.Tensor):
|
82 |
+
scale = scale.to(device)
|
83 |
+
|
84 |
if not conv2d: # linear
|
85 |
weight = weight + ratio * (up_weight @ down_weight) * scale
|
86 |
+
elif kernel_size == (1, 1):
|
87 |
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
88 |
).unsqueeze(2).unsqueeze(3) * scale
|
89 |
+
else:
|
90 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
91 |
+
weight = weight + ratio * conved * scale
|
92 |
|
93 |
merged_sd[lora_module_name] = weight
|
94 |
|
|
|
98 |
with torch.no_grad():
|
99 |
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
100 |
conv2d = (len(mat.size()) == 4)
|
101 |
+
kernel_size = None if not conv2d else mat.size()[2:4]
|
102 |
+
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
103 |
+
out_dim, in_dim = mat.size()[0:2]
|
104 |
+
|
105 |
if conv2d:
|
106 |
+
if conv2d_3x3:
|
107 |
+
mat = mat.flatten(start_dim=1)
|
108 |
+
else:
|
109 |
+
mat = mat.squeeze()
|
110 |
+
|
111 |
+
module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
|
112 |
+
module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
113 |
|
114 |
U, S, Vh = torch.linalg.svd(mat)
|
115 |
|
116 |
+
U = U[:, :module_new_rank]
|
117 |
+
S = S[:module_new_rank]
|
118 |
U = U @ torch.diag(S)
|
119 |
|
120 |
+
Vh = Vh[:module_new_rank, :]
|
121 |
|
122 |
dist = torch.cat([U.flatten(), Vh.flatten()])
|
123 |
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
|
|
126 |
U = U.clamp(low_val, hi_val)
|
127 |
Vh = Vh.clamp(low_val, hi_val)
|
128 |
|
129 |
+
if conv2d:
|
130 |
+
U = U.reshape(out_dim, module_new_rank, 1, 1)
|
131 |
+
Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1])
|
132 |
+
|
133 |
up_weight = U
|
134 |
down_weight = Vh
|
135 |
|
|
|
|
|
|
|
|
|
136 |
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
|
137 |
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
|
138 |
+
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(module_new_rank)
|
139 |
|
140 |
return merged_lora_sd
|
141 |
|
|
|
157 |
if save_dtype is None:
|
158 |
save_dtype = merge_dtype
|
159 |
|
160 |
+
new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
|
161 |
+
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype)
|
162 |
|
163 |
print(f"saving model to: {args.save_to}")
|
164 |
+
save_to_file(args.save_to, state_dict, save_dtype)
|
165 |
|
166 |
|
167 |
if __name__ == '__main__':
|
|
|
178 |
help="ratios for each model / それぞれのLoRAモデルの比率")
|
179 |
parser.add_argument("--new_rank", type=int, default=4,
|
180 |
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
181 |
+
parser.add_argument("--new_conv_rank", type=int, default=None,
|
182 |
+
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ")
|
183 |
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
184 |
|
185 |
args = parser.parse_args()
|
requirements.txt
CHANGED
@@ -12,6 +12,8 @@ safetensors==0.2.6
|
|
12 |
gradio==3.16.2
|
13 |
altair==4.2.2
|
14 |
easygui==0.98.3
|
|
|
|
|
15 |
# for BLIP captioning
|
16 |
requests==2.28.2
|
17 |
timm==0.6.12
|
@@ -21,5 +23,4 @@ fairscale==0.4.13
|
|
21 |
tensorflow==2.10.1
|
22 |
huggingface-hub==0.12.0
|
23 |
# for kohya_ss library
|
24 |
-
#locon.locon_kohya
|
25 |
.
|
|
|
12 |
gradio==3.16.2
|
13 |
altair==4.2.2
|
14 |
easygui==0.98.3
|
15 |
+
toml==0.10.2
|
16 |
+
voluptuous==0.13.1
|
17 |
# for BLIP captioning
|
18 |
requests==2.28.2
|
19 |
timm==0.6.12
|
|
|
23 |
tensorflow==2.10.1
|
24 |
huggingface-hub==0.12.0
|
25 |
# for kohya_ss library
|
|
|
26 |
.
|
train_db.py
CHANGED
@@ -15,7 +15,11 @@ import diffusers
|
|
15 |
from diffusers import DDPMScheduler
|
16 |
|
17 |
import library.train_util as train_util
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
|
20 |
|
21 |
def collate_fn(examples):
|
@@ -33,24 +37,33 @@ def train(args):
|
|
33 |
|
34 |
tokenizer = train_util.load_tokenizer(args)
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
|
46 |
-
|
47 |
|
48 |
-
|
|
|
49 |
|
50 |
if args.debug_dataset:
|
51 |
-
train_util.debug_dataset(
|
52 |
return
|
53 |
|
|
|
|
|
|
|
54 |
# acceleratorを準備する
|
55 |
print("prepare accelerator")
|
56 |
|
@@ -91,7 +104,7 @@ def train(args):
|
|
91 |
vae.requires_grad_(False)
|
92 |
vae.eval()
|
93 |
with torch.no_grad():
|
94 |
-
|
95 |
vae.to("cpu")
|
96 |
if torch.cuda.is_available():
|
97 |
torch.cuda.empty_cache()
|
@@ -115,38 +128,18 @@ def train(args):
|
|
115 |
|
116 |
# 学習に必要なクラスを準備する
|
117 |
print("prepare optimizer, data loader etc.")
|
118 |
-
|
119 |
-
# 8-bit Adamを使う
|
120 |
-
if args.use_8bit_adam:
|
121 |
-
try:
|
122 |
-
import bitsandbytes as bnb
|
123 |
-
except ImportError:
|
124 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
125 |
-
print("use 8-bit Adam optimizer")
|
126 |
-
optimizer_class = bnb.optim.AdamW8bit
|
127 |
-
elif args.use_lion_optimizer:
|
128 |
-
try:
|
129 |
-
import lion_pytorch
|
130 |
-
except ImportError:
|
131 |
-
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
132 |
-
print("use Lion optimizer")
|
133 |
-
optimizer_class = lion_pytorch.Lion
|
134 |
-
else:
|
135 |
-
optimizer_class = torch.optim.AdamW
|
136 |
-
|
137 |
if train_text_encoder:
|
138 |
trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
139 |
else:
|
140 |
trainable_params = unet.parameters()
|
141 |
|
142 |
-
|
143 |
-
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
144 |
|
145 |
# dataloaderを準備する
|
146 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
147 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
148 |
train_dataloader = torch.utils.data.DataLoader(
|
149 |
-
|
150 |
|
151 |
# 学習ステップ数を計算する
|
152 |
if args.max_train_epochs is not None:
|
@@ -156,9 +149,10 @@ def train(args):
|
|
156 |
if args.stop_text_encoder_training is None:
|
157 |
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
158 |
|
159 |
-
# lr schedulerを用意する
|
160 |
-
lr_scheduler =
|
161 |
-
|
|
|
162 |
|
163 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
164 |
if args.full_fp16:
|
@@ -195,8 +189,8 @@ def train(args):
|
|
195 |
# 学習する
|
196 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
197 |
print("running training / 学習開始")
|
198 |
-
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {
|
199 |
-
print(f" num reg images / 正則化画像の数: {
|
200 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
201 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
202 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
@@ -217,7 +211,7 @@ def train(args):
|
|
217 |
loss_total = 0.0
|
218 |
for epoch in range(num_train_epochs):
|
219 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
220 |
-
|
221 |
|
222 |
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
223 |
unet.train()
|
@@ -281,12 +275,12 @@ def train(args):
|
|
281 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
282 |
|
283 |
accelerator.backward(loss)
|
284 |
-
if accelerator.sync_gradients:
|
285 |
if train_text_encoder:
|
286 |
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
287 |
else:
|
288 |
params_to_clip = unet.parameters()
|
289 |
-
accelerator.clip_grad_norm_(params_to_clip,
|
290 |
|
291 |
optimizer.step()
|
292 |
lr_scheduler.step()
|
@@ -297,9 +291,13 @@ def train(args):
|
|
297 |
progress_bar.update(1)
|
298 |
global_step += 1
|
299 |
|
|
|
|
|
300 |
current_loss = loss.detach().item()
|
301 |
if args.logging_dir is not None:
|
302 |
-
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
|
|
303 |
accelerator.log(logs, step=global_step)
|
304 |
|
305 |
if epoch == 0:
|
@@ -326,6 +324,8 @@ def train(args):
|
|
326 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
327 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
328 |
|
|
|
|
|
329 |
is_main_process = accelerator.is_main_process
|
330 |
if is_main_process:
|
331 |
unet = unwrap_model(unet)
|
@@ -352,6 +352,8 @@ if __name__ == '__main__':
|
|
352 |
train_util.add_dataset_arguments(parser, True, False, True)
|
353 |
train_util.add_training_arguments(parser, True)
|
354 |
train_util.add_sd_saving_arguments(parser)
|
|
|
|
|
355 |
|
356 |
parser.add_argument("--no_token_padding", action="store_true",
|
357 |
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
|
|
15 |
from diffusers import DDPMScheduler
|
16 |
|
17 |
import library.train_util as train_util
|
18 |
+
import library.config_util as config_util
|
19 |
+
from library.config_util import (
|
20 |
+
ConfigSanitizer,
|
21 |
+
BlueprintGenerator,
|
22 |
+
)
|
23 |
|
24 |
|
25 |
def collate_fn(examples):
|
|
|
37 |
|
38 |
tokenizer = train_util.load_tokenizer(args)
|
39 |
|
40 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
|
41 |
+
if args.dataset_config is not None:
|
42 |
+
print(f"Load dataset config from {args.dataset_config}")
|
43 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
44 |
+
ignored = ["train_data_dir", "reg_data_dir"]
|
45 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
46 |
+
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
47 |
+
else:
|
48 |
+
user_config = {
|
49 |
+
"datasets": [{
|
50 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
|
51 |
+
}]
|
52 |
+
}
|
53 |
|
54 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
55 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
56 |
|
57 |
+
if args.no_token_padding:
|
58 |
+
train_dataset_group.disable_token_padding()
|
59 |
|
60 |
if args.debug_dataset:
|
61 |
+
train_util.debug_dataset(train_dataset_group)
|
62 |
return
|
63 |
|
64 |
+
if cache_latents:
|
65 |
+
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
66 |
+
|
67 |
# acceleratorを準備する
|
68 |
print("prepare accelerator")
|
69 |
|
|
|
104 |
vae.requires_grad_(False)
|
105 |
vae.eval()
|
106 |
with torch.no_grad():
|
107 |
+
train_dataset_group.cache_latents(vae)
|
108 |
vae.to("cpu")
|
109 |
if torch.cuda.is_available():
|
110 |
torch.cuda.empty_cache()
|
|
|
128 |
|
129 |
# 学習に必要なクラスを準備する
|
130 |
print("prepare optimizer, data loader etc.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
if train_text_encoder:
|
132 |
trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
133 |
else:
|
134 |
trainable_params = unet.parameters()
|
135 |
|
136 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
|
|
137 |
|
138 |
# dataloaderを準備する
|
139 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
140 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
141 |
train_dataloader = torch.utils.data.DataLoader(
|
142 |
+
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
143 |
|
144 |
# 学習ステップ数を計算する
|
145 |
if args.max_train_epochs is not None:
|
|
|
149 |
if args.stop_text_encoder_training is None:
|
150 |
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
151 |
|
152 |
+
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
|
153 |
+
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
154 |
+
num_training_steps=args.max_train_steps,
|
155 |
+
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
156 |
|
157 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
158 |
if args.full_fp16:
|
|
|
189 |
# 学習する
|
190 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
191 |
print("running training / 学習開始")
|
192 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
193 |
+
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
194 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
195 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
196 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
|
|
211 |
loss_total = 0.0
|
212 |
for epoch in range(num_train_epochs):
|
213 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
214 |
+
train_dataset_group.set_current_epoch(epoch + 1)
|
215 |
|
216 |
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
217 |
unet.train()
|
|
|
275 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
276 |
|
277 |
accelerator.backward(loss)
|
278 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
279 |
if train_text_encoder:
|
280 |
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
281 |
else:
|
282 |
params_to_clip = unet.parameters()
|
283 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
284 |
|
285 |
optimizer.step()
|
286 |
lr_scheduler.step()
|
|
|
291 |
progress_bar.update(1)
|
292 |
global_step += 1
|
293 |
|
294 |
+
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
295 |
+
|
296 |
current_loss = loss.detach().item()
|
297 |
if args.logging_dir is not None:
|
298 |
+
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
299 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
300 |
+
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
301 |
accelerator.log(logs, step=global_step)
|
302 |
|
303 |
if epoch == 0:
|
|
|
324 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
325 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
326 |
|
327 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
328 |
+
|
329 |
is_main_process = accelerator.is_main_process
|
330 |
if is_main_process:
|
331 |
unet = unwrap_model(unet)
|
|
|
352 |
train_util.add_dataset_arguments(parser, True, False, True)
|
353 |
train_util.add_training_arguments(parser, True)
|
354 |
train_util.add_sd_saving_arguments(parser)
|
355 |
+
train_util.add_optimizer_arguments(parser)
|
356 |
+
config_util.add_config_arguments(parser)
|
357 |
|
358 |
parser.add_argument("--no_token_padding", action="store_true",
|
359 |
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
train_network.py
CHANGED
@@ -1,8 +1,4 @@
|
|
1 |
-
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
2 |
-
from torch.optim import Optimizer
|
3 |
-
from torch.cuda.amp import autocast
|
4 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
5 |
-
from typing import Optional, Union
|
6 |
import importlib
|
7 |
import argparse
|
8 |
import gc
|
@@ -15,92 +11,39 @@ import json
|
|
15 |
from tqdm import tqdm
|
16 |
import torch
|
17 |
from accelerate.utils import set_seed
|
18 |
-
import diffusers
|
19 |
from diffusers import DDPMScheduler
|
20 |
|
21 |
import library.train_util as train_util
|
22 |
-
from library.train_util import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
def collate_fn(examples):
|
26 |
return examples[0]
|
27 |
|
28 |
|
|
|
29 |
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
30 |
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
31 |
|
32 |
if args.network_train_unet_only:
|
33 |
-
logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
|
34 |
elif args.network_train_text_encoder_only:
|
35 |
-
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
36 |
else:
|
37 |
-
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
38 |
-
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
|
39 |
-
|
40 |
-
return logs
|
41 |
|
|
|
|
|
42 |
|
43 |
-
|
44 |
-
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
45 |
-
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
46 |
-
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
47 |
-
|
48 |
-
|
49 |
-
def get_scheduler_fix(
|
50 |
-
name: Union[str, SchedulerType],
|
51 |
-
optimizer: Optimizer,
|
52 |
-
num_warmup_steps: Optional[int] = None,
|
53 |
-
num_training_steps: Optional[int] = None,
|
54 |
-
num_cycles: int = 1,
|
55 |
-
power: float = 1.0,
|
56 |
-
):
|
57 |
-
"""
|
58 |
-
Unified API to get any scheduler from its name.
|
59 |
-
Args:
|
60 |
-
name (`str` or `SchedulerType`):
|
61 |
-
The name of the scheduler to use.
|
62 |
-
optimizer (`torch.optim.Optimizer`):
|
63 |
-
The optimizer that will be used during training.
|
64 |
-
num_warmup_steps (`int`, *optional*):
|
65 |
-
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
66 |
-
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
67 |
-
num_training_steps (`int``, *optional*):
|
68 |
-
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
69 |
-
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
70 |
-
num_cycles (`int`, *optional*):
|
71 |
-
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
72 |
-
power (`float`, *optional*, defaults to 1.0):
|
73 |
-
Power factor. See `POLYNOMIAL` scheduler
|
74 |
-
last_epoch (`int`, *optional*, defaults to -1):
|
75 |
-
The index of the last epoch when resuming training.
|
76 |
-
"""
|
77 |
-
name = SchedulerType(name)
|
78 |
-
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
79 |
-
if name == SchedulerType.CONSTANT:
|
80 |
-
return schedule_func(optimizer)
|
81 |
-
|
82 |
-
# All other schedulers require `num_warmup_steps`
|
83 |
-
if num_warmup_steps is None:
|
84 |
-
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
85 |
-
|
86 |
-
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
87 |
-
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
88 |
-
|
89 |
-
# All other schedulers require `num_training_steps`
|
90 |
-
if num_training_steps is None:
|
91 |
-
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
92 |
-
|
93 |
-
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
94 |
-
return schedule_func(
|
95 |
-
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
96 |
-
)
|
97 |
-
|
98 |
-
if name == SchedulerType.POLYNOMIAL:
|
99 |
-
return schedule_func(
|
100 |
-
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
101 |
-
)
|
102 |
-
|
103 |
-
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
104 |
|
105 |
|
106 |
def train(args):
|
@@ -111,6 +54,7 @@ def train(args):
|
|
111 |
|
112 |
cache_latents = args.cache_latents
|
113 |
use_dreambooth_method = args.in_json is None
|
|
|
114 |
|
115 |
if args.seed is not None:
|
116 |
set_seed(args.seed)
|
@@ -118,38 +62,51 @@ def train(args):
|
|
118 |
tokenizer = train_util.load_tokenizer(args)
|
119 |
|
120 |
# データセットを準備する
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
else:
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
if args.debug_dataset:
|
144 |
-
train_util.debug_dataset(
|
145 |
return
|
146 |
-
if len(
|
147 |
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
148 |
return
|
149 |
|
|
|
|
|
|
|
|
|
150 |
# acceleratorを準備する
|
151 |
print("prepare accelerator")
|
152 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
|
|
153 |
|
154 |
# mixed precisionに対応した型を用意しておき適宜castする
|
155 |
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
@@ -161,7 +118,7 @@ def train(args):
|
|
161 |
if args.lowram:
|
162 |
text_encoder.to("cuda")
|
163 |
unet.to("cuda")
|
164 |
-
|
165 |
# モデルに xformers とか memory efficient attention を組み込む
|
166 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
167 |
|
@@ -171,13 +128,15 @@ def train(args):
|
|
171 |
vae.requires_grad_(False)
|
172 |
vae.eval()
|
173 |
with torch.no_grad():
|
174 |
-
|
175 |
vae.to("cpu")
|
176 |
if torch.cuda.is_available():
|
177 |
torch.cuda.empty_cache()
|
178 |
gc.collect()
|
179 |
|
180 |
# prepare network
|
|
|
|
|
181 |
print("import network module:", args.network_module)
|
182 |
network_module = importlib.import_module(args.network_module)
|
183 |
|
@@ -208,48 +167,25 @@ def train(args):
|
|
208 |
# 学習に必要なクラスを準備する
|
209 |
print("prepare optimizer, data loader etc.")
|
210 |
|
211 |
-
# 8-bit Adamを使う
|
212 |
-
if args.use_8bit_adam:
|
213 |
-
try:
|
214 |
-
import bitsandbytes as bnb
|
215 |
-
except ImportError:
|
216 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
217 |
-
print("use 8-bit Adam optimizer")
|
218 |
-
optimizer_class = bnb.optim.AdamW8bit
|
219 |
-
elif args.use_lion_optimizer:
|
220 |
-
try:
|
221 |
-
import lion_pytorch
|
222 |
-
except ImportError:
|
223 |
-
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
224 |
-
print("use Lion optimizer")
|
225 |
-
optimizer_class = lion_pytorch.Lion
|
226 |
-
else:
|
227 |
-
optimizer_class = torch.optim.AdamW
|
228 |
-
|
229 |
-
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
230 |
-
|
231 |
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
232 |
-
|
233 |
-
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
234 |
-
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
235 |
|
236 |
# dataloaderを準備する
|
237 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
238 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
239 |
train_dataloader = torch.utils.data.DataLoader(
|
240 |
-
|
241 |
|
242 |
# 学習ステップ数を計算する
|
243 |
if args.max_train_epochs is not None:
|
244 |
-
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
245 |
-
|
|
|
246 |
|
247 |
# lr schedulerを用意する
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
252 |
-
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
253 |
|
254 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
255 |
if args.full_fp16:
|
@@ -317,17 +253,21 @@ def train(args):
|
|
317 |
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
318 |
|
319 |
# 学習する
|
|
|
320 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
|
|
|
|
|
|
331 |
metadata = {
|
332 |
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
333 |
"ss_training_started_at": training_started_at, # unix timestamp
|
@@ -335,12 +275,10 @@ def train(args):
|
|
335 |
"ss_learning_rate": args.learning_rate,
|
336 |
"ss_text_encoder_lr": args.text_encoder_lr,
|
337 |
"ss_unet_lr": args.unet_lr,
|
338 |
-
"ss_num_train_images":
|
339 |
-
"ss_num_reg_images":
|
340 |
"ss_num_batches_per_epoch": len(train_dataloader),
|
341 |
"ss_num_epochs": num_train_epochs,
|
342 |
-
"ss_batch_size_per_device": args.train_batch_size,
|
343 |
-
"ss_total_batch_size": total_batch_size,
|
344 |
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
345 |
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
346 |
"ss_max_train_steps": args.max_train_steps,
|
@@ -352,33 +290,156 @@ def train(args):
|
|
352 |
"ss_mixed_precision": args.mixed_precision,
|
353 |
"ss_full_fp16": bool(args.full_fp16),
|
354 |
"ss_v2": bool(args.v2),
|
355 |
-
"ss_resolution": args.resolution,
|
356 |
"ss_clip_skip": args.clip_skip,
|
357 |
"ss_max_token_length": args.max_token_length,
|
358 |
-
"ss_color_aug": bool(args.color_aug),
|
359 |
-
"ss_flip_aug": bool(args.flip_aug),
|
360 |
-
"ss_random_crop": bool(args.random_crop),
|
361 |
-
"ss_shuffle_caption": bool(args.shuffle_caption),
|
362 |
"ss_cache_latents": bool(args.cache_latents),
|
363 |
-
"ss_enable_bucket": bool(train_dataset.enable_bucket),
|
364 |
-
"ss_min_bucket_reso": train_dataset.min_bucket_reso,
|
365 |
-
"ss_max_bucket_reso": train_dataset.max_bucket_reso,
|
366 |
"ss_seed": args.seed,
|
367 |
-
"
|
368 |
"ss_noise_offset": args.noise_offset,
|
369 |
-
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
370 |
-
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
371 |
-
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
|
372 |
-
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
|
373 |
"ss_training_comment": args.training_comment, # will not be updated after training
|
374 |
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
375 |
-
"ss_optimizer": optimizer_name
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
}
|
377 |
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
if args.pretrained_model_name_or_path is not None:
|
383 |
sd_model_name = args.pretrained_model_name_or_path
|
384 |
if os.path.exists(sd_model_name):
|
@@ -397,6 +458,13 @@ def train(args):
|
|
397 |
|
398 |
metadata = {k: str(v) for k, v in metadata.items()}
|
399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
401 |
global_step = 0
|
402 |
|
@@ -409,8 +477,9 @@ def train(args):
|
|
409 |
loss_list = []
|
410 |
loss_total = 0.0
|
411 |
for epoch in range(num_train_epochs):
|
412 |
-
|
413 |
-
|
|
|
414 |
|
415 |
metadata["ss_epoch"] = str(epoch+1)
|
416 |
|
@@ -447,7 +516,7 @@ def train(args):
|
|
447 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
448 |
|
449 |
# Predict the noise residual
|
450 |
-
with autocast():
|
451 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
452 |
|
453 |
if args.v_parameterization:
|
@@ -465,9 +534,9 @@ def train(args):
|
|
465 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
466 |
|
467 |
accelerator.backward(loss)
|
468 |
-
if accelerator.sync_gradients:
|
469 |
params_to_clip = network.get_trainable_params()
|
470 |
-
accelerator.clip_grad_norm_(params_to_clip,
|
471 |
|
472 |
optimizer.step()
|
473 |
lr_scheduler.step()
|
@@ -478,6 +547,8 @@ def train(args):
|
|
478 |
progress_bar.update(1)
|
479 |
global_step += 1
|
480 |
|
|
|
|
|
481 |
current_loss = loss.detach().item()
|
482 |
if epoch == 0:
|
483 |
loss_list.append(current_loss)
|
@@ -508,8 +579,9 @@ def train(args):
|
|
508 |
def save_func():
|
509 |
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
510 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
|
|
511 |
print(f"saving checkpoint: {ckpt_file}")
|
512 |
-
unwrap_model(network).save_weights(ckpt_file, save_dtype,
|
513 |
|
514 |
def remove_old_func(old_epoch_no):
|
515 |
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
@@ -518,15 +590,18 @@ def train(args):
|
|
518 |
print(f"removing old checkpoint: {old_ckpt_file}")
|
519 |
os.remove(old_ckpt_file)
|
520 |
|
521 |
-
|
522 |
-
|
523 |
-
|
|
|
|
|
|
|
524 |
|
525 |
# end of epoch
|
526 |
|
527 |
metadata["ss_epoch"] = str(num_train_epochs)
|
|
|
528 |
|
529 |
-
is_main_process = accelerator.is_main_process
|
530 |
if is_main_process:
|
531 |
network = unwrap_model(network)
|
532 |
|
@@ -545,7 +620,7 @@ def train(args):
|
|
545 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
546 |
|
547 |
print(f"save trained model to {ckpt_file}")
|
548 |
-
network.save_weights(ckpt_file, save_dtype,
|
549 |
print("model saved.")
|
550 |
|
551 |
|
@@ -555,6 +630,8 @@ if __name__ == '__main__':
|
|
555 |
train_util.add_sd_models_arguments(parser)
|
556 |
train_util.add_dataset_arguments(parser, True, True, True)
|
557 |
train_util.add_training_arguments(parser, True)
|
|
|
|
|
558 |
|
559 |
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
560 |
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
@@ -562,10 +639,6 @@ if __name__ == '__main__':
|
|
562 |
|
563 |
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
564 |
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
565 |
-
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
566 |
-
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
567 |
-
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
568 |
-
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
569 |
|
570 |
parser.add_argument("--network_weights", type=str, default=None,
|
571 |
help="pretrained weights for network / 学習するネットワークの初期重み")
|
|
|
|
|
|
|
|
|
1 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
2 |
import importlib
|
3 |
import argparse
|
4 |
import gc
|
|
|
11 |
from tqdm import tqdm
|
12 |
import torch
|
13 |
from accelerate.utils import set_seed
|
|
|
14 |
from diffusers import DDPMScheduler
|
15 |
|
16 |
import library.train_util as train_util
|
17 |
+
from library.train_util import (
|
18 |
+
DreamBoothDataset,
|
19 |
+
)
|
20 |
+
import library.config_util as config_util
|
21 |
+
from library.config_util import (
|
22 |
+
ConfigSanitizer,
|
23 |
+
BlueprintGenerator,
|
24 |
+
)
|
25 |
|
26 |
|
27 |
def collate_fn(examples):
|
28 |
return examples[0]
|
29 |
|
30 |
|
31 |
+
# TODO 他のスクリプトと共通化する
|
32 |
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
33 |
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
34 |
|
35 |
if args.network_train_unet_only:
|
36 |
+
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
|
37 |
elif args.network_train_text_encoder_only:
|
38 |
+
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
39 |
else:
|
40 |
+
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
41 |
+
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
|
|
|
|
|
42 |
|
43 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
|
44 |
+
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
|
45 |
|
46 |
+
return logs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
|
49 |
def train(args):
|
|
|
54 |
|
55 |
cache_latents = args.cache_latents
|
56 |
use_dreambooth_method = args.in_json is None
|
57 |
+
use_user_config = args.dataset_config is not None
|
58 |
|
59 |
if args.seed is not None:
|
60 |
set_seed(args.seed)
|
|
|
62 |
tokenizer = train_util.load_tokenizer(args)
|
63 |
|
64 |
# データセットを準備する
|
65 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
|
66 |
+
if use_user_config:
|
67 |
+
print(f"Load dataset config from {args.dataset_config}")
|
68 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
69 |
+
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
70 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
71 |
+
print(
|
72 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
73 |
else:
|
74 |
+
if use_dreambooth_method:
|
75 |
+
print("Use DreamBooth method.")
|
76 |
+
user_config = {
|
77 |
+
"datasets": [{
|
78 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
|
79 |
+
}]
|
80 |
+
}
|
81 |
+
else:
|
82 |
+
print("Train with captions.")
|
83 |
+
user_config = {
|
84 |
+
"datasets": [{
|
85 |
+
"subsets": [{
|
86 |
+
"image_dir": args.train_data_dir,
|
87 |
+
"metadata_file": args.in_json,
|
88 |
+
}]
|
89 |
+
}]
|
90 |
+
}
|
91 |
+
|
92 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
93 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
94 |
|
95 |
if args.debug_dataset:
|
96 |
+
train_util.debug_dataset(train_dataset_group)
|
97 |
return
|
98 |
+
if len(train_dataset_group) == 0:
|
99 |
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
100 |
return
|
101 |
|
102 |
+
if cache_latents:
|
103 |
+
assert train_dataset_group.is_latent_cacheable(
|
104 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
105 |
+
|
106 |
# acceleratorを準備する
|
107 |
print("prepare accelerator")
|
108 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
109 |
+
is_main_process = accelerator.is_main_process
|
110 |
|
111 |
# mixed precisionに対応した型を用意しておき適宜castする
|
112 |
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
|
|
118 |
if args.lowram:
|
119 |
text_encoder.to("cuda")
|
120 |
unet.to("cuda")
|
121 |
+
|
122 |
# モデルに xformers とか memory efficient attention を組み込む
|
123 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
124 |
|
|
|
128 |
vae.requires_grad_(False)
|
129 |
vae.eval()
|
130 |
with torch.no_grad():
|
131 |
+
train_dataset_group.cache_latents(vae)
|
132 |
vae.to("cpu")
|
133 |
if torch.cuda.is_available():
|
134 |
torch.cuda.empty_cache()
|
135 |
gc.collect()
|
136 |
|
137 |
# prepare network
|
138 |
+
import sys
|
139 |
+
sys.path.append(os.path.dirname(__file__))
|
140 |
print("import network module:", args.network_module)
|
141 |
network_module = importlib.import_module(args.network_module)
|
142 |
|
|
|
167 |
# 学習に必要なクラスを準備する
|
168 |
print("prepare optimizer, data loader etc.")
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
171 |
+
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
|
|
|
|
172 |
|
173 |
# dataloaderを準備する
|
174 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
175 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
176 |
train_dataloader = torch.utils.data.DataLoader(
|
177 |
+
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
178 |
|
179 |
# 学習ステップ数を計算する
|
180 |
if args.max_train_epochs is not None:
|
181 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes)
|
182 |
+
if is_main_process:
|
183 |
+
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
184 |
|
185 |
# lr schedulerを用意する
|
186 |
+
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
187 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps,
|
188 |
+
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
|
|
|
|
189 |
|
190 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
191 |
if args.full_fp16:
|
|
|
253 |
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
254 |
|
255 |
# 学習する
|
256 |
+
# TODO: find a way to handle total batch size when there are multiple datasets
|
257 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
258 |
+
|
259 |
+
if is_main_process:
|
260 |
+
print("running training / 学習開始")
|
261 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
262 |
+
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
263 |
+
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
264 |
+
print(f" num epochs / epoch数: {num_train_epochs}")
|
265 |
+
print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
266 |
+
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
267 |
+
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
268 |
+
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
269 |
+
|
270 |
+
# TODO refactor metadata creation and move to util
|
271 |
metadata = {
|
272 |
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
273 |
"ss_training_started_at": training_started_at, # unix timestamp
|
|
|
275 |
"ss_learning_rate": args.learning_rate,
|
276 |
"ss_text_encoder_lr": args.text_encoder_lr,
|
277 |
"ss_unet_lr": args.unet_lr,
|
278 |
+
"ss_num_train_images": train_dataset_group.num_train_images,
|
279 |
+
"ss_num_reg_images": train_dataset_group.num_reg_images,
|
280 |
"ss_num_batches_per_epoch": len(train_dataloader),
|
281 |
"ss_num_epochs": num_train_epochs,
|
|
|
|
|
282 |
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
283 |
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
284 |
"ss_max_train_steps": args.max_train_steps,
|
|
|
290 |
"ss_mixed_precision": args.mixed_precision,
|
291 |
"ss_full_fp16": bool(args.full_fp16),
|
292 |
"ss_v2": bool(args.v2),
|
|
|
293 |
"ss_clip_skip": args.clip_skip,
|
294 |
"ss_max_token_length": args.max_token_length,
|
|
|
|
|
|
|
|
|
295 |
"ss_cache_latents": bool(args.cache_latents),
|
|
|
|
|
|
|
296 |
"ss_seed": args.seed,
|
297 |
+
"ss_lowram": args.lowram,
|
298 |
"ss_noise_offset": args.noise_offset,
|
|
|
|
|
|
|
|
|
299 |
"ss_training_comment": args.training_comment, # will not be updated after training
|
300 |
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
301 |
+
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
|
302 |
+
"ss_max_grad_norm": args.max_grad_norm,
|
303 |
+
"ss_caption_dropout_rate": args.caption_dropout_rate,
|
304 |
+
"ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
|
305 |
+
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
|
306 |
+
"ss_face_crop_aug_range": args.face_crop_aug_range,
|
307 |
+
"ss_prior_loss_weight": args.prior_loss_weight,
|
308 |
}
|
309 |
|
310 |
+
if use_user_config:
|
311 |
+
# save metadata of multiple datasets
|
312 |
+
# NOTE: pack "ss_datasets" value as json one time
|
313 |
+
# or should also pack nested collections as json?
|
314 |
+
datasets_metadata = []
|
315 |
+
tag_frequency = {} # merge tag frequency for metadata editor
|
316 |
+
dataset_dirs_info = {} # merge subset dirs for metadata editor
|
317 |
+
|
318 |
+
for dataset in train_dataset_group.datasets:
|
319 |
+
is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
|
320 |
+
dataset_metadata = {
|
321 |
+
"is_dreambooth": is_dreambooth_dataset,
|
322 |
+
"batch_size_per_device": dataset.batch_size,
|
323 |
+
"num_train_images": dataset.num_train_images, # includes repeating
|
324 |
+
"num_reg_images": dataset.num_reg_images,
|
325 |
+
"resolution": (dataset.width, dataset.height),
|
326 |
+
"enable_bucket": bool(dataset.enable_bucket),
|
327 |
+
"min_bucket_reso": dataset.min_bucket_reso,
|
328 |
+
"max_bucket_reso": dataset.max_bucket_reso,
|
329 |
+
"tag_frequency": dataset.tag_frequency,
|
330 |
+
"bucket_info": dataset.bucket_info,
|
331 |
+
}
|
332 |
+
|
333 |
+
subsets_metadata = []
|
334 |
+
for subset in dataset.subsets:
|
335 |
+
subset_metadata = {
|
336 |
+
"img_count": subset.img_count,
|
337 |
+
"num_repeats": subset.num_repeats,
|
338 |
+
"color_aug": bool(subset.color_aug),
|
339 |
+
"flip_aug": bool(subset.flip_aug),
|
340 |
+
"random_crop": bool(subset.random_crop),
|
341 |
+
"shuffle_caption": bool(subset.shuffle_caption),
|
342 |
+
"keep_tokens": subset.keep_tokens,
|
343 |
+
}
|
344 |
+
|
345 |
+
image_dir_or_metadata_file = None
|
346 |
+
if subset.image_dir:
|
347 |
+
image_dir = os.path.basename(subset.image_dir)
|
348 |
+
subset_metadata["image_dir"] = image_dir
|
349 |
+
image_dir_or_metadata_file = image_dir
|
350 |
+
|
351 |
+
if is_dreambooth_dataset:
|
352 |
+
subset_metadata["class_tokens"] = subset.class_tokens
|
353 |
+
subset_metadata["is_reg"] = subset.is_reg
|
354 |
+
if subset.is_reg:
|
355 |
+
image_dir_or_metadata_file = None # not merging reg dataset
|
356 |
+
else:
|
357 |
+
metadata_file = os.path.basename(subset.metadata_file)
|
358 |
+
subset_metadata["metadata_file"] = metadata_file
|
359 |
+
image_dir_or_metadata_file = metadata_file # may overwrite
|
360 |
+
|
361 |
+
subsets_metadata.append(subset_metadata)
|
362 |
+
|
363 |
+
# merge dataset dir: not reg subset only
|
364 |
+
# TODO update additional-network extension to show detailed dataset config from metadata
|
365 |
+
if image_dir_or_metadata_file is not None:
|
366 |
+
# datasets may have a certain dir multiple times
|
367 |
+
v = image_dir_or_metadata_file
|
368 |
+
i = 2
|
369 |
+
while v in dataset_dirs_info:
|
370 |
+
v = image_dir_or_metadata_file + f" ({i})"
|
371 |
+
i += 1
|
372 |
+
image_dir_or_metadata_file = v
|
373 |
+
|
374 |
+
dataset_dirs_info[image_dir_or_metadata_file] = {
|
375 |
+
"n_repeats": subset.num_repeats,
|
376 |
+
"img_count": subset.img_count
|
377 |
+
}
|
378 |
+
|
379 |
+
dataset_metadata["subsets"] = subsets_metadata
|
380 |
+
datasets_metadata.append(dataset_metadata)
|
381 |
+
|
382 |
+
# merge tag frequency:
|
383 |
+
for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
|
384 |
+
# あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
|
385 |
+
# もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
|
386 |
+
# なので、ここで複数datasetの回数を合算してもあまり意味はない
|
387 |
+
if ds_dir_name in tag_frequency:
|
388 |
+
continue
|
389 |
+
tag_frequency[ds_dir_name] = ds_freq_for_dir
|
390 |
+
|
391 |
+
metadata["ss_datasets"] = json.dumps(datasets_metadata)
|
392 |
+
metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
|
393 |
+
metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
|
394 |
+
else:
|
395 |
+
# conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
|
396 |
+
assert len(
|
397 |
+
train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
|
398 |
+
|
399 |
+
dataset = train_dataset_group.datasets[0]
|
400 |
+
|
401 |
+
dataset_dirs_info = {}
|
402 |
+
reg_dataset_dirs_info = {}
|
403 |
+
if use_dreambooth_method:
|
404 |
+
for subset in dataset.subsets:
|
405 |
+
info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
|
406 |
+
info[os.path.basename(subset.image_dir)] = {
|
407 |
+
"n_repeats": subset.num_repeats,
|
408 |
+
"img_count": subset.img_count
|
409 |
+
}
|
410 |
+
else:
|
411 |
+
for subset in dataset.subsets:
|
412 |
+
dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
|
413 |
+
"n_repeats": subset.num_repeats,
|
414 |
+
"img_count": subset.img_count
|
415 |
+
}
|
416 |
+
|
417 |
+
metadata.update({
|
418 |
+
"ss_batch_size_per_device": args.train_batch_size,
|
419 |
+
"ss_total_batch_size": total_batch_size,
|
420 |
+
"ss_resolution": args.resolution,
|
421 |
+
"ss_color_aug": bool(args.color_aug),
|
422 |
+
"ss_flip_aug": bool(args.flip_aug),
|
423 |
+
"ss_random_crop": bool(args.random_crop),
|
424 |
+
"ss_shuffle_caption": bool(args.shuffle_caption),
|
425 |
+
"ss_enable_bucket": bool(dataset.enable_bucket),
|
426 |
+
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
|
427 |
+
"ss_min_bucket_reso": dataset.min_bucket_reso,
|
428 |
+
"ss_max_bucket_reso": dataset.max_bucket_reso,
|
429 |
+
"ss_keep_tokens": args.keep_tokens,
|
430 |
+
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
|
431 |
+
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
|
432 |
+
"ss_tag_frequency": json.dumps(dataset.tag_frequency),
|
433 |
+
"ss_bucket_info": json.dumps(dataset.bucket_info),
|
434 |
+
})
|
435 |
+
|
436 |
+
# add extra args
|
437 |
+
if args.network_args:
|
438 |
+
metadata["ss_network_args"] = json.dumps(net_kwargs)
|
439 |
+
# for key, value in net_kwargs.items():
|
440 |
+
# metadata["ss_arg_" + key] = value
|
441 |
+
|
442 |
+
# model name and hash
|
443 |
if args.pretrained_model_name_or_path is not None:
|
444 |
sd_model_name = args.pretrained_model_name_or_path
|
445 |
if os.path.exists(sd_model_name):
|
|
|
458 |
|
459 |
metadata = {k: str(v) for k, v in metadata.items()}
|
460 |
|
461 |
+
# make minimum metadata for filtering
|
462 |
+
minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"]
|
463 |
+
minimum_metadata = {}
|
464 |
+
for key in minimum_keys:
|
465 |
+
if key in metadata:
|
466 |
+
minimum_metadata[key] = metadata[key]
|
467 |
+
|
468 |
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
469 |
global_step = 0
|
470 |
|
|
|
477 |
loss_list = []
|
478 |
loss_total = 0.0
|
479 |
for epoch in range(num_train_epochs):
|
480 |
+
if is_main_process:
|
481 |
+
print(f"epoch {epoch+1}/{num_train_epochs}")
|
482 |
+
train_dataset_group.set_current_epoch(epoch + 1)
|
483 |
|
484 |
metadata["ss_epoch"] = str(epoch+1)
|
485 |
|
|
|
516 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
517 |
|
518 |
# Predict the noise residual
|
519 |
+
with accelerator.autocast():
|
520 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
521 |
|
522 |
if args.v_parameterization:
|
|
|
534 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
535 |
|
536 |
accelerator.backward(loss)
|
537 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
538 |
params_to_clip = network.get_trainable_params()
|
539 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
540 |
|
541 |
optimizer.step()
|
542 |
lr_scheduler.step()
|
|
|
547 |
progress_bar.update(1)
|
548 |
global_step += 1
|
549 |
|
550 |
+
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
551 |
+
|
552 |
current_loss = loss.detach().item()
|
553 |
if epoch == 0:
|
554 |
loss_list.append(current_loss)
|
|
|
579 |
def save_func():
|
580 |
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
581 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
582 |
+
metadata["ss_training_finished_at"] = str(time.time())
|
583 |
print(f"saving checkpoint: {ckpt_file}")
|
584 |
+
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
585 |
|
586 |
def remove_old_func(old_epoch_no):
|
587 |
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
|
|
590 |
print(f"removing old checkpoint: {old_ckpt_file}")
|
591 |
os.remove(old_ckpt_file)
|
592 |
|
593 |
+
if is_main_process:
|
594 |
+
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
595 |
+
if saving and args.save_state:
|
596 |
+
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
597 |
+
|
598 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
599 |
|
600 |
# end of epoch
|
601 |
|
602 |
metadata["ss_epoch"] = str(num_train_epochs)
|
603 |
+
metadata["ss_training_finished_at"] = str(time.time())
|
604 |
|
|
|
605 |
if is_main_process:
|
606 |
network = unwrap_model(network)
|
607 |
|
|
|
620 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
621 |
|
622 |
print(f"save trained model to {ckpt_file}")
|
623 |
+
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
624 |
print("model saved.")
|
625 |
|
626 |
|
|
|
630 |
train_util.add_sd_models_arguments(parser)
|
631 |
train_util.add_dataset_arguments(parser, True, True, True)
|
632 |
train_util.add_training_arguments(parser, True)
|
633 |
+
train_util.add_optimizer_arguments(parser)
|
634 |
+
config_util.add_config_arguments(parser)
|
635 |
|
636 |
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
637 |
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
|
|
639 |
|
640 |
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
641 |
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
|
|
|
|
|
|
|
|
642 |
|
643 |
parser.add_argument("--network_weights", type=str, default=None,
|
644 |
help="pretrained weights for network / 学習するネットワークの初期重み")
|
train_textual_inversion.py
CHANGED
@@ -11,7 +11,11 @@ import diffusers
|
|
11 |
from diffusers import DDPMScheduler
|
12 |
|
13 |
import library.train_util as train_util
|
14 |
-
|
|
|
|
|
|
|
|
|
15 |
|
16 |
imagenet_templates_small = [
|
17 |
"a photo of a {}",
|
@@ -79,7 +83,6 @@ def train(args):
|
|
79 |
train_util.prepare_dataset_args(args, True)
|
80 |
|
81 |
cache_latents = args.cache_latents
|
82 |
-
use_dreambooth_method = args.in_json is None
|
83 |
|
84 |
if args.seed is not None:
|
85 |
set_seed(args.seed)
|
@@ -139,21 +142,35 @@ def train(args):
|
|
139 |
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
140 |
|
141 |
# データセットを準備する
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
else:
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
159 |
if use_template:
|
@@ -163,20 +180,30 @@ def train(args):
|
|
163 |
captions = []
|
164 |
for tmpl in templates:
|
165 |
captions.append(tmpl.format(replace_to))
|
166 |
-
|
167 |
-
elif args.num_vectors_per_token > 1:
|
168 |
-
replace_to = " ".join(token_strings)
|
169 |
-
train_dataset.add_replacement(args.token_string, replace_to)
|
170 |
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
if args.debug_dataset:
|
174 |
-
train_util.debug_dataset(
|
175 |
return
|
176 |
-
if len(
|
177 |
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
178 |
return
|
179 |
|
|
|
|
|
|
|
180 |
# モデルに xformers とか memory efficient attention を組み込む
|
181 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
182 |
|
@@ -186,7 +213,7 @@ def train(args):
|
|
186 |
vae.requires_grad_(False)
|
187 |
vae.eval()
|
188 |
with torch.no_grad():
|
189 |
-
|
190 |
vae.to("cpu")
|
191 |
if torch.cuda.is_available():
|
192 |
torch.cuda.empty_cache()
|
@@ -198,35 +225,14 @@ def train(args):
|
|
198 |
|
199 |
# 学習に必要なクラスを準備する
|
200 |
print("prepare optimizer, data loader etc.")
|
201 |
-
|
202 |
-
# 8-bit Adamを使う
|
203 |
-
if args.use_8bit_adam:
|
204 |
-
try:
|
205 |
-
import bitsandbytes as bnb
|
206 |
-
except ImportError:
|
207 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
208 |
-
print("use 8-bit Adam optimizer")
|
209 |
-
optimizer_class = bnb.optim.AdamW8bit
|
210 |
-
elif args.use_lion_optimizer:
|
211 |
-
try:
|
212 |
-
import lion_pytorch
|
213 |
-
except ImportError:
|
214 |
-
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
215 |
-
print("use Lion optimizer")
|
216 |
-
optimizer_class = lion_pytorch.Lion
|
217 |
-
else:
|
218 |
-
optimizer_class = torch.optim.AdamW
|
219 |
-
|
220 |
trainable_params = text_encoder.get_input_embeddings().parameters()
|
221 |
-
|
222 |
-
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
223 |
-
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
224 |
|
225 |
# dataloaderを準備する
|
226 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
227 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
228 |
train_dataloader = torch.utils.data.DataLoader(
|
229 |
-
|
230 |
|
231 |
# 学習ステップ数を計算する
|
232 |
if args.max_train_epochs is not None:
|
@@ -234,8 +240,9 @@ def train(args):
|
|
234 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
235 |
|
236 |
# lr schedulerを用意する
|
237 |
-
lr_scheduler =
|
238 |
-
|
|
|
239 |
|
240 |
# acceleratorがなんかよろしくやってくれるらしい
|
241 |
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
@@ -283,8 +290,8 @@ def train(args):
|
|
283 |
# 学習する
|
284 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
285 |
print("running training / 学習開始")
|
286 |
-
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {
|
287 |
-
print(f" num reg images / 正則化画像の数: {
|
288 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
289 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
290 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
@@ -303,12 +310,11 @@ def train(args):
|
|
303 |
|
304 |
for epoch in range(num_train_epochs):
|
305 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
306 |
-
|
307 |
|
308 |
text_encoder.train()
|
309 |
|
310 |
loss_total = 0
|
311 |
-
bef_epo_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
312 |
for step, batch in enumerate(train_dataloader):
|
313 |
with accelerator.accumulate(text_encoder):
|
314 |
with torch.no_grad():
|
@@ -357,9 +363,9 @@ def train(args):
|
|
357 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
358 |
|
359 |
accelerator.backward(loss)
|
360 |
-
if accelerator.sync_gradients:
|
361 |
params_to_clip = text_encoder.get_input_embeddings().parameters()
|
362 |
-
accelerator.clip_grad_norm_(params_to_clip,
|
363 |
|
364 |
optimizer.step()
|
365 |
lr_scheduler.step()
|
@@ -374,9 +380,14 @@ def train(args):
|
|
374 |
progress_bar.update(1)
|
375 |
global_step += 1
|
376 |
|
|
|
|
|
|
|
377 |
current_loss = loss.detach().item()
|
378 |
if args.logging_dir is not None:
|
379 |
-
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
|
|
380 |
accelerator.log(logs, step=global_step)
|
381 |
|
382 |
loss_total += current_loss
|
@@ -394,8 +405,6 @@ def train(args):
|
|
394 |
accelerator.wait_for_everyone()
|
395 |
|
396 |
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
397 |
-
# d = updated_embs - bef_epo_embs
|
398 |
-
# print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
|
399 |
|
400 |
if args.save_every_n_epochs is not None:
|
401 |
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
@@ -417,6 +426,9 @@ def train(args):
|
|
417 |
if saving and args.save_state:
|
418 |
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
419 |
|
|
|
|
|
|
|
420 |
# end of epoch
|
421 |
|
422 |
is_main_process = accelerator.is_main_process
|
@@ -491,6 +503,8 @@ if __name__ == '__main__':
|
|
491 |
train_util.add_sd_models_arguments(parser)
|
492 |
train_util.add_dataset_arguments(parser, True, True, False)
|
493 |
train_util.add_training_arguments(parser, True)
|
|
|
|
|
494 |
|
495 |
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
496 |
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|
|
|
11 |
from diffusers import DDPMScheduler
|
12 |
|
13 |
import library.train_util as train_util
|
14 |
+
import library.config_util as config_util
|
15 |
+
from library.config_util import (
|
16 |
+
ConfigSanitizer,
|
17 |
+
BlueprintGenerator,
|
18 |
+
)
|
19 |
|
20 |
imagenet_templates_small = [
|
21 |
"a photo of a {}",
|
|
|
83 |
train_util.prepare_dataset_args(args, True)
|
84 |
|
85 |
cache_latents = args.cache_latents
|
|
|
86 |
|
87 |
if args.seed is not None:
|
88 |
set_seed(args.seed)
|
|
|
142 |
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
143 |
|
144 |
# データセットを準備する
|
145 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
|
146 |
+
if args.dataset_config is not None:
|
147 |
+
print(f"Load dataset config from {args.dataset_config}")
|
148 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
149 |
+
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
150 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
151 |
+
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
152 |
else:
|
153 |
+
use_dreambooth_method = args.in_json is None
|
154 |
+
if use_dreambooth_method:
|
155 |
+
print("Use DreamBooth method.")
|
156 |
+
user_config = {
|
157 |
+
"datasets": [{
|
158 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
|
159 |
+
}]
|
160 |
+
}
|
161 |
+
else:
|
162 |
+
print("Train with captions.")
|
163 |
+
user_config = {
|
164 |
+
"datasets": [{
|
165 |
+
"subsets": [{
|
166 |
+
"image_dir": args.train_data_dir,
|
167 |
+
"metadata_file": args.in_json,
|
168 |
+
}]
|
169 |
+
}]
|
170 |
+
}
|
171 |
+
|
172 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
173 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
174 |
|
175 |
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
176 |
if use_template:
|
|
|
180 |
captions = []
|
181 |
for tmpl in templates:
|
182 |
captions.append(tmpl.format(replace_to))
|
183 |
+
train_dataset_group.add_replacement("", captions)
|
|
|
|
|
|
|
184 |
|
185 |
+
if args.num_vectors_per_token > 1:
|
186 |
+
prompt_replacement = (args.token_string, replace_to)
|
187 |
+
else:
|
188 |
+
prompt_replacement = None
|
189 |
+
else:
|
190 |
+
if args.num_vectors_per_token > 1:
|
191 |
+
replace_to = " ".join(token_strings)
|
192 |
+
train_dataset_group.add_replacement(args.token_string, replace_to)
|
193 |
+
prompt_replacement = (args.token_string, replace_to)
|
194 |
+
else:
|
195 |
+
prompt_replacement = None
|
196 |
|
197 |
if args.debug_dataset:
|
198 |
+
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
|
199 |
return
|
200 |
+
if len(train_dataset_group) == 0:
|
201 |
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
202 |
return
|
203 |
|
204 |
+
if cache_latents:
|
205 |
+
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
206 |
+
|
207 |
# モデルに xformers とか memory efficient attention を組み込む
|
208 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
209 |
|
|
|
213 |
vae.requires_grad_(False)
|
214 |
vae.eval()
|
215 |
with torch.no_grad():
|
216 |
+
train_dataset_group.cache_latents(vae)
|
217 |
vae.to("cpu")
|
218 |
if torch.cuda.is_available():
|
219 |
torch.cuda.empty_cache()
|
|
|
225 |
|
226 |
# 学習に必要なクラスを準備する
|
227 |
print("prepare optimizer, data loader etc.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
trainable_params = text_encoder.get_input_embeddings().parameters()
|
229 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
|
|
|
|
230 |
|
231 |
# dataloaderを準備する
|
232 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
233 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
234 |
train_dataloader = torch.utils.data.DataLoader(
|
235 |
+
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
236 |
|
237 |
# 学習ステップ数を計算する
|
238 |
if args.max_train_epochs is not None:
|
|
|
240 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
241 |
|
242 |
# lr schedulerを用意する
|
243 |
+
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
244 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
245 |
+
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
246 |
|
247 |
# acceleratorがなんかよろしくやってくれるらしい
|
248 |
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
|
|
290 |
# 学習する
|
291 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
292 |
print("running training / 学習開始")
|
293 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
294 |
+
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
295 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
296 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
297 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
|
|
310 |
|
311 |
for epoch in range(num_train_epochs):
|
312 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
313 |
+
train_dataset_group.set_current_epoch(epoch + 1)
|
314 |
|
315 |
text_encoder.train()
|
316 |
|
317 |
loss_total = 0
|
|
|
318 |
for step, batch in enumerate(train_dataloader):
|
319 |
with accelerator.accumulate(text_encoder):
|
320 |
with torch.no_grad():
|
|
|
363 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
364 |
|
365 |
accelerator.backward(loss)
|
366 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
367 |
params_to_clip = text_encoder.get_input_embeddings().parameters()
|
368 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
369 |
|
370 |
optimizer.step()
|
371 |
lr_scheduler.step()
|
|
|
380 |
progress_bar.update(1)
|
381 |
global_step += 1
|
382 |
|
383 |
+
train_util.sample_images(accelerator, args, None, global_step, accelerator.device,
|
384 |
+
vae, tokenizer, text_encoder, unet, prompt_replacement)
|
385 |
+
|
386 |
current_loss = loss.detach().item()
|
387 |
if args.logging_dir is not None:
|
388 |
+
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
389 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
390 |
+
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
391 |
accelerator.log(logs, step=global_step)
|
392 |
|
393 |
loss_total += current_loss
|
|
|
405 |
accelerator.wait_for_everyone()
|
406 |
|
407 |
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
|
|
|
|
408 |
|
409 |
if args.save_every_n_epochs is not None:
|
410 |
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
|
|
426 |
if saving and args.save_state:
|
427 |
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
428 |
|
429 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device,
|
430 |
+
vae, tokenizer, text_encoder, unet, prompt_replacement)
|
431 |
+
|
432 |
# end of epoch
|
433 |
|
434 |
is_main_process = accelerator.is_main_process
|
|
|
503 |
train_util.add_sd_models_arguments(parser)
|
504 |
train_util.add_dataset_arguments(parser, True, True, False)
|
505 |
train_util.add_training_arguments(parser, True)
|
506 |
+
train_util.add_optimizer_arguments(parser)
|
507 |
+
config_util.add_config_arguments(parser)
|
508 |
|
509 |
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
510 |
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|