abc
commited on
Commit
·
350076d
1
Parent(s):
b71766b
Upload 35 files
Browse files- append_module.py +329 -32
- fine_tune.py +50 -45
- finetune/blip/blip.py +240 -0
- finetune/blip/med.py +955 -0
- finetune/blip/med_config.json +22 -0
- finetune/blip/vit.py +305 -0
- finetune/clean_captions_and_tags.py +184 -0
- finetune/hypernetwork_nai.py +96 -0
- finetune/make_captions.py +162 -0
- finetune/make_captions_by_git.py +145 -0
- finetune/merge_captions_to_metadata.py +67 -0
- finetune/merge_dd_tags_to_metadata.py +62 -0
- finetune/prepare_buckets_latents.py +261 -0
- finetune/tag_images_by_wd14_tagger.py +200 -0
- gen_img_diffusers.py +213 -48
- library/config_util.py +527 -0
- library/train_util.py +823 -230
- networks/lora.py +5 -0
- requirements.txt +2 -1
- train_db.py +47 -45
- train_network.py +212 -156
- train_network_opt.py +293 -355
- train_textual_inversion.py +68 -59
append_module.py
CHANGED
@@ -2,7 +2,19 @@ import argparse
|
|
2 |
import json
|
3 |
import shutil
|
4 |
import time
|
5 |
-
from typing import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from accelerate import Accelerator
|
7 |
from torch.autograd.function import Function
|
8 |
import glob
|
@@ -28,6 +40,7 @@ import safetensors.torch
|
|
28 |
|
29 |
import library.model_util as model_util
|
30 |
import library.train_util as train_util
|
|
|
31 |
|
32 |
#============================================================================================================
|
33 |
#AdafactorScheduleに暫定的にinitial_lrを層別に適用できるようにしたもの
|
@@ -115,6 +128,124 @@ def make_bucket_resolutions_fix(max_reso, min_reso, min_size=256, max_size=1024,
|
|
115 |
return area_size_resos_list, area_size_list
|
116 |
|
117 |
#============================================================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
#train_util 内より
|
119 |
#============================================================================================================
|
120 |
class BucketManager_append(train_util.BucketManager):
|
@@ -179,7 +310,7 @@ class BucketManager_append(train_util.BucketManager):
|
|
179 |
bucket_size_id_list.append(bucket_size_id + i + 1)
|
180 |
_min_error = 1000.
|
181 |
_min_id = bucket_size_id
|
182 |
-
for now_size_id in
|
183 |
self.predefined_aspect_ratios = self.predefined_aspect_ratios_list[now_size_id]
|
184 |
ar_errors = self.predefined_aspect_ratios - aspect_ratio
|
185 |
ar_error = np.abs(ar_errors).min()
|
@@ -253,13 +384,13 @@ class BucketManager_append(train_util.BucketManager):
|
|
253 |
return reso, resized_size, ar_error
|
254 |
|
255 |
class DreamBoothDataset(train_util.DreamBoothDataset):
|
256 |
-
def __init__(self,
|
257 |
print("use append DreamBoothDataset")
|
258 |
self.min_resolution = min_resolution
|
259 |
self.area_step = area_step
|
260 |
-
super().__init__(
|
261 |
-
|
262 |
-
|
263 |
def make_buckets(self):
|
264 |
'''
|
265 |
bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
|
@@ -353,11 +484,10 @@ class DreamBoothDataset(train_util.DreamBoothDataset):
|
|
353 |
self._length = len(self.buckets_indices)
|
354 |
|
355 |
class FineTuningDataset(train_util.FineTuningDataset):
|
356 |
-
def __init__(self,
|
357 |
train_util.glob_images = glob_images
|
358 |
-
super().__init__(
|
359 |
-
|
360 |
-
random_crop, dataset_repeats, debug_dataset)
|
361 |
|
362 |
def glob_images(directory, base="*", npz_flag=True):
|
363 |
img_paths = []
|
@@ -373,13 +503,26 @@ def glob_images(directory, base="*", npz_flag=True):
|
|
373 |
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
374 |
return img_paths
|
375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
#============================================================================================================
|
377 |
#networks.lora
|
378 |
#============================================================================================================
|
379 |
from networks.lora import LoRANetwork
|
380 |
def replace_prepare_optimizer_params(networks):
|
381 |
-
def prepare_optimizer_params(self, text_encoder_lr, unet_lr,
|
382 |
-
|
383 |
def enumerate_params(loras, lora_name=None):
|
384 |
params = []
|
385 |
for lora in loras:
|
@@ -393,6 +536,7 @@ def replace_prepare_optimizer_params(networks):
|
|
393 |
self.requires_grad_(True)
|
394 |
all_params = []
|
395 |
ret_scheduler_lr = []
|
|
|
396 |
|
397 |
if loranames is not None:
|
398 |
textencoder_names = [None]
|
@@ -405,22 +549,60 @@ def replace_prepare_optimizer_params(networks):
|
|
405 |
if self.text_encoder_loras:
|
406 |
for textencoder_name in textencoder_names:
|
407 |
param_data = {'params': enumerate_params(self.text_encoder_loras, lora_name=textencoder_name)}
|
|
|
408 |
if text_encoder_lr is not None:
|
409 |
param_data['lr'] = text_encoder_lr
|
410 |
-
|
411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
all_params.append(param_data)
|
413 |
|
414 |
if self.unet_loras:
|
415 |
for unet_name in unet_names:
|
416 |
param_data = {'params': enumerate_params(self.unet_loras, lora_name=unet_name)}
|
|
|
417 |
if unet_lr is not None:
|
418 |
param_data['lr'] = unet_lr
|
419 |
-
|
420 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
421 |
all_params.append(param_data)
|
422 |
|
423 |
-
return all_params, ret_scheduler_lr
|
424 |
|
425 |
LoRANetwork.prepare_optimizer_params = prepare_optimizer_params
|
426 |
|
@@ -429,14 +611,98 @@ def replace_prepare_optimizer_params(networks):
|
|
429 |
#============================================================================================================
|
430 |
def add_append_arguments(parser: argparse.ArgumentParser):
|
431 |
# for train_network_opt.py
|
432 |
-
parser.add_argument("--optimizer", type=str, default="AdamW", choices=["AdamW", "RAdam", "AdaBound", "AdaBelief", "AggMo", "AdamP", "Adastand", "Adastand_belief", "Apollo", "Lamb", "Ranger", "RangerVA", "Lookahead_Adam", "Lookahead_DiffGrad", "Yogi", "NovoGrad", "QHAdam", "DiffGrad", "MADGRAD", "Adafactor"], help="使用するoptimizerを指定する")
|
433 |
-
parser.add_argument("--optimizer_arg", type=str, default=None, nargs='*')
|
|
|
|
|
434 |
parser.add_argument("--split_lora_networks", action="store_true")
|
435 |
parser.add_argument("--split_lora_level", type=int, default=0, help="どれくらい細分化するかの設定 0がunetのみを層別に 1がunetを大枠で分割 2がtextencoder含めて層別")
|
|
|
|
|
436 |
parser.add_argument("--min_resolution", type=str, default=None)
|
437 |
parser.add_argument("--area_step", type=int, default=1)
|
438 |
parser.add_argument("--config", type=str, default=None)
|
439 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
440 |
def create_split_names(split_flag, split_level):
|
441 |
split_names = None
|
442 |
if split_flag:
|
@@ -446,14 +712,23 @@ def create_split_names(split_flag, split_level):
|
|
446 |
if split_level==1:
|
447 |
unet_names.append(f"lora_unet_down_blocks_")
|
448 |
unet_names.append(f"lora_unet_up_blocks_")
|
449 |
-
elif split_level==2 or split_level==0:
|
450 |
-
if split_level
|
451 |
text_encoder_names = []
|
452 |
for i in range(12):
|
453 |
text_encoder_names.append(f"lora_te_text_model_encoder_layers_{i}_")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
454 |
for i in range(3):
|
455 |
-
|
456 |
-
|
|
|
|
|
457 |
split_names["text_encoder"] = text_encoder_names
|
458 |
split_names["unet"] = unet_names
|
459 |
return split_names
|
@@ -465,7 +740,7 @@ def get_config(parser):
|
|
465 |
import datetime
|
466 |
if os.path.splitext(args.config)[-1] == ".yaml":
|
467 |
args.config = os.path.splitext(args.config)[0]
|
468 |
-
config_path = f"
|
469 |
if os.path.exists(config_path):
|
470 |
print(f"{config_path} から設定を読���込み中...")
|
471 |
margs, rest = parser.parse_known_args()
|
@@ -486,19 +761,41 @@ def get_config(parser):
|
|
486 |
args_type_dic[key] = act.type
|
487 |
#データタイプの確認とargsにkeyの内容を代入していく
|
488 |
for key, v in configs.items():
|
489 |
-
if
|
490 |
-
if
|
491 |
-
|
492 |
-
|
493 |
-
v
|
494 |
-
|
495 |
-
|
496 |
if not type(v) == args_type_dic[key]:
|
497 |
v = args_type_dic[key](v)
|
498 |
-
|
499 |
#最後にデフォから指定が変わってるものを変更する
|
500 |
for key, v in change_def_dic.items():
|
501 |
args_dic[key] = v
|
502 |
else:
|
503 |
print(f"{config_path} が見つかりませんでした")
|
504 |
return args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import json
|
3 |
import shutil
|
4 |
import time
|
5 |
+
from typing import (
|
6 |
+
Dict,
|
7 |
+
List,
|
8 |
+
NamedTuple,
|
9 |
+
Optional,
|
10 |
+
Sequence,
|
11 |
+
Tuple,
|
12 |
+
Union,
|
13 |
+
)
|
14 |
+
from dataclasses import (
|
15 |
+
asdict,
|
16 |
+
dataclass,
|
17 |
+
)
|
18 |
from accelerate import Accelerator
|
19 |
from torch.autograd.function import Function
|
20 |
import glob
|
|
|
40 |
|
41 |
import library.model_util as model_util
|
42 |
import library.train_util as train_util
|
43 |
+
import library.config_util as config_util
|
44 |
|
45 |
#============================================================================================================
|
46 |
#AdafactorScheduleに暫定的にinitial_lrを層別に適用できるようにしたもの
|
|
|
128 |
return area_size_resos_list, area_size_list
|
129 |
|
130 |
#============================================================================================================
|
131 |
+
#config_util 内より
|
132 |
+
#============================================================================================================
|
133 |
+
@dataclass
|
134 |
+
class DreamBoothDatasetParams(config_util.DreamBoothDatasetParams):
|
135 |
+
min_resolution: Optional[Tuple[int, int]] = None
|
136 |
+
area_step : int = 2
|
137 |
+
|
138 |
+
class ConfigSanitizer(config_util.ConfigSanitizer):
|
139 |
+
#@config_util.curry
|
140 |
+
@staticmethod
|
141 |
+
def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
|
142 |
+
config_util.Schema(config_util.ExactSequence([klass, klass]))(value)
|
143 |
+
return tuple(value)
|
144 |
+
|
145 |
+
#@config_util.curry
|
146 |
+
@staticmethod
|
147 |
+
def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
|
148 |
+
config_util.Schema(config_util.Any(klass, config_util.ExactSequence([klass, klass])))(value)
|
149 |
+
try:
|
150 |
+
config_util.Schema(klass)(value)
|
151 |
+
return (value, value)
|
152 |
+
except:
|
153 |
+
return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
|
154 |
+
# datasets schema
|
155 |
+
DATASET_ASCENDABLE_SCHEMA = {
|
156 |
+
"batch_size": int,
|
157 |
+
"bucket_no_upscale": bool,
|
158 |
+
"bucket_reso_steps": int,
|
159 |
+
"enable_bucket": bool,
|
160 |
+
"max_bucket_reso": int,
|
161 |
+
"min_bucket_reso": int,
|
162 |
+
"resolution": config_util.functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
163 |
+
"min_resolution": config_util.functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
164 |
+
"area_step": int,
|
165 |
+
}
|
166 |
+
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None:
|
167 |
+
super().__init__(support_dreambooth, support_finetuning, support_dropout)
|
168 |
+
def _check(self):
|
169 |
+
print(self.db_dataset_schema)
|
170 |
+
|
171 |
+
class BlueprintGenerator(config_util.BlueprintGenerator):
|
172 |
+
def __init__(self, sanitizer: ConfigSanitizer):
|
173 |
+
config_util.DreamBoothDatasetParams = DreamBoothDatasetParams
|
174 |
+
super().__init__(sanitizer)
|
175 |
+
|
176 |
+
def generate_dataset_group_by_blueprint(dataset_group_blueprint: config_util.DatasetGroupBlueprint):
|
177 |
+
datasets: List[Union[DreamBoothDataset, FineTuningDataset]] = []
|
178 |
+
|
179 |
+
for dataset_blueprint in dataset_group_blueprint.datasets:
|
180 |
+
if dataset_blueprint.is_dreambooth:
|
181 |
+
subset_klass = train_util.DreamBoothSubset
|
182 |
+
dataset_klass = DreamBoothDataset
|
183 |
+
else:
|
184 |
+
subset_klass = train_util.FineTuningSubset
|
185 |
+
dataset_klass = FineTuningDataset
|
186 |
+
|
187 |
+
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
188 |
+
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
|
189 |
+
datasets.append(dataset)
|
190 |
+
|
191 |
+
# print info
|
192 |
+
info = ""
|
193 |
+
for i, dataset in enumerate(datasets):
|
194 |
+
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
195 |
+
info += config_util.dedent(f"""\
|
196 |
+
[Dataset {i}]
|
197 |
+
batch_size: {dataset.batch_size}
|
198 |
+
resolution: {(dataset.width, dataset.height)}
|
199 |
+
enable_bucket: {dataset.enable_bucket}
|
200 |
+
""")
|
201 |
+
|
202 |
+
if dataset.enable_bucket:
|
203 |
+
info += config_util.indent(config_util.dedent(f"""\
|
204 |
+
min_bucket_reso: {dataset.min_bucket_reso}
|
205 |
+
max_bucket_reso: {dataset.max_bucket_reso}
|
206 |
+
bucket_reso_steps: {dataset.bucket_reso_steps}
|
207 |
+
bucket_no_upscale: {dataset.bucket_no_upscale}
|
208 |
+
\n"""), " ")
|
209 |
+
else:
|
210 |
+
info += "\n"
|
211 |
+
|
212 |
+
for j, subset in enumerate(dataset.subsets):
|
213 |
+
info += config_util.indent(config_util.dedent(f"""\
|
214 |
+
[Subset {j} of Dataset {i}]
|
215 |
+
image_dir: "{subset.image_dir}"
|
216 |
+
image_count: {subset.img_count}
|
217 |
+
num_repeats: {subset.num_repeats}
|
218 |
+
shuffle_caption: {subset.shuffle_caption}
|
219 |
+
keep_tokens: {subset.keep_tokens}
|
220 |
+
caption_dropout_rate: {subset.caption_dropout_rate}
|
221 |
+
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
222 |
+
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
223 |
+
color_aug: {subset.color_aug}
|
224 |
+
flip_aug: {subset.flip_aug}
|
225 |
+
face_crop_aug_range: {subset.face_crop_aug_range}
|
226 |
+
random_crop: {subset.random_crop}
|
227 |
+
"""), " ")
|
228 |
+
|
229 |
+
if is_dreambooth:
|
230 |
+
info += config_util.indent(config_util.dedent(f"""\
|
231 |
+
is_reg: {subset.is_reg}
|
232 |
+
class_tokens: {subset.class_tokens}
|
233 |
+
caption_extension: {subset.caption_extension}
|
234 |
+
\n"""), " ")
|
235 |
+
else:
|
236 |
+
info += config_util.indent(config_util.dedent(f"""\
|
237 |
+
metadata_file: {subset.metadata_file}
|
238 |
+
\n"""), " ")
|
239 |
+
|
240 |
+
print(info)
|
241 |
+
|
242 |
+
# make buckets first because it determines the length of dataset
|
243 |
+
for i, dataset in enumerate(datasets):
|
244 |
+
print(f"[Dataset {i}]")
|
245 |
+
dataset.make_buckets()
|
246 |
+
|
247 |
+
return train_util.DatasetGroup(datasets)
|
248 |
+
#============================================================================================================
|
249 |
#train_util 内より
|
250 |
#============================================================================================================
|
251 |
class BucketManager_append(train_util.BucketManager):
|
|
|
310 |
bucket_size_id_list.append(bucket_size_id + i + 1)
|
311 |
_min_error = 1000.
|
312 |
_min_id = bucket_size_id
|
313 |
+
for now_size_id in bucket_size_id_list:
|
314 |
self.predefined_aspect_ratios = self.predefined_aspect_ratios_list[now_size_id]
|
315 |
ar_errors = self.predefined_aspect_ratios - aspect_ratio
|
316 |
ar_error = np.abs(ar_errors).min()
|
|
|
384 |
return reso, resized_size, ar_error
|
385 |
|
386 |
class DreamBoothDataset(train_util.DreamBoothDataset):
|
387 |
+
def __init__(self, subsets: Sequence[train_util.DreamBoothSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset, min_resolution=None, area_step=None) -> None:
|
388 |
print("use append DreamBoothDataset")
|
389 |
self.min_resolution = min_resolution
|
390 |
self.area_step = area_step
|
391 |
+
super().__init__(subsets, batch_size, tokenizer, max_token_length,
|
392 |
+
resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale,
|
393 |
+
prior_loss_weight, debug_dataset)
|
394 |
def make_buckets(self):
|
395 |
'''
|
396 |
bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
|
|
|
484 |
self._length = len(self.buckets_indices)
|
485 |
|
486 |
class FineTuningDataset(train_util.FineTuningDataset):
|
487 |
+
def __init__(self, subsets: Sequence[train_util.FineTuningSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset) -> None:
|
488 |
train_util.glob_images = glob_images
|
489 |
+
super().__init__(subsets, batch_size, tokenizer, max_token_length,
|
490 |
+
resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, debug_dataset)
|
|
|
491 |
|
492 |
def glob_images(directory, base="*", npz_flag=True):
|
493 |
img_paths = []
|
|
|
503 |
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
504 |
return img_paths
|
505 |
|
506 |
+
import transformers
|
507 |
+
from torch.optim import Optimizer
|
508 |
+
from diffusers.optimization import SchedulerType
|
509 |
+
from typing import Union
|
510 |
+
def get_scheduler_Adafactor(
|
511 |
+
name: Union[str, SchedulerType],
|
512 |
+
optimizer: Optimizer,
|
513 |
+
scheduler_arg: Dict
|
514 |
+
):
|
515 |
+
if name.startswith("adafactor"):
|
516 |
+
assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
|
517 |
+
print(scheduler_arg)
|
518 |
+
return AdafactorSchedule_append(optimizer, **scheduler_arg)
|
519 |
#============================================================================================================
|
520 |
#networks.lora
|
521 |
#============================================================================================================
|
522 |
from networks.lora import LoRANetwork
|
523 |
def replace_prepare_optimizer_params(networks):
|
524 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, loranames=None, lr_dic=None, block_args_dic=None):
|
525 |
+
|
526 |
def enumerate_params(loras, lora_name=None):
|
527 |
params = []
|
528 |
for lora in loras:
|
|
|
536 |
self.requires_grad_(True)
|
537 |
all_params = []
|
538 |
ret_scheduler_lr = []
|
539 |
+
used_names = []
|
540 |
|
541 |
if loranames is not None:
|
542 |
textencoder_names = [None]
|
|
|
549 |
if self.text_encoder_loras:
|
550 |
for textencoder_name in textencoder_names:
|
551 |
param_data = {'params': enumerate_params(self.text_encoder_loras, lora_name=textencoder_name)}
|
552 |
+
used_names.append(textencoder_name)
|
553 |
if text_encoder_lr is not None:
|
554 |
param_data['lr'] = text_encoder_lr
|
555 |
+
if lr_dic is not None:
|
556 |
+
if textencoder_name in lr_dic:
|
557 |
+
param_data['lr'] = lr_dic[textencoder_name]
|
558 |
+
print(f"{textencoder_name} lr: {param_data['lr']}")
|
559 |
+
|
560 |
+
if block_args_dic is not None:
|
561 |
+
if "lora_te_" in block_args_dic:
|
562 |
+
for pname, value in block_args_dic["lora_te_"].items():
|
563 |
+
param_data[pname] = value
|
564 |
+
if textencoder_name in block_args_dic:
|
565 |
+
for pname, value in block_args_dic[textencoder_name].items():
|
566 |
+
param_data[pname] = value
|
567 |
+
|
568 |
+
if text_encoder_lr is not None:
|
569 |
+
ret_scheduler_lr.append(text_encoder_lr)
|
570 |
+
else:
|
571 |
+
ret_scheduler_lr.append(0.)
|
572 |
+
if lr_dic is not None:
|
573 |
+
if textencoder_name in lr_dic:
|
574 |
+
ret_scheduler_lr[-1] = lr_dic[textencoder_name]
|
575 |
all_params.append(param_data)
|
576 |
|
577 |
if self.unet_loras:
|
578 |
for unet_name in unet_names:
|
579 |
param_data = {'params': enumerate_params(self.unet_loras, lora_name=unet_name)}
|
580 |
+
used_names.append(unet_name)
|
581 |
if unet_lr is not None:
|
582 |
param_data['lr'] = unet_lr
|
583 |
+
if lr_dic is not None:
|
584 |
+
if unet_name in lr_dic:
|
585 |
+
param_data['lr'] = lr_dic[unet_name]
|
586 |
+
print(f"{unet_name} lr: {param_data['lr']}")
|
587 |
+
|
588 |
+
if block_args_dic is not None:
|
589 |
+
if "lora_unet_" in block_args_dic:
|
590 |
+
for pname, value in block_args_dic["lora_unet_"].items():
|
591 |
+
param_data[pname] = value
|
592 |
+
if unet_name in block_args_dic:
|
593 |
+
for pname, value in block_args_dic[unet_name].items():
|
594 |
+
param_data[pname] = value
|
595 |
+
|
596 |
+
if unet_lr is not None:
|
597 |
+
ret_scheduler_lr.append(unet_lr)
|
598 |
+
else:
|
599 |
+
ret_scheduler_lr.append(0.)
|
600 |
+
if lr_dic is not None:
|
601 |
+
if unet_name in lr_dic:
|
602 |
+
ret_scheduler_lr[-1] = lr_dic[unet_name]
|
603 |
all_params.append(param_data)
|
604 |
|
605 |
+
return all_params, {"initial_lr" : ret_scheduler_lr}, used_names
|
606 |
|
607 |
LoRANetwork.prepare_optimizer_params = prepare_optimizer_params
|
608 |
|
|
|
611 |
#============================================================================================================
|
612 |
def add_append_arguments(parser: argparse.ArgumentParser):
|
613 |
# for train_network_opt.py
|
614 |
+
#parser.add_argument("--optimizer", type=str, default="AdamW", choices=["AdamW", "RAdam", "AdaBound", "AdaBelief", "AggMo", "AdamP", "Adastand", "Adastand_belief", "Apollo", "Lamb", "Ranger", "RangerVA", "Lookahead_Adam", "Lookahead_DiffGrad", "Yogi", "NovoGrad", "QHAdam", "DiffGrad", "MADGRAD", "Adafactor"], help="使用するoptimizerを指定する")
|
615 |
+
#parser.add_argument("--optimizer_arg", type=str, default=None, nargs='*')
|
616 |
+
parser.add_argument("--use_lookahead", action="store_true")
|
617 |
+
parser.add_argument("--lookahead_arg", type=str, nargs="*", default=None)
|
618 |
parser.add_argument("--split_lora_networks", action="store_true")
|
619 |
parser.add_argument("--split_lora_level", type=int, default=0, help="どれくらい細分化するかの設定 0がunetのみを層別に 1がunetを大枠で分割 2がtextencoder含めて層別")
|
620 |
+
parser.add_argument("--blocks_lr_setting", type=str, default=None)
|
621 |
+
parser.add_argument("--block_optim_args", type=str, nargs="*", default=None)
|
622 |
parser.add_argument("--min_resolution", type=str, default=None)
|
623 |
parser.add_argument("--area_step", type=int, default=1)
|
624 |
parser.add_argument("--config", type=str, default=None)
|
625 |
|
626 |
+
def create_lr_blocks(lr_setting_str=None, block_optim_args=None):
|
627 |
+
ex_block_weight_dic = {
|
628 |
+
"BASE": "te",
|
629 |
+
"IN01": "down_0_at_0", "IN02": "down_0_at_1",
|
630 |
+
"IN04": "down_1_at_0", "IN05": "down_1_at_1",
|
631 |
+
"IN07": "down_2_at_0", "IN08": "down_2_at_1",
|
632 |
+
"MID": "mid",
|
633 |
+
"OUT03": "up_1_at_0", "OUT04": "up_1_at_1", "OUT05": "up_1_at_2",
|
634 |
+
"OUT06": "up_2_at_0", "OUT07": "up_2_at_1", "OUT08": "up_2_at_2",
|
635 |
+
"OUT09": "up_3_at_0", "OUT10": "up_3_at_1", "OUT11": "up_3_at_2",
|
636 |
+
}
|
637 |
+
|
638 |
+
blocks_name_dic = { "te": "lora_te_",
|
639 |
+
"unet": "lora_unet_",
|
640 |
+
"mid": "lora_unet_mid_block",
|
641 |
+
"down": "lora_unet_down_blocks_",
|
642 |
+
"up": "lora_unet_up_blocks_"}
|
643 |
+
for i in range(12):
|
644 |
+
blocks_name_dic[f"te_{i}"] = f"lora_te_text_model_encoder_layers_{i}_"
|
645 |
+
for i in range(3):
|
646 |
+
blocks_name_dic[f"down_{i}"] = f"lora_unet_down_blocks_{i}"
|
647 |
+
blocks_name_dic[f"up_{i+1}"] = f"lora_unet_up_blocks_{i+1}"
|
648 |
+
for i in range(3):
|
649 |
+
for j in range(2):
|
650 |
+
blocks_name_dic[f"down_{i}_at_{j}"] = f"lora_unet_down_blocks_{i}_attentions_{j}_"
|
651 |
+
for j in range(3):
|
652 |
+
blocks_name_dic[f"up_{i+1}_at_{j}"] = f"lora_unet_up_blocks_{i+1}_attentions_{j}_"
|
653 |
+
|
654 |
+
lr_dic = {}
|
655 |
+
if lr_setting_str==None or lr_setting_str=="":
|
656 |
+
pass
|
657 |
+
else:
|
658 |
+
lr_settings = lr_setting_str.replace(" ", "").split(",")
|
659 |
+
for lr_setting in lr_settings:
|
660 |
+
key, value = lr_setting.split("=")
|
661 |
+
if key in ex_block_weight_dic:
|
662 |
+
key = ex_block_weight_dic[key]
|
663 |
+
if key in blocks_name_dic:
|
664 |
+
new_key = blocks_name_dic[key]
|
665 |
+
lr_dic[new_key] = float(value)
|
666 |
+
if len(lr_dic)==0:
|
667 |
+
lr_dic = None
|
668 |
+
|
669 |
+
args_dic = {}
|
670 |
+
if (block_optim_args is None):
|
671 |
+
block_optim_args = []
|
672 |
+
if (len(block_optim_args)>0):
|
673 |
+
for my_arg in block_optim_args:
|
674 |
+
my_arg = my_arg.replace(" ", "")
|
675 |
+
splits = my_arg.split(":")
|
676 |
+
b_name = splits[0]
|
677 |
+
if b_name in ex_block_weight_dic:
|
678 |
+
b_name = ex_block_weight_dic[b_name]
|
679 |
+
new_b_name = blocks_name_dic[b_name]
|
680 |
+
key, _value = splits[1].split("=")
|
681 |
+
value_type = float
|
682 |
+
if len(splits)==3:
|
683 |
+
if _value=="str":
|
684 |
+
value_type = str
|
685 |
+
elif _value=="int":
|
686 |
+
value_type = int
|
687 |
+
_value = splits[2]
|
688 |
+
if _value=="true" or _value=="false":
|
689 |
+
value_type = bool
|
690 |
+
if "," in _value:
|
691 |
+
_value = _value.split(",")
|
692 |
+
for i in range(len(_value)):
|
693 |
+
_value[i] = value_type(_value[i])
|
694 |
+
value=tuple(_value)
|
695 |
+
else:
|
696 |
+
value = value_type(_value)
|
697 |
+
|
698 |
+
if not new_b_name in args_dic:
|
699 |
+
args_dic[new_b_name] = {}
|
700 |
+
args_dic[new_b_name][key] = value
|
701 |
+
|
702 |
+
if len(args_dic)==0:
|
703 |
+
args_dic = None
|
704 |
+
return lr_dic, args_dic
|
705 |
+
|
706 |
def create_split_names(split_flag, split_level):
|
707 |
split_names = None
|
708 |
if split_flag:
|
|
|
712 |
if split_level==1:
|
713 |
unet_names.append(f"lora_unet_down_blocks_")
|
714 |
unet_names.append(f"lora_unet_up_blocks_")
|
715 |
+
elif split_level==2 or split_level==0 or split_level==4:
|
716 |
+
if split_level>=2:
|
717 |
text_encoder_names = []
|
718 |
for i in range(12):
|
719 |
text_encoder_names.append(f"lora_te_text_model_encoder_layers_{i}_")
|
720 |
+
|
721 |
+
if split_level<=2:
|
722 |
+
for i in range(3):
|
723 |
+
unet_names.append(f"lora_unet_down_blocks_{i}")
|
724 |
+
unet_names.append(f"lora_unet_up_blocks_{i+1}")
|
725 |
+
|
726 |
+
if split_level>=3:
|
727 |
for i in range(3):
|
728 |
+
for j in range(2):
|
729 |
+
unet_names.append(f"lora_unet_down_blocks_{i}_attentions_{j}_")
|
730 |
+
for j in range(3):
|
731 |
+
unet_names.append(f"lora_unet_up_blocks_{i+1}_attentions_{j}_")
|
732 |
split_names["text_encoder"] = text_encoder_names
|
733 |
split_names["unet"] = unet_names
|
734 |
return split_names
|
|
|
740 |
import datetime
|
741 |
if os.path.splitext(args.config)[-1] == ".yaml":
|
742 |
args.config = os.path.splitext(args.config)[0]
|
743 |
+
config_path = f"{args.config}.yaml"
|
744 |
if os.path.exists(config_path):
|
745 |
print(f"{config_path} から設定を読���込み中...")
|
746 |
margs, rest = parser.parse_known_args()
|
|
|
761 |
args_type_dic[key] = act.type
|
762 |
#データタイプの確認とargsにkeyの内容を代入していく
|
763 |
for key, v in configs.items():
|
764 |
+
if v is not None:
|
765 |
+
if key in args_dic:
|
766 |
+
if args_dic[key] is not None:
|
767 |
+
new_type = type(args_dic[key])
|
768 |
+
if (not type(v) == new_type) and (not new_type==list):
|
769 |
+
v = new_type(v)
|
770 |
+
else:
|
771 |
if not type(v) == args_type_dic[key]:
|
772 |
v = args_type_dic[key](v)
|
773 |
+
args_dic[key] = v
|
774 |
#最後にデフォから指定が変わってるものを変更する
|
775 |
for key, v in change_def_dic.items():
|
776 |
args_dic[key] = v
|
777 |
else:
|
778 |
print(f"{config_path} が見つかりませんでした")
|
779 |
return args
|
780 |
+
|
781 |
+
'''
|
782 |
+
class GradientReversalFunction(torch.autograd.Function):
|
783 |
+
@staticmethod
|
784 |
+
def forward(ctx, input_forward: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
785 |
+
ctx.save_for_backward(scale)
|
786 |
+
return input_forward
|
787 |
+
@staticmethod
|
788 |
+
def backward(ctx, grad_backward: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
789 |
+
scale, = ctx.saved_tensors
|
790 |
+
return scale * -grad_backward, None
|
791 |
+
|
792 |
+
class GradientReversal(torch.nn.Module):
|
793 |
+
def __init__(self, scale: float):
|
794 |
+
super(GradientReversal, self).__init__()
|
795 |
+
self.scale = torch.tensor(scale)
|
796 |
+
def forward(self, x: torch.Tensor, flag: bool = False) -> torch.Tensor:
|
797 |
+
if flag:
|
798 |
+
return x
|
799 |
+
else:
|
800 |
+
return GradientReversalFunction.apply(x, self.scale)
|
801 |
+
'''
|
fine_tune.py
CHANGED
@@ -13,7 +13,11 @@ import diffusers
|
|
13 |
from diffusers import DDPMScheduler
|
14 |
|
15 |
import library.train_util as train_util
|
16 |
-
|
|
|
|
|
|
|
|
|
17 |
|
18 |
def collate_fn(examples):
|
19 |
return examples[0]
|
@@ -30,25 +34,36 @@ def train(args):
|
|
30 |
|
31 |
tokenizer = train_util.load_tokenizer(args)
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
if args.debug_dataset:
|
46 |
-
train_util.debug_dataset(
|
47 |
return
|
48 |
-
if len(
|
49 |
print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
|
50 |
return
|
51 |
|
|
|
|
|
|
|
52 |
# acceleratorを準備する
|
53 |
print("prepare accelerator")
|
54 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
@@ -109,7 +124,7 @@ def train(args):
|
|
109 |
vae.requires_grad_(False)
|
110 |
vae.eval()
|
111 |
with torch.no_grad():
|
112 |
-
|
113 |
vae.to("cpu")
|
114 |
if torch.cuda.is_available():
|
115 |
torch.cuda.empty_cache()
|
@@ -149,33 +164,13 @@ def train(args):
|
|
149 |
|
150 |
# 学習に必要なクラスを準備する
|
151 |
print("prepare optimizer, data loader etc.")
|
152 |
-
|
153 |
-
# 8-bit Adamを使う
|
154 |
-
if args.use_8bit_adam:
|
155 |
-
try:
|
156 |
-
import bitsandbytes as bnb
|
157 |
-
except ImportError:
|
158 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
159 |
-
print("use 8-bit Adam optimizer")
|
160 |
-
optimizer_class = bnb.optim.AdamW8bit
|
161 |
-
elif args.use_lion_optimizer:
|
162 |
-
try:
|
163 |
-
import lion_pytorch
|
164 |
-
except ImportError:
|
165 |
-
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
166 |
-
print("use Lion optimizer")
|
167 |
-
optimizer_class = lion_pytorch.Lion
|
168 |
-
else:
|
169 |
-
optimizer_class = torch.optim.AdamW
|
170 |
-
|
171 |
-
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
172 |
-
optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
|
173 |
|
174 |
# dataloaderを準備する
|
175 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
176 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
177 |
train_dataloader = torch.utils.data.DataLoader(
|
178 |
-
|
179 |
|
180 |
# 学習ステップ数を計算する
|
181 |
if args.max_train_epochs is not None:
|
@@ -183,8 +178,9 @@ def train(args):
|
|
183 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
184 |
|
185 |
# lr schedulerを用意する
|
186 |
-
lr_scheduler =
|
187 |
-
|
|
|
188 |
|
189 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
190 |
if args.full_fp16:
|
@@ -218,7 +214,7 @@ def train(args):
|
|
218 |
# 学習する
|
219 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
220 |
print("running training / 学習開始")
|
221 |
-
print(f" num examples / サンプル数: {
|
222 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
223 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
224 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
@@ -237,7 +233,7 @@ def train(args):
|
|
237 |
|
238 |
for epoch in range(num_train_epochs):
|
239 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
240 |
-
|
241 |
|
242 |
for m in training_models:
|
243 |
m.train()
|
@@ -286,11 +282,11 @@ def train(args):
|
|
286 |
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
287 |
|
288 |
accelerator.backward(loss)
|
289 |
-
if accelerator.sync_gradients:
|
290 |
params_to_clip = []
|
291 |
for m in training_models:
|
292 |
params_to_clip.extend(m.parameters())
|
293 |
-
accelerator.clip_grad_norm_(params_to_clip,
|
294 |
|
295 |
optimizer.step()
|
296 |
lr_scheduler.step()
|
@@ -301,11 +297,16 @@ def train(args):
|
|
301 |
progress_bar.update(1)
|
302 |
global_step += 1
|
303 |
|
|
|
|
|
304 |
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
305 |
if args.logging_dir is not None:
|
306 |
-
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
|
|
307 |
accelerator.log(logs, step=global_step)
|
308 |
|
|
|
309 |
loss_total += current_loss
|
310 |
avr_loss = loss_total / (step+1)
|
311 |
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
@@ -315,7 +316,7 @@ def train(args):
|
|
315 |
break
|
316 |
|
317 |
if args.logging_dir is not None:
|
318 |
-
logs = {"
|
319 |
accelerator.log(logs, step=epoch+1)
|
320 |
|
321 |
accelerator.wait_for_everyone()
|
@@ -325,6 +326,8 @@ def train(args):
|
|
325 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
326 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
327 |
|
|
|
|
|
328 |
is_main_process = accelerator.is_main_process
|
329 |
if is_main_process:
|
330 |
unet = unwrap_model(unet)
|
@@ -351,6 +354,8 @@ if __name__ == '__main__':
|
|
351 |
train_util.add_dataset_arguments(parser, False, True, True)
|
352 |
train_util.add_training_arguments(parser, False)
|
353 |
train_util.add_sd_saving_arguments(parser)
|
|
|
|
|
354 |
|
355 |
parser.add_argument("--diffusers_xformers", action='store_true',
|
356 |
help='use xformers by diffusers / Diffusersでxformersを使用する')
|
|
|
13 |
from diffusers import DDPMScheduler
|
14 |
|
15 |
import library.train_util as train_util
|
16 |
+
import library.config_util as config_util
|
17 |
+
from library.config_util import (
|
18 |
+
ConfigSanitizer,
|
19 |
+
BlueprintGenerator,
|
20 |
+
)
|
21 |
|
22 |
def collate_fn(examples):
|
23 |
return examples[0]
|
|
|
34 |
|
35 |
tokenizer = train_util.load_tokenizer(args)
|
36 |
|
37 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
|
38 |
+
if args.dataset_config is not None:
|
39 |
+
print(f"Load dataset config from {args.dataset_config}")
|
40 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
41 |
+
ignored = ["train_data_dir", "in_json"]
|
42 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
43 |
+
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
44 |
+
else:
|
45 |
+
user_config = {
|
46 |
+
"datasets": [{
|
47 |
+
"subsets": [{
|
48 |
+
"image_dir": args.train_data_dir,
|
49 |
+
"metadata_file": args.in_json,
|
50 |
+
}]
|
51 |
+
}]
|
52 |
+
}
|
53 |
+
|
54 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
55 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
56 |
|
57 |
if args.debug_dataset:
|
58 |
+
train_util.debug_dataset(train_dataset_group)
|
59 |
return
|
60 |
+
if len(train_dataset_group) == 0:
|
61 |
print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
|
62 |
return
|
63 |
|
64 |
+
if cache_latents:
|
65 |
+
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
66 |
+
|
67 |
# acceleratorを準備する
|
68 |
print("prepare accelerator")
|
69 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
|
|
124 |
vae.requires_grad_(False)
|
125 |
vae.eval()
|
126 |
with torch.no_grad():
|
127 |
+
train_dataset_group.cache_latents(vae)
|
128 |
vae.to("cpu")
|
129 |
if torch.cuda.is_available():
|
130 |
torch.cuda.empty_cache()
|
|
|
164 |
|
165 |
# 学習に必要なクラスを準備する
|
166 |
print("prepare optimizer, data loader etc.")
|
167 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
# dataloaderを準備する
|
170 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
171 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
172 |
train_dataloader = torch.utils.data.DataLoader(
|
173 |
+
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
174 |
|
175 |
# 学習ステップ数を計算する
|
176 |
if args.max_train_epochs is not None:
|
|
|
178 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
179 |
|
180 |
# lr schedulerを用意する
|
181 |
+
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
182 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
183 |
+
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
184 |
|
185 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
186 |
if args.full_fp16:
|
|
|
214 |
# 学習する
|
215 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
216 |
print("running training / 学習開始")
|
217 |
+
print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
218 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
219 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
220 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
|
|
233 |
|
234 |
for epoch in range(num_train_epochs):
|
235 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
236 |
+
train_dataset_group.set_current_epoch(epoch + 1)
|
237 |
|
238 |
for m in training_models:
|
239 |
m.train()
|
|
|
282 |
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
283 |
|
284 |
accelerator.backward(loss)
|
285 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
286 |
params_to_clip = []
|
287 |
for m in training_models:
|
288 |
params_to_clip.extend(m.parameters())
|
289 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
290 |
|
291 |
optimizer.step()
|
292 |
lr_scheduler.step()
|
|
|
297 |
progress_bar.update(1)
|
298 |
global_step += 1
|
299 |
|
300 |
+
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
301 |
+
|
302 |
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
303 |
if args.logging_dir is not None:
|
304 |
+
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
305 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
306 |
+
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
307 |
accelerator.log(logs, step=global_step)
|
308 |
|
309 |
+
# TODO moving averageにする
|
310 |
loss_total += current_loss
|
311 |
avr_loss = loss_total / (step+1)
|
312 |
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
316 |
break
|
317 |
|
318 |
if args.logging_dir is not None:
|
319 |
+
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
320 |
accelerator.log(logs, step=epoch+1)
|
321 |
|
322 |
accelerator.wait_for_everyone()
|
|
|
326 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
327 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
328 |
|
329 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
330 |
+
|
331 |
is_main_process = accelerator.is_main_process
|
332 |
if is_main_process:
|
333 |
unet = unwrap_model(unet)
|
|
|
354 |
train_util.add_dataset_arguments(parser, False, True, True)
|
355 |
train_util.add_training_arguments(parser, False)
|
356 |
train_util.add_sd_saving_arguments(parser)
|
357 |
+
train_util.add_optimizer_arguments(parser)
|
358 |
+
config_util.add_config_arguments(parser)
|
359 |
|
360 |
parser.add_argument("--diffusers_xformers", action='store_true',
|
361 |
help='use xformers by diffusers / Diffusersでxformersを使用する')
|
finetune/blip/blip.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import warnings
|
9 |
+
warnings.filterwarnings("ignore")
|
10 |
+
|
11 |
+
# from models.vit import VisionTransformer, interpolate_pos_embed
|
12 |
+
# from models.med import BertConfig, BertModel, BertLMHeadModel
|
13 |
+
from blip.vit import VisionTransformer, interpolate_pos_embed
|
14 |
+
from blip.med import BertConfig, BertModel, BertLMHeadModel
|
15 |
+
from transformers import BertTokenizer
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
|
21 |
+
import os
|
22 |
+
from urllib.parse import urlparse
|
23 |
+
from timm.models.hub import download_cached_file
|
24 |
+
|
25 |
+
class BLIP_Base(nn.Module):
|
26 |
+
def __init__(self,
|
27 |
+
med_config = 'configs/med_config.json',
|
28 |
+
image_size = 224,
|
29 |
+
vit = 'base',
|
30 |
+
vit_grad_ckpt = False,
|
31 |
+
vit_ckpt_layer = 0,
|
32 |
+
):
|
33 |
+
"""
|
34 |
+
Args:
|
35 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
36 |
+
image_size (int): input image size
|
37 |
+
vit (str): model size of vision transformer
|
38 |
+
"""
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
42 |
+
self.tokenizer = init_tokenizer()
|
43 |
+
med_config = BertConfig.from_json_file(med_config)
|
44 |
+
med_config.encoder_width = vision_width
|
45 |
+
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
46 |
+
|
47 |
+
|
48 |
+
def forward(self, image, caption, mode):
|
49 |
+
|
50 |
+
assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
|
51 |
+
text = self.tokenizer(caption, return_tensors="pt").to(image.device)
|
52 |
+
|
53 |
+
if mode=='image':
|
54 |
+
# return image features
|
55 |
+
image_embeds = self.visual_encoder(image)
|
56 |
+
return image_embeds
|
57 |
+
|
58 |
+
elif mode=='text':
|
59 |
+
# return text features
|
60 |
+
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
61 |
+
return_dict = True, mode = 'text')
|
62 |
+
return text_output.last_hidden_state
|
63 |
+
|
64 |
+
elif mode=='multimodal':
|
65 |
+
# return multimodel features
|
66 |
+
image_embeds = self.visual_encoder(image)
|
67 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
68 |
+
|
69 |
+
text.input_ids[:,0] = self.tokenizer.enc_token_id
|
70 |
+
output = self.text_encoder(text.input_ids,
|
71 |
+
attention_mask = text.attention_mask,
|
72 |
+
encoder_hidden_states = image_embeds,
|
73 |
+
encoder_attention_mask = image_atts,
|
74 |
+
return_dict = True,
|
75 |
+
)
|
76 |
+
return output.last_hidden_state
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
class BLIP_Decoder(nn.Module):
|
81 |
+
def __init__(self,
|
82 |
+
med_config = 'configs/med_config.json',
|
83 |
+
image_size = 384,
|
84 |
+
vit = 'base',
|
85 |
+
vit_grad_ckpt = False,
|
86 |
+
vit_ckpt_layer = 0,
|
87 |
+
prompt = 'a picture of ',
|
88 |
+
):
|
89 |
+
"""
|
90 |
+
Args:
|
91 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
92 |
+
image_size (int): input image size
|
93 |
+
vit (str): model size of vision transformer
|
94 |
+
"""
|
95 |
+
super().__init__()
|
96 |
+
|
97 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
98 |
+
self.tokenizer = init_tokenizer()
|
99 |
+
med_config = BertConfig.from_json_file(med_config)
|
100 |
+
med_config.encoder_width = vision_width
|
101 |
+
self.text_decoder = BertLMHeadModel(config=med_config)
|
102 |
+
|
103 |
+
self.prompt = prompt
|
104 |
+
self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
|
105 |
+
|
106 |
+
|
107 |
+
def forward(self, image, caption):
|
108 |
+
|
109 |
+
image_embeds = self.visual_encoder(image)
|
110 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
111 |
+
|
112 |
+
text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
|
113 |
+
|
114 |
+
text.input_ids[:,0] = self.tokenizer.bos_token_id
|
115 |
+
|
116 |
+
decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
|
117 |
+
decoder_targets[:,:self.prompt_length] = -100
|
118 |
+
|
119 |
+
decoder_output = self.text_decoder(text.input_ids,
|
120 |
+
attention_mask = text.attention_mask,
|
121 |
+
encoder_hidden_states = image_embeds,
|
122 |
+
encoder_attention_mask = image_atts,
|
123 |
+
labels = decoder_targets,
|
124 |
+
return_dict = True,
|
125 |
+
)
|
126 |
+
loss_lm = decoder_output.loss
|
127 |
+
|
128 |
+
return loss_lm
|
129 |
+
|
130 |
+
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
|
131 |
+
image_embeds = self.visual_encoder(image)
|
132 |
+
|
133 |
+
if not sample:
|
134 |
+
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
|
135 |
+
|
136 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
137 |
+
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
|
138 |
+
|
139 |
+
prompt = [self.prompt] * image.size(0)
|
140 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
|
141 |
+
input_ids[:,0] = self.tokenizer.bos_token_id
|
142 |
+
input_ids = input_ids[:, :-1]
|
143 |
+
|
144 |
+
if sample:
|
145 |
+
#nucleus sampling
|
146 |
+
outputs = self.text_decoder.generate(input_ids=input_ids,
|
147 |
+
max_length=max_length,
|
148 |
+
min_length=min_length,
|
149 |
+
do_sample=True,
|
150 |
+
top_p=top_p,
|
151 |
+
num_return_sequences=1,
|
152 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
153 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
154 |
+
repetition_penalty=1.1,
|
155 |
+
**model_kwargs)
|
156 |
+
else:
|
157 |
+
#beam search
|
158 |
+
outputs = self.text_decoder.generate(input_ids=input_ids,
|
159 |
+
max_length=max_length,
|
160 |
+
min_length=min_length,
|
161 |
+
num_beams=num_beams,
|
162 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
163 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
164 |
+
repetition_penalty=repetition_penalty,
|
165 |
+
**model_kwargs)
|
166 |
+
|
167 |
+
captions = []
|
168 |
+
for output in outputs:
|
169 |
+
caption = self.tokenizer.decode(output, skip_special_tokens=True)
|
170 |
+
captions.append(caption[len(self.prompt):])
|
171 |
+
return captions
|
172 |
+
|
173 |
+
|
174 |
+
def blip_decoder(pretrained='',**kwargs):
|
175 |
+
model = BLIP_Decoder(**kwargs)
|
176 |
+
if pretrained:
|
177 |
+
model,msg = load_checkpoint(model,pretrained)
|
178 |
+
assert(len(msg.missing_keys)==0)
|
179 |
+
return model
|
180 |
+
|
181 |
+
def blip_feature_extractor(pretrained='',**kwargs):
|
182 |
+
model = BLIP_Base(**kwargs)
|
183 |
+
if pretrained:
|
184 |
+
model,msg = load_checkpoint(model,pretrained)
|
185 |
+
assert(len(msg.missing_keys)==0)
|
186 |
+
return model
|
187 |
+
|
188 |
+
def init_tokenizer():
|
189 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
190 |
+
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
191 |
+
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
192 |
+
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
193 |
+
return tokenizer
|
194 |
+
|
195 |
+
|
196 |
+
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
|
197 |
+
|
198 |
+
assert vit in ['base', 'large'], "vit parameter must be base or large"
|
199 |
+
if vit=='base':
|
200 |
+
vision_width = 768
|
201 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
|
202 |
+
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
203 |
+
drop_path_rate=0 or drop_path_rate
|
204 |
+
)
|
205 |
+
elif vit=='large':
|
206 |
+
vision_width = 1024
|
207 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
|
208 |
+
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
209 |
+
drop_path_rate=0.1 or drop_path_rate
|
210 |
+
)
|
211 |
+
return visual_encoder, vision_width
|
212 |
+
|
213 |
+
def is_url(url_or_filename):
|
214 |
+
parsed = urlparse(url_or_filename)
|
215 |
+
return parsed.scheme in ("http", "https")
|
216 |
+
|
217 |
+
def load_checkpoint(model,url_or_filename):
|
218 |
+
if is_url(url_or_filename):
|
219 |
+
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
220 |
+
checkpoint = torch.load(cached_file, map_location='cpu')
|
221 |
+
elif os.path.isfile(url_or_filename):
|
222 |
+
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
223 |
+
else:
|
224 |
+
raise RuntimeError('checkpoint url or path is invalid')
|
225 |
+
|
226 |
+
state_dict = checkpoint['model']
|
227 |
+
|
228 |
+
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
229 |
+
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
|
230 |
+
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
|
231 |
+
model.visual_encoder_m)
|
232 |
+
for key in model.state_dict().keys():
|
233 |
+
if key in state_dict.keys():
|
234 |
+
if state_dict[key].shape!=model.state_dict()[key].shape:
|
235 |
+
del state_dict[key]
|
236 |
+
|
237 |
+
msg = model.load_state_dict(state_dict,strict=False)
|
238 |
+
print('load checkpoint from %s'%url_or_filename)
|
239 |
+
return model,msg
|
240 |
+
|
finetune/blip/med.py
ADDED
@@ -0,0 +1,955 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
* Based on huggingface code base
|
8 |
+
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
9 |
+
'''
|
10 |
+
|
11 |
+
import math
|
12 |
+
import os
|
13 |
+
import warnings
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Optional, Tuple
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import Tensor, device, dtype, nn
|
19 |
+
import torch.utils.checkpoint
|
20 |
+
from torch import nn
|
21 |
+
from torch.nn import CrossEntropyLoss
|
22 |
+
import torch.nn.functional as F
|
23 |
+
|
24 |
+
from transformers.activations import ACT2FN
|
25 |
+
from transformers.file_utils import (
|
26 |
+
ModelOutput,
|
27 |
+
)
|
28 |
+
from transformers.modeling_outputs import (
|
29 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
30 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
31 |
+
CausalLMOutputWithCrossAttentions,
|
32 |
+
MaskedLMOutput,
|
33 |
+
MultipleChoiceModelOutput,
|
34 |
+
NextSentencePredictorOutput,
|
35 |
+
QuestionAnsweringModelOutput,
|
36 |
+
SequenceClassifierOutput,
|
37 |
+
TokenClassifierOutput,
|
38 |
+
)
|
39 |
+
from transformers.modeling_utils import (
|
40 |
+
PreTrainedModel,
|
41 |
+
apply_chunking_to_forward,
|
42 |
+
find_pruneable_heads_and_indices,
|
43 |
+
prune_linear_layer,
|
44 |
+
)
|
45 |
+
from transformers.utils import logging
|
46 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
47 |
+
|
48 |
+
|
49 |
+
logger = logging.get_logger(__name__)
|
50 |
+
|
51 |
+
|
52 |
+
class BertEmbeddings(nn.Module):
|
53 |
+
"""Construct the embeddings from word and position embeddings."""
|
54 |
+
|
55 |
+
def __init__(self, config):
|
56 |
+
super().__init__()
|
57 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
58 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
59 |
+
|
60 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
61 |
+
# any TensorFlow checkpoint file
|
62 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
63 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
64 |
+
|
65 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
66 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
67 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
68 |
+
|
69 |
+
self.config = config
|
70 |
+
|
71 |
+
def forward(
|
72 |
+
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
73 |
+
):
|
74 |
+
if input_ids is not None:
|
75 |
+
input_shape = input_ids.size()
|
76 |
+
else:
|
77 |
+
input_shape = inputs_embeds.size()[:-1]
|
78 |
+
|
79 |
+
seq_length = input_shape[1]
|
80 |
+
|
81 |
+
if position_ids is None:
|
82 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
83 |
+
|
84 |
+
if inputs_embeds is None:
|
85 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
86 |
+
|
87 |
+
embeddings = inputs_embeds
|
88 |
+
|
89 |
+
if self.position_embedding_type == "absolute":
|
90 |
+
position_embeddings = self.position_embeddings(position_ids)
|
91 |
+
embeddings += position_embeddings
|
92 |
+
embeddings = self.LayerNorm(embeddings)
|
93 |
+
embeddings = self.dropout(embeddings)
|
94 |
+
return embeddings
|
95 |
+
|
96 |
+
|
97 |
+
class BertSelfAttention(nn.Module):
|
98 |
+
def __init__(self, config, is_cross_attention):
|
99 |
+
super().__init__()
|
100 |
+
self.config = config
|
101 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
102 |
+
raise ValueError(
|
103 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
104 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
105 |
+
)
|
106 |
+
|
107 |
+
self.num_attention_heads = config.num_attention_heads
|
108 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
109 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
110 |
+
|
111 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
112 |
+
if is_cross_attention:
|
113 |
+
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
114 |
+
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
115 |
+
else:
|
116 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
117 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
118 |
+
|
119 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
120 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
121 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
122 |
+
self.max_position_embeddings = config.max_position_embeddings
|
123 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
124 |
+
self.save_attention = False
|
125 |
+
|
126 |
+
def save_attn_gradients(self, attn_gradients):
|
127 |
+
self.attn_gradients = attn_gradients
|
128 |
+
|
129 |
+
def get_attn_gradients(self):
|
130 |
+
return self.attn_gradients
|
131 |
+
|
132 |
+
def save_attention_map(self, attention_map):
|
133 |
+
self.attention_map = attention_map
|
134 |
+
|
135 |
+
def get_attention_map(self):
|
136 |
+
return self.attention_map
|
137 |
+
|
138 |
+
def transpose_for_scores(self, x):
|
139 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
140 |
+
x = x.view(*new_x_shape)
|
141 |
+
return x.permute(0, 2, 1, 3)
|
142 |
+
|
143 |
+
def forward(
|
144 |
+
self,
|
145 |
+
hidden_states,
|
146 |
+
attention_mask=None,
|
147 |
+
head_mask=None,
|
148 |
+
encoder_hidden_states=None,
|
149 |
+
encoder_attention_mask=None,
|
150 |
+
past_key_value=None,
|
151 |
+
output_attentions=False,
|
152 |
+
):
|
153 |
+
mixed_query_layer = self.query(hidden_states)
|
154 |
+
|
155 |
+
# If this is instantiated as a cross-attention module, the keys
|
156 |
+
# and values come from an encoder; the attention mask needs to be
|
157 |
+
# such that the encoder's padding tokens are not attended to.
|
158 |
+
is_cross_attention = encoder_hidden_states is not None
|
159 |
+
|
160 |
+
if is_cross_attention:
|
161 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
162 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
163 |
+
attention_mask = encoder_attention_mask
|
164 |
+
elif past_key_value is not None:
|
165 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
166 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
167 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
168 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
169 |
+
else:
|
170 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
171 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
172 |
+
|
173 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
174 |
+
|
175 |
+
past_key_value = (key_layer, value_layer)
|
176 |
+
|
177 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
178 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
179 |
+
|
180 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
181 |
+
seq_length = hidden_states.size()[1]
|
182 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
183 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
184 |
+
distance = position_ids_l - position_ids_r
|
185 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
186 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
187 |
+
|
188 |
+
if self.position_embedding_type == "relative_key":
|
189 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
190 |
+
attention_scores = attention_scores + relative_position_scores
|
191 |
+
elif self.position_embedding_type == "relative_key_query":
|
192 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
193 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
194 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
195 |
+
|
196 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
197 |
+
if attention_mask is not None:
|
198 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
199 |
+
attention_scores = attention_scores + attention_mask
|
200 |
+
|
201 |
+
# Normalize the attention scores to probabilities.
|
202 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
203 |
+
|
204 |
+
if is_cross_attention and self.save_attention:
|
205 |
+
self.save_attention_map(attention_probs)
|
206 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
207 |
+
|
208 |
+
# This is actually dropping out entire tokens to attend to, which might
|
209 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
210 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
211 |
+
|
212 |
+
# Mask heads if we want to
|
213 |
+
if head_mask is not None:
|
214 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
215 |
+
|
216 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
217 |
+
|
218 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
219 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
220 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
221 |
+
|
222 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
223 |
+
|
224 |
+
outputs = outputs + (past_key_value,)
|
225 |
+
return outputs
|
226 |
+
|
227 |
+
|
228 |
+
class BertSelfOutput(nn.Module):
|
229 |
+
def __init__(self, config):
|
230 |
+
super().__init__()
|
231 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
232 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
233 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
234 |
+
|
235 |
+
def forward(self, hidden_states, input_tensor):
|
236 |
+
hidden_states = self.dense(hidden_states)
|
237 |
+
hidden_states = self.dropout(hidden_states)
|
238 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
239 |
+
return hidden_states
|
240 |
+
|
241 |
+
|
242 |
+
class BertAttention(nn.Module):
|
243 |
+
def __init__(self, config, is_cross_attention=False):
|
244 |
+
super().__init__()
|
245 |
+
self.self = BertSelfAttention(config, is_cross_attention)
|
246 |
+
self.output = BertSelfOutput(config)
|
247 |
+
self.pruned_heads = set()
|
248 |
+
|
249 |
+
def prune_heads(self, heads):
|
250 |
+
if len(heads) == 0:
|
251 |
+
return
|
252 |
+
heads, index = find_pruneable_heads_and_indices(
|
253 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
254 |
+
)
|
255 |
+
|
256 |
+
# Prune linear layers
|
257 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
258 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
259 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
260 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
261 |
+
|
262 |
+
# Update hyper params and store pruned heads
|
263 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
264 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
265 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
266 |
+
|
267 |
+
def forward(
|
268 |
+
self,
|
269 |
+
hidden_states,
|
270 |
+
attention_mask=None,
|
271 |
+
head_mask=None,
|
272 |
+
encoder_hidden_states=None,
|
273 |
+
encoder_attention_mask=None,
|
274 |
+
past_key_value=None,
|
275 |
+
output_attentions=False,
|
276 |
+
):
|
277 |
+
self_outputs = self.self(
|
278 |
+
hidden_states,
|
279 |
+
attention_mask,
|
280 |
+
head_mask,
|
281 |
+
encoder_hidden_states,
|
282 |
+
encoder_attention_mask,
|
283 |
+
past_key_value,
|
284 |
+
output_attentions,
|
285 |
+
)
|
286 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
287 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
288 |
+
return outputs
|
289 |
+
|
290 |
+
|
291 |
+
class BertIntermediate(nn.Module):
|
292 |
+
def __init__(self, config):
|
293 |
+
super().__init__()
|
294 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
295 |
+
if isinstance(config.hidden_act, str):
|
296 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
297 |
+
else:
|
298 |
+
self.intermediate_act_fn = config.hidden_act
|
299 |
+
|
300 |
+
def forward(self, hidden_states):
|
301 |
+
hidden_states = self.dense(hidden_states)
|
302 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
303 |
+
return hidden_states
|
304 |
+
|
305 |
+
|
306 |
+
class BertOutput(nn.Module):
|
307 |
+
def __init__(self, config):
|
308 |
+
super().__init__()
|
309 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
310 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
311 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
312 |
+
|
313 |
+
def forward(self, hidden_states, input_tensor):
|
314 |
+
hidden_states = self.dense(hidden_states)
|
315 |
+
hidden_states = self.dropout(hidden_states)
|
316 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
317 |
+
return hidden_states
|
318 |
+
|
319 |
+
|
320 |
+
class BertLayer(nn.Module):
|
321 |
+
def __init__(self, config, layer_num):
|
322 |
+
super().__init__()
|
323 |
+
self.config = config
|
324 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
325 |
+
self.seq_len_dim = 1
|
326 |
+
self.attention = BertAttention(config)
|
327 |
+
self.layer_num = layer_num
|
328 |
+
if self.config.add_cross_attention:
|
329 |
+
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
|
330 |
+
self.intermediate = BertIntermediate(config)
|
331 |
+
self.output = BertOutput(config)
|
332 |
+
|
333 |
+
def forward(
|
334 |
+
self,
|
335 |
+
hidden_states,
|
336 |
+
attention_mask=None,
|
337 |
+
head_mask=None,
|
338 |
+
encoder_hidden_states=None,
|
339 |
+
encoder_attention_mask=None,
|
340 |
+
past_key_value=None,
|
341 |
+
output_attentions=False,
|
342 |
+
mode=None,
|
343 |
+
):
|
344 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
345 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
346 |
+
self_attention_outputs = self.attention(
|
347 |
+
hidden_states,
|
348 |
+
attention_mask,
|
349 |
+
head_mask,
|
350 |
+
output_attentions=output_attentions,
|
351 |
+
past_key_value=self_attn_past_key_value,
|
352 |
+
)
|
353 |
+
attention_output = self_attention_outputs[0]
|
354 |
+
|
355 |
+
outputs = self_attention_outputs[1:-1]
|
356 |
+
present_key_value = self_attention_outputs[-1]
|
357 |
+
|
358 |
+
if mode=='multimodal':
|
359 |
+
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
|
360 |
+
|
361 |
+
cross_attention_outputs = self.crossattention(
|
362 |
+
attention_output,
|
363 |
+
attention_mask,
|
364 |
+
head_mask,
|
365 |
+
encoder_hidden_states,
|
366 |
+
encoder_attention_mask,
|
367 |
+
output_attentions=output_attentions,
|
368 |
+
)
|
369 |
+
attention_output = cross_attention_outputs[0]
|
370 |
+
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
371 |
+
layer_output = apply_chunking_to_forward(
|
372 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
373 |
+
)
|
374 |
+
outputs = (layer_output,) + outputs
|
375 |
+
|
376 |
+
outputs = outputs + (present_key_value,)
|
377 |
+
|
378 |
+
return outputs
|
379 |
+
|
380 |
+
def feed_forward_chunk(self, attention_output):
|
381 |
+
intermediate_output = self.intermediate(attention_output)
|
382 |
+
layer_output = self.output(intermediate_output, attention_output)
|
383 |
+
return layer_output
|
384 |
+
|
385 |
+
|
386 |
+
class BertEncoder(nn.Module):
|
387 |
+
def __init__(self, config):
|
388 |
+
super().__init__()
|
389 |
+
self.config = config
|
390 |
+
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
|
391 |
+
self.gradient_checkpointing = False
|
392 |
+
|
393 |
+
def forward(
|
394 |
+
self,
|
395 |
+
hidden_states,
|
396 |
+
attention_mask=None,
|
397 |
+
head_mask=None,
|
398 |
+
encoder_hidden_states=None,
|
399 |
+
encoder_attention_mask=None,
|
400 |
+
past_key_values=None,
|
401 |
+
use_cache=None,
|
402 |
+
output_attentions=False,
|
403 |
+
output_hidden_states=False,
|
404 |
+
return_dict=True,
|
405 |
+
mode='multimodal',
|
406 |
+
):
|
407 |
+
all_hidden_states = () if output_hidden_states else None
|
408 |
+
all_self_attentions = () if output_attentions else None
|
409 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
410 |
+
|
411 |
+
next_decoder_cache = () if use_cache else None
|
412 |
+
|
413 |
+
for i in range(self.config.num_hidden_layers):
|
414 |
+
layer_module = self.layer[i]
|
415 |
+
if output_hidden_states:
|
416 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
417 |
+
|
418 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
419 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
420 |
+
|
421 |
+
if self.gradient_checkpointing and self.training:
|
422 |
+
|
423 |
+
if use_cache:
|
424 |
+
logger.warn(
|
425 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
426 |
+
)
|
427 |
+
use_cache = False
|
428 |
+
|
429 |
+
def create_custom_forward(module):
|
430 |
+
def custom_forward(*inputs):
|
431 |
+
return module(*inputs, past_key_value, output_attentions)
|
432 |
+
|
433 |
+
return custom_forward
|
434 |
+
|
435 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
436 |
+
create_custom_forward(layer_module),
|
437 |
+
hidden_states,
|
438 |
+
attention_mask,
|
439 |
+
layer_head_mask,
|
440 |
+
encoder_hidden_states,
|
441 |
+
encoder_attention_mask,
|
442 |
+
mode=mode,
|
443 |
+
)
|
444 |
+
else:
|
445 |
+
layer_outputs = layer_module(
|
446 |
+
hidden_states,
|
447 |
+
attention_mask,
|
448 |
+
layer_head_mask,
|
449 |
+
encoder_hidden_states,
|
450 |
+
encoder_attention_mask,
|
451 |
+
past_key_value,
|
452 |
+
output_attentions,
|
453 |
+
mode=mode,
|
454 |
+
)
|
455 |
+
|
456 |
+
hidden_states = layer_outputs[0]
|
457 |
+
if use_cache:
|
458 |
+
next_decoder_cache += (layer_outputs[-1],)
|
459 |
+
if output_attentions:
|
460 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
461 |
+
|
462 |
+
if output_hidden_states:
|
463 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
464 |
+
|
465 |
+
if not return_dict:
|
466 |
+
return tuple(
|
467 |
+
v
|
468 |
+
for v in [
|
469 |
+
hidden_states,
|
470 |
+
next_decoder_cache,
|
471 |
+
all_hidden_states,
|
472 |
+
all_self_attentions,
|
473 |
+
all_cross_attentions,
|
474 |
+
]
|
475 |
+
if v is not None
|
476 |
+
)
|
477 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
478 |
+
last_hidden_state=hidden_states,
|
479 |
+
past_key_values=next_decoder_cache,
|
480 |
+
hidden_states=all_hidden_states,
|
481 |
+
attentions=all_self_attentions,
|
482 |
+
cross_attentions=all_cross_attentions,
|
483 |
+
)
|
484 |
+
|
485 |
+
|
486 |
+
class BertPooler(nn.Module):
|
487 |
+
def __init__(self, config):
|
488 |
+
super().__init__()
|
489 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
490 |
+
self.activation = nn.Tanh()
|
491 |
+
|
492 |
+
def forward(self, hidden_states):
|
493 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
494 |
+
# to the first token.
|
495 |
+
first_token_tensor = hidden_states[:, 0]
|
496 |
+
pooled_output = self.dense(first_token_tensor)
|
497 |
+
pooled_output = self.activation(pooled_output)
|
498 |
+
return pooled_output
|
499 |
+
|
500 |
+
|
501 |
+
class BertPredictionHeadTransform(nn.Module):
|
502 |
+
def __init__(self, config):
|
503 |
+
super().__init__()
|
504 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
505 |
+
if isinstance(config.hidden_act, str):
|
506 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
507 |
+
else:
|
508 |
+
self.transform_act_fn = config.hidden_act
|
509 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
510 |
+
|
511 |
+
def forward(self, hidden_states):
|
512 |
+
hidden_states = self.dense(hidden_states)
|
513 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
514 |
+
hidden_states = self.LayerNorm(hidden_states)
|
515 |
+
return hidden_states
|
516 |
+
|
517 |
+
|
518 |
+
class BertLMPredictionHead(nn.Module):
|
519 |
+
def __init__(self, config):
|
520 |
+
super().__init__()
|
521 |
+
self.transform = BertPredictionHeadTransform(config)
|
522 |
+
|
523 |
+
# The output weights are the same as the input embeddings, but there is
|
524 |
+
# an output-only bias for each token.
|
525 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
526 |
+
|
527 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
528 |
+
|
529 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
530 |
+
self.decoder.bias = self.bias
|
531 |
+
|
532 |
+
def forward(self, hidden_states):
|
533 |
+
hidden_states = self.transform(hidden_states)
|
534 |
+
hidden_states = self.decoder(hidden_states)
|
535 |
+
return hidden_states
|
536 |
+
|
537 |
+
|
538 |
+
class BertOnlyMLMHead(nn.Module):
|
539 |
+
def __init__(self, config):
|
540 |
+
super().__init__()
|
541 |
+
self.predictions = BertLMPredictionHead(config)
|
542 |
+
|
543 |
+
def forward(self, sequence_output):
|
544 |
+
prediction_scores = self.predictions(sequence_output)
|
545 |
+
return prediction_scores
|
546 |
+
|
547 |
+
|
548 |
+
class BertPreTrainedModel(PreTrainedModel):
|
549 |
+
"""
|
550 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
551 |
+
models.
|
552 |
+
"""
|
553 |
+
|
554 |
+
config_class = BertConfig
|
555 |
+
base_model_prefix = "bert"
|
556 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
557 |
+
|
558 |
+
def _init_weights(self, module):
|
559 |
+
""" Initialize the weights """
|
560 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
561 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
562 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
563 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
564 |
+
elif isinstance(module, nn.LayerNorm):
|
565 |
+
module.bias.data.zero_()
|
566 |
+
module.weight.data.fill_(1.0)
|
567 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
568 |
+
module.bias.data.zero_()
|
569 |
+
|
570 |
+
|
571 |
+
class BertModel(BertPreTrainedModel):
|
572 |
+
"""
|
573 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
574 |
+
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
575 |
+
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
576 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
577 |
+
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
578 |
+
input to the forward pass.
|
579 |
+
"""
|
580 |
+
|
581 |
+
def __init__(self, config, add_pooling_layer=True):
|
582 |
+
super().__init__(config)
|
583 |
+
self.config = config
|
584 |
+
|
585 |
+
self.embeddings = BertEmbeddings(config)
|
586 |
+
|
587 |
+
self.encoder = BertEncoder(config)
|
588 |
+
|
589 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
590 |
+
|
591 |
+
self.init_weights()
|
592 |
+
|
593 |
+
|
594 |
+
def get_input_embeddings(self):
|
595 |
+
return self.embeddings.word_embeddings
|
596 |
+
|
597 |
+
def set_input_embeddings(self, value):
|
598 |
+
self.embeddings.word_embeddings = value
|
599 |
+
|
600 |
+
def _prune_heads(self, heads_to_prune):
|
601 |
+
"""
|
602 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
603 |
+
class PreTrainedModel
|
604 |
+
"""
|
605 |
+
for layer, heads in heads_to_prune.items():
|
606 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
607 |
+
|
608 |
+
|
609 |
+
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
|
610 |
+
"""
|
611 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
612 |
+
|
613 |
+
Arguments:
|
614 |
+
attention_mask (:obj:`torch.Tensor`):
|
615 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
616 |
+
input_shape (:obj:`Tuple[int]`):
|
617 |
+
The shape of the input to the model.
|
618 |
+
device: (:obj:`torch.device`):
|
619 |
+
The device of the input to the model.
|
620 |
+
|
621 |
+
Returns:
|
622 |
+
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
623 |
+
"""
|
624 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
625 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
626 |
+
if attention_mask.dim() == 3:
|
627 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
628 |
+
elif attention_mask.dim() == 2:
|
629 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
630 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
631 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
632 |
+
if is_decoder:
|
633 |
+
batch_size, seq_length = input_shape
|
634 |
+
|
635 |
+
seq_ids = torch.arange(seq_length, device=device)
|
636 |
+
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
637 |
+
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
638 |
+
# causal and attention masks must have same type with pytorch version < 1.3
|
639 |
+
causal_mask = causal_mask.to(attention_mask.dtype)
|
640 |
+
|
641 |
+
if causal_mask.shape[1] < attention_mask.shape[1]:
|
642 |
+
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
643 |
+
causal_mask = torch.cat(
|
644 |
+
[
|
645 |
+
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
|
646 |
+
causal_mask,
|
647 |
+
],
|
648 |
+
axis=-1,
|
649 |
+
)
|
650 |
+
|
651 |
+
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
652 |
+
else:
|
653 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
654 |
+
else:
|
655 |
+
raise ValueError(
|
656 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
657 |
+
input_shape, attention_mask.shape
|
658 |
+
)
|
659 |
+
)
|
660 |
+
|
661 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
662 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
663 |
+
# positions we want to attend and -10000.0 for masked positions.
|
664 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
665 |
+
# effectively the same as removing these entirely.
|
666 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
667 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
668 |
+
return extended_attention_mask
|
669 |
+
|
670 |
+
def forward(
|
671 |
+
self,
|
672 |
+
input_ids=None,
|
673 |
+
attention_mask=None,
|
674 |
+
position_ids=None,
|
675 |
+
head_mask=None,
|
676 |
+
inputs_embeds=None,
|
677 |
+
encoder_embeds=None,
|
678 |
+
encoder_hidden_states=None,
|
679 |
+
encoder_attention_mask=None,
|
680 |
+
past_key_values=None,
|
681 |
+
use_cache=None,
|
682 |
+
output_attentions=None,
|
683 |
+
output_hidden_states=None,
|
684 |
+
return_dict=None,
|
685 |
+
is_decoder=False,
|
686 |
+
mode='multimodal',
|
687 |
+
):
|
688 |
+
r"""
|
689 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
690 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
691 |
+
the model is configured as a decoder.
|
692 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
693 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
694 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
695 |
+
- 1 for tokens that are **not masked**,
|
696 |
+
- 0 for tokens that are **masked**.
|
697 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
698 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
699 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
700 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
701 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
702 |
+
use_cache (:obj:`bool`, `optional`):
|
703 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
704 |
+
decoding (see :obj:`past_key_values`).
|
705 |
+
"""
|
706 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
707 |
+
output_hidden_states = (
|
708 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
709 |
+
)
|
710 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
711 |
+
|
712 |
+
if is_decoder:
|
713 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
714 |
+
else:
|
715 |
+
use_cache = False
|
716 |
+
|
717 |
+
if input_ids is not None and inputs_embeds is not None:
|
718 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
719 |
+
elif input_ids is not None:
|
720 |
+
input_shape = input_ids.size()
|
721 |
+
batch_size, seq_length = input_shape
|
722 |
+
device = input_ids.device
|
723 |
+
elif inputs_embeds is not None:
|
724 |
+
input_shape = inputs_embeds.size()[:-1]
|
725 |
+
batch_size, seq_length = input_shape
|
726 |
+
device = inputs_embeds.device
|
727 |
+
elif encoder_embeds is not None:
|
728 |
+
input_shape = encoder_embeds.size()[:-1]
|
729 |
+
batch_size, seq_length = input_shape
|
730 |
+
device = encoder_embeds.device
|
731 |
+
else:
|
732 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
|
733 |
+
|
734 |
+
# past_key_values_length
|
735 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
736 |
+
|
737 |
+
if attention_mask is None:
|
738 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
739 |
+
|
740 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
741 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
742 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
|
743 |
+
device, is_decoder)
|
744 |
+
|
745 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
746 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
747 |
+
if encoder_hidden_states is not None:
|
748 |
+
if type(encoder_hidden_states) == list:
|
749 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
750 |
+
else:
|
751 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
752 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
753 |
+
|
754 |
+
if type(encoder_attention_mask) == list:
|
755 |
+
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
756 |
+
elif encoder_attention_mask is None:
|
757 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
758 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
759 |
+
else:
|
760 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
761 |
+
else:
|
762 |
+
encoder_extended_attention_mask = None
|
763 |
+
|
764 |
+
# Prepare head mask if needed
|
765 |
+
# 1.0 in head_mask indicate we keep the head
|
766 |
+
# attention_probs has shape bsz x n_heads x N x N
|
767 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
768 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
769 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
770 |
+
|
771 |
+
if encoder_embeds is None:
|
772 |
+
embedding_output = self.embeddings(
|
773 |
+
input_ids=input_ids,
|
774 |
+
position_ids=position_ids,
|
775 |
+
inputs_embeds=inputs_embeds,
|
776 |
+
past_key_values_length=past_key_values_length,
|
777 |
+
)
|
778 |
+
else:
|
779 |
+
embedding_output = encoder_embeds
|
780 |
+
|
781 |
+
encoder_outputs = self.encoder(
|
782 |
+
embedding_output,
|
783 |
+
attention_mask=extended_attention_mask,
|
784 |
+
head_mask=head_mask,
|
785 |
+
encoder_hidden_states=encoder_hidden_states,
|
786 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
787 |
+
past_key_values=past_key_values,
|
788 |
+
use_cache=use_cache,
|
789 |
+
output_attentions=output_attentions,
|
790 |
+
output_hidden_states=output_hidden_states,
|
791 |
+
return_dict=return_dict,
|
792 |
+
mode=mode,
|
793 |
+
)
|
794 |
+
sequence_output = encoder_outputs[0]
|
795 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
796 |
+
|
797 |
+
if not return_dict:
|
798 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
799 |
+
|
800 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
801 |
+
last_hidden_state=sequence_output,
|
802 |
+
pooler_output=pooled_output,
|
803 |
+
past_key_values=encoder_outputs.past_key_values,
|
804 |
+
hidden_states=encoder_outputs.hidden_states,
|
805 |
+
attentions=encoder_outputs.attentions,
|
806 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
807 |
+
)
|
808 |
+
|
809 |
+
|
810 |
+
|
811 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
812 |
+
|
813 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
814 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
815 |
+
|
816 |
+
def __init__(self, config):
|
817 |
+
super().__init__(config)
|
818 |
+
|
819 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
820 |
+
self.cls = BertOnlyMLMHead(config)
|
821 |
+
|
822 |
+
self.init_weights()
|
823 |
+
|
824 |
+
def get_output_embeddings(self):
|
825 |
+
return self.cls.predictions.decoder
|
826 |
+
|
827 |
+
def set_output_embeddings(self, new_embeddings):
|
828 |
+
self.cls.predictions.decoder = new_embeddings
|
829 |
+
|
830 |
+
def forward(
|
831 |
+
self,
|
832 |
+
input_ids=None,
|
833 |
+
attention_mask=None,
|
834 |
+
position_ids=None,
|
835 |
+
head_mask=None,
|
836 |
+
inputs_embeds=None,
|
837 |
+
encoder_hidden_states=None,
|
838 |
+
encoder_attention_mask=None,
|
839 |
+
labels=None,
|
840 |
+
past_key_values=None,
|
841 |
+
use_cache=None,
|
842 |
+
output_attentions=None,
|
843 |
+
output_hidden_states=None,
|
844 |
+
return_dict=None,
|
845 |
+
return_logits=False,
|
846 |
+
is_decoder=True,
|
847 |
+
reduction='mean',
|
848 |
+
mode='multimodal',
|
849 |
+
):
|
850 |
+
r"""
|
851 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
852 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
853 |
+
the model is configured as a decoder.
|
854 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
855 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
856 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
857 |
+
- 1 for tokens that are **not masked**,
|
858 |
+
- 0 for tokens that are **masked**.
|
859 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
860 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
861 |
+
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
862 |
+
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
863 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
864 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
865 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
866 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
867 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
868 |
+
use_cache (:obj:`bool`, `optional`):
|
869 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
870 |
+
decoding (see :obj:`past_key_values`).
|
871 |
+
Returns:
|
872 |
+
Example::
|
873 |
+
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
874 |
+
>>> import torch
|
875 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
876 |
+
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
877 |
+
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
878 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
879 |
+
>>> outputs = model(**inputs)
|
880 |
+
>>> prediction_logits = outputs.logits
|
881 |
+
"""
|
882 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
883 |
+
if labels is not None:
|
884 |
+
use_cache = False
|
885 |
+
|
886 |
+
outputs = self.bert(
|
887 |
+
input_ids,
|
888 |
+
attention_mask=attention_mask,
|
889 |
+
position_ids=position_ids,
|
890 |
+
head_mask=head_mask,
|
891 |
+
inputs_embeds=inputs_embeds,
|
892 |
+
encoder_hidden_states=encoder_hidden_states,
|
893 |
+
encoder_attention_mask=encoder_attention_mask,
|
894 |
+
past_key_values=past_key_values,
|
895 |
+
use_cache=use_cache,
|
896 |
+
output_attentions=output_attentions,
|
897 |
+
output_hidden_states=output_hidden_states,
|
898 |
+
return_dict=return_dict,
|
899 |
+
is_decoder=is_decoder,
|
900 |
+
mode=mode,
|
901 |
+
)
|
902 |
+
|
903 |
+
sequence_output = outputs[0]
|
904 |
+
prediction_scores = self.cls(sequence_output)
|
905 |
+
|
906 |
+
if return_logits:
|
907 |
+
return prediction_scores[:, :-1, :].contiguous()
|
908 |
+
|
909 |
+
lm_loss = None
|
910 |
+
if labels is not None:
|
911 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
912 |
+
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
913 |
+
labels = labels[:, 1:].contiguous()
|
914 |
+
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
915 |
+
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
916 |
+
if reduction=='none':
|
917 |
+
lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
|
918 |
+
|
919 |
+
if not return_dict:
|
920 |
+
output = (prediction_scores,) + outputs[2:]
|
921 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
922 |
+
|
923 |
+
return CausalLMOutputWithCrossAttentions(
|
924 |
+
loss=lm_loss,
|
925 |
+
logits=prediction_scores,
|
926 |
+
past_key_values=outputs.past_key_values,
|
927 |
+
hidden_states=outputs.hidden_states,
|
928 |
+
attentions=outputs.attentions,
|
929 |
+
cross_attentions=outputs.cross_attentions,
|
930 |
+
)
|
931 |
+
|
932 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
933 |
+
input_shape = input_ids.shape
|
934 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
935 |
+
if attention_mask is None:
|
936 |
+
attention_mask = input_ids.new_ones(input_shape)
|
937 |
+
|
938 |
+
# cut decoder_input_ids if past is used
|
939 |
+
if past is not None:
|
940 |
+
input_ids = input_ids[:, -1:]
|
941 |
+
|
942 |
+
return {
|
943 |
+
"input_ids": input_ids,
|
944 |
+
"attention_mask": attention_mask,
|
945 |
+
"past_key_values": past,
|
946 |
+
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
947 |
+
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
948 |
+
"is_decoder": True,
|
949 |
+
}
|
950 |
+
|
951 |
+
def _reorder_cache(self, past, beam_idx):
|
952 |
+
reordered_past = ()
|
953 |
+
for layer_past in past:
|
954 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
955 |
+
return reordered_past
|
finetune/blip/med_config.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertModel"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"hidden_act": "gelu",
|
7 |
+
"hidden_dropout_prob": 0.1,
|
8 |
+
"hidden_size": 768,
|
9 |
+
"initializer_range": 0.02,
|
10 |
+
"intermediate_size": 3072,
|
11 |
+
"layer_norm_eps": 1e-12,
|
12 |
+
"max_position_embeddings": 512,
|
13 |
+
"model_type": "bert",
|
14 |
+
"num_attention_heads": 12,
|
15 |
+
"num_hidden_layers": 12,
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"type_vocab_size": 2,
|
18 |
+
"vocab_size": 30524,
|
19 |
+
"encoder_width": 768,
|
20 |
+
"add_cross_attention": true
|
21 |
+
}
|
22 |
+
|
finetune/blip/vit.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
* Based on timm code base
|
8 |
+
* https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
9 |
+
'''
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from functools import partial
|
15 |
+
|
16 |
+
from timm.models.vision_transformer import _cfg, PatchEmbed
|
17 |
+
from timm.models.registry import register_model
|
18 |
+
from timm.models.layers import trunc_normal_, DropPath
|
19 |
+
from timm.models.helpers import named_apply, adapt_input_conv
|
20 |
+
|
21 |
+
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
22 |
+
|
23 |
+
class Mlp(nn.Module):
|
24 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
25 |
+
"""
|
26 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
27 |
+
super().__init__()
|
28 |
+
out_features = out_features or in_features
|
29 |
+
hidden_features = hidden_features or in_features
|
30 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
31 |
+
self.act = act_layer()
|
32 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
33 |
+
self.drop = nn.Dropout(drop)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
x = self.fc1(x)
|
37 |
+
x = self.act(x)
|
38 |
+
x = self.drop(x)
|
39 |
+
x = self.fc2(x)
|
40 |
+
x = self.drop(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
class Attention(nn.Module):
|
45 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
46 |
+
super().__init__()
|
47 |
+
self.num_heads = num_heads
|
48 |
+
head_dim = dim // num_heads
|
49 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
50 |
+
self.scale = qk_scale or head_dim ** -0.5
|
51 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
52 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
53 |
+
self.proj = nn.Linear(dim, dim)
|
54 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
55 |
+
self.attn_gradients = None
|
56 |
+
self.attention_map = None
|
57 |
+
|
58 |
+
def save_attn_gradients(self, attn_gradients):
|
59 |
+
self.attn_gradients = attn_gradients
|
60 |
+
|
61 |
+
def get_attn_gradients(self):
|
62 |
+
return self.attn_gradients
|
63 |
+
|
64 |
+
def save_attention_map(self, attention_map):
|
65 |
+
self.attention_map = attention_map
|
66 |
+
|
67 |
+
def get_attention_map(self):
|
68 |
+
return self.attention_map
|
69 |
+
|
70 |
+
def forward(self, x, register_hook=False):
|
71 |
+
B, N, C = x.shape
|
72 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
73 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
74 |
+
|
75 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
76 |
+
attn = attn.softmax(dim=-1)
|
77 |
+
attn = self.attn_drop(attn)
|
78 |
+
|
79 |
+
if register_hook:
|
80 |
+
self.save_attention_map(attn)
|
81 |
+
attn.register_hook(self.save_attn_gradients)
|
82 |
+
|
83 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
84 |
+
x = self.proj(x)
|
85 |
+
x = self.proj_drop(x)
|
86 |
+
return x
|
87 |
+
|
88 |
+
|
89 |
+
class Block(nn.Module):
|
90 |
+
|
91 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
92 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
|
93 |
+
super().__init__()
|
94 |
+
self.norm1 = norm_layer(dim)
|
95 |
+
self.attn = Attention(
|
96 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
97 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
98 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
99 |
+
self.norm2 = norm_layer(dim)
|
100 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
101 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
102 |
+
|
103 |
+
if use_grad_checkpointing:
|
104 |
+
self.attn = checkpoint_wrapper(self.attn)
|
105 |
+
self.mlp = checkpoint_wrapper(self.mlp)
|
106 |
+
|
107 |
+
def forward(self, x, register_hook=False):
|
108 |
+
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
|
109 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class VisionTransformer(nn.Module):
|
114 |
+
""" Vision Transformer
|
115 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
116 |
+
https://arxiv.org/abs/2010.11929
|
117 |
+
"""
|
118 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
119 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
120 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
|
121 |
+
use_grad_checkpointing=False, ckpt_layer=0):
|
122 |
+
"""
|
123 |
+
Args:
|
124 |
+
img_size (int, tuple): input image size
|
125 |
+
patch_size (int, tuple): patch size
|
126 |
+
in_chans (int): number of input channels
|
127 |
+
num_classes (int): number of classes for classification head
|
128 |
+
embed_dim (int): embedding dimension
|
129 |
+
depth (int): depth of transformer
|
130 |
+
num_heads (int): number of attention heads
|
131 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
132 |
+
qkv_bias (bool): enable bias for qkv if True
|
133 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
134 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
135 |
+
drop_rate (float): dropout rate
|
136 |
+
attn_drop_rate (float): attention dropout rate
|
137 |
+
drop_path_rate (float): stochastic depth rate
|
138 |
+
norm_layer: (nn.Module): normalization layer
|
139 |
+
"""
|
140 |
+
super().__init__()
|
141 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
142 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
143 |
+
|
144 |
+
self.patch_embed = PatchEmbed(
|
145 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
146 |
+
|
147 |
+
num_patches = self.patch_embed.num_patches
|
148 |
+
|
149 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
150 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
151 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
152 |
+
|
153 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
154 |
+
self.blocks = nn.ModuleList([
|
155 |
+
Block(
|
156 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
157 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
158 |
+
use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
|
159 |
+
)
|
160 |
+
for i in range(depth)])
|
161 |
+
self.norm = norm_layer(embed_dim)
|
162 |
+
|
163 |
+
trunc_normal_(self.pos_embed, std=.02)
|
164 |
+
trunc_normal_(self.cls_token, std=.02)
|
165 |
+
self.apply(self._init_weights)
|
166 |
+
|
167 |
+
def _init_weights(self, m):
|
168 |
+
if isinstance(m, nn.Linear):
|
169 |
+
trunc_normal_(m.weight, std=.02)
|
170 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
171 |
+
nn.init.constant_(m.bias, 0)
|
172 |
+
elif isinstance(m, nn.LayerNorm):
|
173 |
+
nn.init.constant_(m.bias, 0)
|
174 |
+
nn.init.constant_(m.weight, 1.0)
|
175 |
+
|
176 |
+
@torch.jit.ignore
|
177 |
+
def no_weight_decay(self):
|
178 |
+
return {'pos_embed', 'cls_token'}
|
179 |
+
|
180 |
+
def forward(self, x, register_blk=-1):
|
181 |
+
B = x.shape[0]
|
182 |
+
x = self.patch_embed(x)
|
183 |
+
|
184 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
185 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
186 |
+
|
187 |
+
x = x + self.pos_embed[:,:x.size(1),:]
|
188 |
+
x = self.pos_drop(x)
|
189 |
+
|
190 |
+
for i,blk in enumerate(self.blocks):
|
191 |
+
x = blk(x, register_blk==i)
|
192 |
+
x = self.norm(x)
|
193 |
+
|
194 |
+
return x
|
195 |
+
|
196 |
+
@torch.jit.ignore()
|
197 |
+
def load_pretrained(self, checkpoint_path, prefix=''):
|
198 |
+
_load_weights(self, checkpoint_path, prefix)
|
199 |
+
|
200 |
+
|
201 |
+
@torch.no_grad()
|
202 |
+
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
|
203 |
+
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
|
204 |
+
"""
|
205 |
+
import numpy as np
|
206 |
+
|
207 |
+
def _n2p(w, t=True):
|
208 |
+
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
209 |
+
w = w.flatten()
|
210 |
+
if t:
|
211 |
+
if w.ndim == 4:
|
212 |
+
w = w.transpose([3, 2, 0, 1])
|
213 |
+
elif w.ndim == 3:
|
214 |
+
w = w.transpose([2, 0, 1])
|
215 |
+
elif w.ndim == 2:
|
216 |
+
w = w.transpose([1, 0])
|
217 |
+
return torch.from_numpy(w)
|
218 |
+
|
219 |
+
w = np.load(checkpoint_path)
|
220 |
+
if not prefix and 'opt/target/embedding/kernel' in w:
|
221 |
+
prefix = 'opt/target/'
|
222 |
+
|
223 |
+
if hasattr(model.patch_embed, 'backbone'):
|
224 |
+
# hybrid
|
225 |
+
backbone = model.patch_embed.backbone
|
226 |
+
stem_only = not hasattr(backbone, 'stem')
|
227 |
+
stem = backbone if stem_only else backbone.stem
|
228 |
+
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
|
229 |
+
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
|
230 |
+
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
|
231 |
+
if not stem_only:
|
232 |
+
for i, stage in enumerate(backbone.stages):
|
233 |
+
for j, block in enumerate(stage.blocks):
|
234 |
+
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
|
235 |
+
for r in range(3):
|
236 |
+
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
|
237 |
+
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
|
238 |
+
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
|
239 |
+
if block.downsample is not None:
|
240 |
+
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
|
241 |
+
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
|
242 |
+
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
|
243 |
+
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
|
244 |
+
else:
|
245 |
+
embed_conv_w = adapt_input_conv(
|
246 |
+
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
247 |
+
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
248 |
+
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
249 |
+
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
250 |
+
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
251 |
+
if pos_embed_w.shape != model.pos_embed.shape:
|
252 |
+
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
253 |
+
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
254 |
+
model.pos_embed.copy_(pos_embed_w)
|
255 |
+
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
256 |
+
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
257 |
+
# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
258 |
+
# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
259 |
+
# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
260 |
+
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
261 |
+
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
262 |
+
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
263 |
+
for i, block in enumerate(model.blocks.children()):
|
264 |
+
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
265 |
+
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
266 |
+
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
267 |
+
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
268 |
+
block.attn.qkv.weight.copy_(torch.cat([
|
269 |
+
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
270 |
+
block.attn.qkv.bias.copy_(torch.cat([
|
271 |
+
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
272 |
+
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
273 |
+
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
274 |
+
for r in range(2):
|
275 |
+
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
|
276 |
+
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
|
277 |
+
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
|
278 |
+
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
279 |
+
|
280 |
+
|
281 |
+
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
|
282 |
+
# interpolate position embedding
|
283 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
284 |
+
num_patches = visual_encoder.patch_embed.num_patches
|
285 |
+
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
|
286 |
+
# height (== width) for the checkpoint position embedding
|
287 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
288 |
+
# height (== width) for the new position embedding
|
289 |
+
new_size = int(num_patches ** 0.5)
|
290 |
+
|
291 |
+
if orig_size!=new_size:
|
292 |
+
# class_token and dist_token are kept unchanged
|
293 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
294 |
+
# only the position tokens are interpolated
|
295 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
296 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
297 |
+
pos_tokens = torch.nn.functional.interpolate(
|
298 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
299 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
300 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
301 |
+
print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
|
302 |
+
|
303 |
+
return new_pos_embed
|
304 |
+
else:
|
305 |
+
return pos_embed_checkpoint
|
finetune/clean_captions_and_tags.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# このスクリプトのライセンスは、Apache License 2.0とします
|
2 |
+
# (c) 2022 Kohya S. @kohya_ss
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import glob
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
import re
|
9 |
+
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
|
13 |
+
PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
|
14 |
+
PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ')
|
15 |
+
PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ')
|
16 |
+
|
17 |
+
# 複数人がいるとき、複数の髪色や目の色が定義されていれば削除する
|
18 |
+
PATTERNS_REMOVE_IN_MULTI = [
|
19 |
+
PATTERN_HAIR_LENGTH,
|
20 |
+
PATTERN_HAIR_CUT,
|
21 |
+
re.compile(r', [\w\-]+ eyes, '),
|
22 |
+
re.compile(r', ([\w\-]+ sleeves|sleeveless), '),
|
23 |
+
# 複数の髪型定義がある場合は削除する
|
24 |
+
re.compile(
|
25 |
+
r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '),
|
26 |
+
]
|
27 |
+
|
28 |
+
|
29 |
+
def clean_tags(image_key, tags):
|
30 |
+
# replace '_' to ' '
|
31 |
+
tags = tags.replace('^_^', '^@@@^')
|
32 |
+
tags = tags.replace('_', ' ')
|
33 |
+
tags = tags.replace('^@@@^', '^_^')
|
34 |
+
|
35 |
+
# remove rating: deepdanbooruのみ
|
36 |
+
tokens = tags.split(", rating")
|
37 |
+
if len(tokens) == 1:
|
38 |
+
# WD14 taggerのときはこちらになるのでメッセージは出さない
|
39 |
+
# print("no rating:")
|
40 |
+
# print(f"{image_key} {tags}")
|
41 |
+
pass
|
42 |
+
else:
|
43 |
+
if len(tokens) > 2:
|
44 |
+
print("multiple ratings:")
|
45 |
+
print(f"{image_key} {tags}")
|
46 |
+
tags = tokens[0]
|
47 |
+
|
48 |
+
tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
|
49 |
+
|
50 |
+
# 複数の人物がいる場合は髪色等のタグを削除する
|
51 |
+
if 'girls' in tags or 'boys' in tags:
|
52 |
+
for pat in PATTERNS_REMOVE_IN_MULTI:
|
53 |
+
found = pat.findall(tags)
|
54 |
+
if len(found) > 1: # 二つ以上、タグがある
|
55 |
+
tags = pat.sub("", tags)
|
56 |
+
|
57 |
+
# 髪の特殊対応
|
58 |
+
srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) # 髪の長さタグは例外なので避けておく(全員が同じ髪の長さの場合)
|
59 |
+
if srch_hair_len:
|
60 |
+
org = srch_hair_len.group()
|
61 |
+
tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags)
|
62 |
+
|
63 |
+
found = PATTERN_HAIR.findall(tags)
|
64 |
+
if len(found) > 1:
|
65 |
+
tags = PATTERN_HAIR.sub("", tags)
|
66 |
+
|
67 |
+
if srch_hair_len:
|
68 |
+
tags = tags.replace(", @@@, ", org) # 戻す
|
69 |
+
|
70 |
+
# white shirtとshirtみたいな重複タグの削除
|
71 |
+
found = PATTERN_WORD.findall(tags)
|
72 |
+
for word in found:
|
73 |
+
if re.search(f", ((\w+) )+{word}, ", tags):
|
74 |
+
tags = tags.replace(f", {word}, ", "")
|
75 |
+
|
76 |
+
tags = tags.replace(", , ", ", ")
|
77 |
+
assert tags.startswith(", ") and tags.endswith(", ")
|
78 |
+
tags = tags[2:-2]
|
79 |
+
return tags
|
80 |
+
|
81 |
+
|
82 |
+
# 上から順に検索、置換される
|
83 |
+
# ('置換元文字列', '置換後文字列')
|
84 |
+
CAPTION_REPLACEMENTS = [
|
85 |
+
('anime anime', 'anime'),
|
86 |
+
('young ', ''),
|
87 |
+
('anime girl', 'girl'),
|
88 |
+
('cartoon female', 'girl'),
|
89 |
+
('cartoon lady', 'girl'),
|
90 |
+
('cartoon character', 'girl'), # a or ~s
|
91 |
+
('cartoon woman', 'girl'),
|
92 |
+
('cartoon women', 'girls'),
|
93 |
+
('cartoon girl', 'girl'),
|
94 |
+
('anime female', 'girl'),
|
95 |
+
('anime lady', 'girl'),
|
96 |
+
('anime character', 'girl'), # a or ~s
|
97 |
+
('anime woman', 'girl'),
|
98 |
+
('anime women', 'girls'),
|
99 |
+
('lady', 'girl'),
|
100 |
+
('female', 'girl'),
|
101 |
+
('woman', 'girl'),
|
102 |
+
('women', 'girls'),
|
103 |
+
('people', 'girls'),
|
104 |
+
('person', 'girl'),
|
105 |
+
('a cartoon figure', 'a figure'),
|
106 |
+
('a cartoon image', 'an image'),
|
107 |
+
('a cartoon picture', 'a picture'),
|
108 |
+
('an anime cartoon image', 'an image'),
|
109 |
+
('a cartoon anime drawing', 'a drawing'),
|
110 |
+
('a cartoon drawing', 'a drawing'),
|
111 |
+
('girl girl', 'girl'),
|
112 |
+
]
|
113 |
+
|
114 |
+
|
115 |
+
def clean_caption(caption):
|
116 |
+
for rf, rt in CAPTION_REPLACEMENTS:
|
117 |
+
replaced = True
|
118 |
+
while replaced:
|
119 |
+
bef = caption
|
120 |
+
caption = caption.replace(rf, rt)
|
121 |
+
replaced = bef != caption
|
122 |
+
return caption
|
123 |
+
|
124 |
+
|
125 |
+
def main(args):
|
126 |
+
if os.path.exists(args.in_json):
|
127 |
+
print(f"loading existing metadata: {args.in_json}")
|
128 |
+
with open(args.in_json, "rt", encoding='utf-8') as f:
|
129 |
+
metadata = json.load(f)
|
130 |
+
else:
|
131 |
+
print("no metadata / メタデータファイルがありません")
|
132 |
+
return
|
133 |
+
|
134 |
+
print("cleaning captions and tags.")
|
135 |
+
image_keys = list(metadata.keys())
|
136 |
+
for image_key in tqdm(image_keys):
|
137 |
+
tags = metadata[image_key].get('tags')
|
138 |
+
if tags is None:
|
139 |
+
print(f"image does not have tags / メタデータにタグがありません: {image_key}")
|
140 |
+
else:
|
141 |
+
org = tags
|
142 |
+
tags = clean_tags(image_key, tags)
|
143 |
+
metadata[image_key]['tags'] = tags
|
144 |
+
if args.debug and org != tags:
|
145 |
+
print("FROM: " + org)
|
146 |
+
print("TO: " + tags)
|
147 |
+
|
148 |
+
caption = metadata[image_key].get('caption')
|
149 |
+
if caption is None:
|
150 |
+
print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
|
151 |
+
else:
|
152 |
+
org = caption
|
153 |
+
caption = clean_caption(caption)
|
154 |
+
metadata[image_key]['caption'] = caption
|
155 |
+
if args.debug and org != caption:
|
156 |
+
print("FROM: " + org)
|
157 |
+
print("TO: " + caption)
|
158 |
+
|
159 |
+
# metadataを書き出して終わり
|
160 |
+
print(f"writing metadata: {args.out_json}")
|
161 |
+
with open(args.out_json, "wt", encoding='utf-8') as f:
|
162 |
+
json.dump(metadata, f, indent=2)
|
163 |
+
print("done!")
|
164 |
+
|
165 |
+
|
166 |
+
if __name__ == '__main__':
|
167 |
+
parser = argparse.ArgumentParser()
|
168 |
+
# parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
169 |
+
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
170 |
+
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
171 |
+
parser.add_argument("--debug", action="store_true", help="debug mode")
|
172 |
+
|
173 |
+
args, unknown = parser.parse_known_args()
|
174 |
+
if len(unknown) == 1:
|
175 |
+
print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
|
176 |
+
print("All captions and tags in the metadata are processed.")
|
177 |
+
print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
|
178 |
+
print("メタデータ内のすべてのキャプションとタグが処理されます。")
|
179 |
+
args.in_json = args.out_json
|
180 |
+
args.out_json = unknown[0]
|
181 |
+
elif len(unknown) > 0:
|
182 |
+
raise ValueError(f"error: unrecognized arguments: {unknown}")
|
183 |
+
|
184 |
+
main(args)
|
finetune/hypernetwork_nai.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# NAI compatible
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class HypernetworkModule(torch.nn.Module):
|
7 |
+
def __init__(self, dim, multiplier=1.0):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
linear1 = torch.nn.Linear(dim, dim * 2)
|
11 |
+
linear2 = torch.nn.Linear(dim * 2, dim)
|
12 |
+
linear1.weight.data.normal_(mean=0.0, std=0.01)
|
13 |
+
linear1.bias.data.zero_()
|
14 |
+
linear2.weight.data.normal_(mean=0.0, std=0.01)
|
15 |
+
linear2.bias.data.zero_()
|
16 |
+
linears = [linear1, linear2]
|
17 |
+
|
18 |
+
self.linear = torch.nn.Sequential(*linears)
|
19 |
+
self.multiplier = multiplier
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
return x + self.linear(x) * self.multiplier
|
23 |
+
|
24 |
+
|
25 |
+
class Hypernetwork(torch.nn.Module):
|
26 |
+
enable_sizes = [320, 640, 768, 1280]
|
27 |
+
# return self.modules[Hypernetwork.enable_sizes.index(size)]
|
28 |
+
|
29 |
+
def __init__(self, multiplier=1.0) -> None:
|
30 |
+
super().__init__()
|
31 |
+
self.modules = []
|
32 |
+
for size in Hypernetwork.enable_sizes:
|
33 |
+
self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier)))
|
34 |
+
self.register_module(f"{size}_0", self.modules[-1][0])
|
35 |
+
self.register_module(f"{size}_1", self.modules[-1][1])
|
36 |
+
|
37 |
+
def apply_to_stable_diffusion(self, text_encoder, vae, unet):
|
38 |
+
blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks
|
39 |
+
for block in blocks:
|
40 |
+
for subblk in block:
|
41 |
+
if 'SpatialTransformer' in str(type(subblk)):
|
42 |
+
for tf_block in subblk.transformer_blocks:
|
43 |
+
for attn in [tf_block.attn1, tf_block.attn2]:
|
44 |
+
size = attn.context_dim
|
45 |
+
if size in Hypernetwork.enable_sizes:
|
46 |
+
attn.hypernetwork = self
|
47 |
+
else:
|
48 |
+
attn.hypernetwork = None
|
49 |
+
|
50 |
+
def apply_to_diffusers(self, text_encoder, vae, unet):
|
51 |
+
blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks
|
52 |
+
for block in blocks:
|
53 |
+
if hasattr(block, 'attentions'):
|
54 |
+
for subblk in block.attentions:
|
55 |
+
if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~
|
56 |
+
for tf_block in subblk.transformer_blocks:
|
57 |
+
for attn in [tf_block.attn1, tf_block.attn2]:
|
58 |
+
size = attn.to_k.in_features
|
59 |
+
if size in Hypernetwork.enable_sizes:
|
60 |
+
attn.hypernetwork = self
|
61 |
+
else:
|
62 |
+
attn.hypernetwork = None
|
63 |
+
return True # TODO error checking
|
64 |
+
|
65 |
+
def forward(self, x, context):
|
66 |
+
size = context.shape[-1]
|
67 |
+
assert size in Hypernetwork.enable_sizes
|
68 |
+
module = self.modules[Hypernetwork.enable_sizes.index(size)]
|
69 |
+
return module[0].forward(context), module[1].forward(context)
|
70 |
+
|
71 |
+
def load_from_state_dict(self, state_dict):
|
72 |
+
# old ver to new ver
|
73 |
+
changes = {
|
74 |
+
'linear1.bias': 'linear.0.bias',
|
75 |
+
'linear1.weight': 'linear.0.weight',
|
76 |
+
'linear2.bias': 'linear.1.bias',
|
77 |
+
'linear2.weight': 'linear.1.weight',
|
78 |
+
}
|
79 |
+
for key_from, key_to in changes.items():
|
80 |
+
if key_from in state_dict:
|
81 |
+
state_dict[key_to] = state_dict[key_from]
|
82 |
+
del state_dict[key_from]
|
83 |
+
|
84 |
+
for size, sd in state_dict.items():
|
85 |
+
if type(size) == int:
|
86 |
+
self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True)
|
87 |
+
self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True)
|
88 |
+
return True
|
89 |
+
|
90 |
+
def get_state_dict(self):
|
91 |
+
state_dict = {}
|
92 |
+
for i, size in enumerate(Hypernetwork.enable_sizes):
|
93 |
+
sd0 = self.modules[i][0].state_dict()
|
94 |
+
sd1 = self.modules[i][1].state_dict()
|
95 |
+
state_dict[size] = [sd0, sd1]
|
96 |
+
return state_dict
|
finetune/make_captions.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import random
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
from tqdm import tqdm
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from torchvision import transforms
|
12 |
+
from torchvision.transforms.functional import InterpolationMode
|
13 |
+
from blip.blip import blip_decoder
|
14 |
+
import library.train_util as train_util
|
15 |
+
|
16 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
17 |
+
|
18 |
+
|
19 |
+
IMAGE_SIZE = 384
|
20 |
+
|
21 |
+
# 正方形でいいのか? という気がするがソースがそうなので
|
22 |
+
IMAGE_TRANSFORM = transforms.Compose([
|
23 |
+
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
|
24 |
+
transforms.ToTensor(),
|
25 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
26 |
+
])
|
27 |
+
|
28 |
+
# 共通化したいが微妙に処理が異なる……
|
29 |
+
class ImageLoadingTransformDataset(torch.utils.data.Dataset):
|
30 |
+
def __init__(self, image_paths):
|
31 |
+
self.images = image_paths
|
32 |
+
|
33 |
+
def __len__(self):
|
34 |
+
return len(self.images)
|
35 |
+
|
36 |
+
def __getitem__(self, idx):
|
37 |
+
img_path = self.images[idx]
|
38 |
+
|
39 |
+
try:
|
40 |
+
image = Image.open(img_path).convert("RGB")
|
41 |
+
# convert to tensor temporarily so dataloader will accept it
|
42 |
+
tensor = IMAGE_TRANSFORM(image)
|
43 |
+
except Exception as e:
|
44 |
+
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
45 |
+
return None
|
46 |
+
|
47 |
+
return (tensor, img_path)
|
48 |
+
|
49 |
+
|
50 |
+
def collate_fn_remove_corrupted(batch):
|
51 |
+
"""Collate function that allows to remove corrupted examples in the
|
52 |
+
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
53 |
+
The 'None's in the batch are removed.
|
54 |
+
"""
|
55 |
+
# Filter out all the Nones (corrupted examples)
|
56 |
+
batch = list(filter(lambda x: x is not None, batch))
|
57 |
+
return batch
|
58 |
+
|
59 |
+
|
60 |
+
def main(args):
|
61 |
+
# fix the seed for reproducibility
|
62 |
+
seed = args.seed # + utils.get_rank()
|
63 |
+
torch.manual_seed(seed)
|
64 |
+
np.random.seed(seed)
|
65 |
+
random.seed(seed)
|
66 |
+
|
67 |
+
if not os.path.exists("blip"):
|
68 |
+
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
|
69 |
+
|
70 |
+
cwd = os.getcwd()
|
71 |
+
print('Current Working Directory is: ', cwd)
|
72 |
+
os.chdir('finetune')
|
73 |
+
|
74 |
+
print(f"load images from {args.train_data_dir}")
|
75 |
+
image_paths = train_util.glob_images(args.train_data_dir)
|
76 |
+
print(f"found {len(image_paths)} images.")
|
77 |
+
|
78 |
+
print(f"loading BLIP caption: {args.caption_weights}")
|
79 |
+
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json")
|
80 |
+
model.eval()
|
81 |
+
model = model.to(DEVICE)
|
82 |
+
print("BLIP loaded")
|
83 |
+
|
84 |
+
# captioningする
|
85 |
+
def run_batch(path_imgs):
|
86 |
+
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
|
87 |
+
|
88 |
+
with torch.no_grad():
|
89 |
+
if args.beam_search:
|
90 |
+
captions = model.generate(imgs, sample=False, num_beams=args.num_beams,
|
91 |
+
max_length=args.max_length, min_length=args.min_length)
|
92 |
+
else:
|
93 |
+
captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length)
|
94 |
+
|
95 |
+
for (image_path, _), caption in zip(path_imgs, captions):
|
96 |
+
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
|
97 |
+
f.write(caption + "\n")
|
98 |
+
if args.debug:
|
99 |
+
print(image_path, caption)
|
100 |
+
|
101 |
+
# 読み込みの高速化のためにDataLoaderを使うオプション
|
102 |
+
if args.max_data_loader_n_workers is not None:
|
103 |
+
dataset = ImageLoadingTransformDataset(image_paths)
|
104 |
+
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
|
105 |
+
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
106 |
+
else:
|
107 |
+
data = [[(None, ip)] for ip in image_paths]
|
108 |
+
|
109 |
+
b_imgs = []
|
110 |
+
for data_entry in tqdm(data, smoothing=0.0):
|
111 |
+
for data in data_entry:
|
112 |
+
if data is None:
|
113 |
+
continue
|
114 |
+
|
115 |
+
img_tensor, image_path = data
|
116 |
+
if img_tensor is None:
|
117 |
+
try:
|
118 |
+
raw_image = Image.open(image_path)
|
119 |
+
if raw_image.mode != 'RGB':
|
120 |
+
raw_image = raw_image.convert("RGB")
|
121 |
+
img_tensor = IMAGE_TRANSFORM(raw_image)
|
122 |
+
except Exception as e:
|
123 |
+
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
124 |
+
continue
|
125 |
+
|
126 |
+
b_imgs.append((image_path, img_tensor))
|
127 |
+
if len(b_imgs) >= args.batch_size:
|
128 |
+
run_batch(b_imgs)
|
129 |
+
b_imgs.clear()
|
130 |
+
if len(b_imgs) > 0:
|
131 |
+
run_batch(b_imgs)
|
132 |
+
|
133 |
+
print("done!")
|
134 |
+
|
135 |
+
|
136 |
+
if __name__ == '__main__':
|
137 |
+
parser = argparse.ArgumentParser()
|
138 |
+
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
139 |
+
parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
|
140 |
+
help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)")
|
141 |
+
parser.add_argument("--caption_extention", type=str, default=None,
|
142 |
+
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
143 |
+
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
144 |
+
parser.add_argument("--beam_search", action="store_true",
|
145 |
+
help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)")
|
146 |
+
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
147 |
+
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
|
148 |
+
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
|
149 |
+
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
|
150 |
+
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
|
151 |
+
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
|
152 |
+
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
|
153 |
+
parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed')
|
154 |
+
parser.add_argument("--debug", action="store_true", help="debug mode")
|
155 |
+
|
156 |
+
args = parser.parse_args()
|
157 |
+
|
158 |
+
# スペルミスしていたオプションを復元する
|
159 |
+
if args.caption_extention is not None:
|
160 |
+
args.caption_extension = args.caption_extention
|
161 |
+
|
162 |
+
main(args)
|
finetune/make_captions_by_git.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
from tqdm import tqdm
|
7 |
+
import torch
|
8 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
9 |
+
from transformers.generation.utils import GenerationMixin
|
10 |
+
|
11 |
+
import library.train_util as train_util
|
12 |
+
|
13 |
+
|
14 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
15 |
+
|
16 |
+
PATTERN_REPLACE = [
|
17 |
+
re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
|
18 |
+
re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'),
|
19 |
+
re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"),
|
20 |
+
re.compile(r'with the number \d+ on (it|\w+ \w+)'),
|
21 |
+
re.compile(r'with the words "'),
|
22 |
+
re.compile(r'word \w+ on it'),
|
23 |
+
re.compile(r'that says the word \w+ on it'),
|
24 |
+
re.compile('that says\'the word "( on it)?'),
|
25 |
+
]
|
26 |
+
|
27 |
+
# 誤検知しまくりの with the word xxxx を消す
|
28 |
+
|
29 |
+
|
30 |
+
def remove_words(captions, debug):
|
31 |
+
removed_caps = []
|
32 |
+
for caption in captions:
|
33 |
+
cap = caption
|
34 |
+
for pat in PATTERN_REPLACE:
|
35 |
+
cap = pat.sub("", cap)
|
36 |
+
if debug and cap != caption:
|
37 |
+
print(caption)
|
38 |
+
print(cap)
|
39 |
+
removed_caps.append(cap)
|
40 |
+
return removed_caps
|
41 |
+
|
42 |
+
|
43 |
+
def collate_fn_remove_corrupted(batch):
|
44 |
+
"""Collate function that allows to remove corrupted examples in the
|
45 |
+
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
46 |
+
The 'None's in the batch are removed.
|
47 |
+
"""
|
48 |
+
# Filter out all the Nones (corrupted examples)
|
49 |
+
batch = list(filter(lambda x: x is not None, batch))
|
50 |
+
return batch
|
51 |
+
|
52 |
+
|
53 |
+
def main(args):
|
54 |
+
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
|
55 |
+
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
|
56 |
+
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
|
57 |
+
|
58 |
+
# input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す
|
59 |
+
# ここより上で置き換えようとするとすごく大変
|
60 |
+
def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs):
|
61 |
+
input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs)
|
62 |
+
if input_ids.size()[0] != curr_batch_size[0]:
|
63 |
+
input_ids = input_ids.repeat(curr_batch_size[0], 1)
|
64 |
+
return input_ids
|
65 |
+
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
|
66 |
+
|
67 |
+
print(f"load images from {args.train_data_dir}")
|
68 |
+
image_paths = train_util.glob_images(args.train_data_dir)
|
69 |
+
print(f"found {len(image_paths)} images.")
|
70 |
+
|
71 |
+
# できればcacheに依存せず明示的にダウンロードしたい
|
72 |
+
print(f"loading GIT: {args.model_id}")
|
73 |
+
git_processor = AutoProcessor.from_pretrained(args.model_id)
|
74 |
+
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
|
75 |
+
print("GIT loaded")
|
76 |
+
|
77 |
+
# captioningする
|
78 |
+
def run_batch(path_imgs):
|
79 |
+
imgs = [im for _, im in path_imgs]
|
80 |
+
|
81 |
+
curr_batch_size[0] = len(path_imgs)
|
82 |
+
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
|
83 |
+
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
|
84 |
+
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
|
85 |
+
|
86 |
+
if args.remove_words:
|
87 |
+
captions = remove_words(captions, args.debug)
|
88 |
+
|
89 |
+
for (image_path, _), caption in zip(path_imgs, captions):
|
90 |
+
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
|
91 |
+
f.write(caption + "\n")
|
92 |
+
if args.debug:
|
93 |
+
print(image_path, caption)
|
94 |
+
|
95 |
+
# 読み込みの高速化のためにDataLoaderを使うオプション
|
96 |
+
if args.max_data_loader_n_workers is not None:
|
97 |
+
dataset = train_util.ImageLoadingDataset(image_paths)
|
98 |
+
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
|
99 |
+
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
100 |
+
else:
|
101 |
+
data = [[(None, ip)] for ip in image_paths]
|
102 |
+
|
103 |
+
b_imgs = []
|
104 |
+
for data_entry in tqdm(data, smoothing=0.0):
|
105 |
+
for data in data_entry:
|
106 |
+
if data is None:
|
107 |
+
continue
|
108 |
+
|
109 |
+
image, image_path = data
|
110 |
+
if image is None:
|
111 |
+
try:
|
112 |
+
image = Image.open(image_path)
|
113 |
+
if image.mode != 'RGB':
|
114 |
+
image = image.convert("RGB")
|
115 |
+
except Exception as e:
|
116 |
+
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
117 |
+
continue
|
118 |
+
|
119 |
+
b_imgs.append((image_path, image))
|
120 |
+
if len(b_imgs) >= args.batch_size:
|
121 |
+
run_batch(b_imgs)
|
122 |
+
b_imgs.clear()
|
123 |
+
|
124 |
+
if len(b_imgs) > 0:
|
125 |
+
run_batch(b_imgs)
|
126 |
+
|
127 |
+
print("done!")
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == '__main__':
|
131 |
+
parser = argparse.ArgumentParser()
|
132 |
+
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
133 |
+
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
134 |
+
parser.add_argument("--model_id", type=str, default="microsoft/git-large-textcaps",
|
135 |
+
help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID")
|
136 |
+
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
137 |
+
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
|
138 |
+
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
|
139 |
+
parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
|
140 |
+
parser.add_argument("--remove_words", action="store_true",
|
141 |
+
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
|
142 |
+
parser.add_argument("--debug", action="store_true", help="debug mode")
|
143 |
+
|
144 |
+
args = parser.parse_args()
|
145 |
+
main(args)
|
finetune/merge_captions_to_metadata.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import List
|
5 |
+
from tqdm import tqdm
|
6 |
+
import library.train_util as train_util
|
7 |
+
|
8 |
+
|
9 |
+
def main(args):
|
10 |
+
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
|
11 |
+
|
12 |
+
train_data_dir_path = Path(args.train_data_dir)
|
13 |
+
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
14 |
+
print(f"found {len(image_paths)} images.")
|
15 |
+
|
16 |
+
if args.in_json is None and Path(args.out_json).is_file():
|
17 |
+
args.in_json = args.out_json
|
18 |
+
|
19 |
+
if args.in_json is not None:
|
20 |
+
print(f"loading existing metadata: {args.in_json}")
|
21 |
+
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
|
22 |
+
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
|
23 |
+
else:
|
24 |
+
print("new metadata will be created / 新しいメタデータファイルが作成されます")
|
25 |
+
metadata = {}
|
26 |
+
|
27 |
+
print("merge caption texts to metadata json.")
|
28 |
+
for image_path in tqdm(image_paths):
|
29 |
+
caption_path = image_path.with_suffix(args.caption_extension)
|
30 |
+
caption = caption_path.read_text(encoding='utf-8').strip()
|
31 |
+
|
32 |
+
image_key = str(image_path) if args.full_path else image_path.stem
|
33 |
+
if image_key not in metadata:
|
34 |
+
metadata[image_key] = {}
|
35 |
+
|
36 |
+
metadata[image_key]['caption'] = caption
|
37 |
+
if args.debug:
|
38 |
+
print(image_key, caption)
|
39 |
+
|
40 |
+
# metadataを書き出して終わり
|
41 |
+
print(f"writing metadata: {args.out_json}")
|
42 |
+
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
|
43 |
+
print("done!")
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == '__main__':
|
47 |
+
parser = argparse.ArgumentParser()
|
48 |
+
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
49 |
+
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
50 |
+
parser.add_argument("--in_json", type=str,
|
51 |
+
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
|
52 |
+
parser.add_argument("--caption_extention", type=str, default=None,
|
53 |
+
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
54 |
+
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
|
55 |
+
parser.add_argument("--full_path", action="store_true",
|
56 |
+
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
57 |
+
parser.add_argument("--recursive", action="store_true",
|
58 |
+
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
|
59 |
+
parser.add_argument("--debug", action="store_true", help="debug mode")
|
60 |
+
|
61 |
+
args = parser.parse_args()
|
62 |
+
|
63 |
+
# スペルミスしていたオプションを復元する
|
64 |
+
if args.caption_extention is not None:
|
65 |
+
args.caption_extension = args.caption_extention
|
66 |
+
|
67 |
+
main(args)
|
finetune/merge_dd_tags_to_metadata.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import List
|
5 |
+
from tqdm import tqdm
|
6 |
+
import library.train_util as train_util
|
7 |
+
|
8 |
+
|
9 |
+
def main(args):
|
10 |
+
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
|
11 |
+
|
12 |
+
train_data_dir_path = Path(args.train_data_dir)
|
13 |
+
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
14 |
+
print(f"found {len(image_paths)} images.")
|
15 |
+
|
16 |
+
if args.in_json is None and Path(args.out_json).is_file():
|
17 |
+
args.in_json = args.out_json
|
18 |
+
|
19 |
+
if args.in_json is not None:
|
20 |
+
print(f"loading existing metadata: {args.in_json}")
|
21 |
+
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
|
22 |
+
print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
|
23 |
+
else:
|
24 |
+
print("new metadata will be created / 新しいメタデータファイルが作成されます")
|
25 |
+
metadata = {}
|
26 |
+
|
27 |
+
print("merge tags to metadata json.")
|
28 |
+
for image_path in tqdm(image_paths):
|
29 |
+
tags_path = image_path.with_suffix(args.caption_extension)
|
30 |
+
tags = tags_path.read_text(encoding='utf-8').strip()
|
31 |
+
|
32 |
+
image_key = str(image_path) if args.full_path else image_path.stem
|
33 |
+
if image_key not in metadata:
|
34 |
+
metadata[image_key] = {}
|
35 |
+
|
36 |
+
metadata[image_key]['tags'] = tags
|
37 |
+
if args.debug:
|
38 |
+
print(image_key, tags)
|
39 |
+
|
40 |
+
# metadataを書き出して終わり
|
41 |
+
print(f"writing metadata: {args.out_json}")
|
42 |
+
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
|
43 |
+
|
44 |
+
print("done!")
|
45 |
+
|
46 |
+
|
47 |
+
if __name__ == '__main__':
|
48 |
+
parser = argparse.ArgumentParser()
|
49 |
+
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
50 |
+
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
51 |
+
parser.add_argument("--in_json", type=str,
|
52 |
+
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
|
53 |
+
parser.add_argument("--full_path", action="store_true",
|
54 |
+
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
55 |
+
parser.add_argument("--recursive", action="store_true",
|
56 |
+
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
|
57 |
+
parser.add_argument("--caption_extension", type=str, default=".txt",
|
58 |
+
help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子")
|
59 |
+
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
|
60 |
+
|
61 |
+
args = parser.parse_args()
|
62 |
+
main(args)
|
finetune/prepare_buckets_latents.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
import cv2
|
9 |
+
import torch
|
10 |
+
from torchvision import transforms
|
11 |
+
|
12 |
+
import library.model_util as model_util
|
13 |
+
import library.train_util as train_util
|
14 |
+
|
15 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
16 |
+
|
17 |
+
IMAGE_TRANSFORMS = transforms.Compose(
|
18 |
+
[
|
19 |
+
transforms.ToTensor(),
|
20 |
+
transforms.Normalize([0.5], [0.5]),
|
21 |
+
]
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
def collate_fn_remove_corrupted(batch):
|
26 |
+
"""Collate function that allows to remove corrupted examples in the
|
27 |
+
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
28 |
+
The 'None's in the batch are removed.
|
29 |
+
"""
|
30 |
+
# Filter out all the Nones (corrupted examples)
|
31 |
+
batch = list(filter(lambda x: x is not None, batch))
|
32 |
+
return batch
|
33 |
+
|
34 |
+
|
35 |
+
def get_latents(vae, images, weight_dtype):
|
36 |
+
img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
|
37 |
+
img_tensors = torch.stack(img_tensors)
|
38 |
+
img_tensors = img_tensors.to(DEVICE, weight_dtype)
|
39 |
+
with torch.no_grad():
|
40 |
+
latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy()
|
41 |
+
return latents
|
42 |
+
|
43 |
+
|
44 |
+
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
|
45 |
+
if is_full_path:
|
46 |
+
base_name = os.path.splitext(os.path.basename(image_key))[0]
|
47 |
+
else:
|
48 |
+
base_name = image_key
|
49 |
+
if flip:
|
50 |
+
base_name += '_flip'
|
51 |
+
return os.path.join(data_dir, base_name)
|
52 |
+
|
53 |
+
|
54 |
+
def main(args):
|
55 |
+
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
|
56 |
+
if args.bucket_reso_steps % 8 > 0:
|
57 |
+
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
|
58 |
+
|
59 |
+
image_paths = train_util.glob_images(args.train_data_dir)
|
60 |
+
print(f"found {len(image_paths)} images.")
|
61 |
+
|
62 |
+
if os.path.exists(args.in_json):
|
63 |
+
print(f"loading existing metadata: {args.in_json}")
|
64 |
+
with open(args.in_json, "rt", encoding='utf-8') as f:
|
65 |
+
metadata = json.load(f)
|
66 |
+
else:
|
67 |
+
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
|
68 |
+
return
|
69 |
+
|
70 |
+
weight_dtype = torch.float32
|
71 |
+
if args.mixed_precision == "fp16":
|
72 |
+
weight_dtype = torch.float16
|
73 |
+
elif args.mixed_precision == "bf16":
|
74 |
+
weight_dtype = torch.bfloat16
|
75 |
+
|
76 |
+
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
|
77 |
+
vae.eval()
|
78 |
+
vae.to(DEVICE, dtype=weight_dtype)
|
79 |
+
|
80 |
+
# bucketのサイズを計算する
|
81 |
+
max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
|
82 |
+
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
83 |
+
|
84 |
+
bucket_manager = train_util.BucketManager(args.bucket_no_upscale, max_reso,
|
85 |
+
args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps)
|
86 |
+
if not args.bucket_no_upscale:
|
87 |
+
bucket_manager.make_buckets()
|
88 |
+
else:
|
89 |
+
print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
|
90 |
+
|
91 |
+
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
|
92 |
+
img_ar_errors = []
|
93 |
+
|
94 |
+
def process_batch(is_last):
|
95 |
+
for bucket in bucket_manager.buckets:
|
96 |
+
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
|
97 |
+
latents = get_latents(vae, [img for _, img in bucket], weight_dtype)
|
98 |
+
assert latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8, \
|
99 |
+
f"latent shape {latents.shape}, {bucket[0][1].shape}"
|
100 |
+
|
101 |
+
for (image_key, _), latent in zip(bucket, latents):
|
102 |
+
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
|
103 |
+
np.savez(npz_file_name, latent)
|
104 |
+
|
105 |
+
# flip
|
106 |
+
if args.flip_aug:
|
107 |
+
latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない
|
108 |
+
|
109 |
+
for (image_key, _), latent in zip(bucket, latents):
|
110 |
+
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
|
111 |
+
np.savez(npz_file_name, latent)
|
112 |
+
else:
|
113 |
+
# remove existing flipped npz
|
114 |
+
for image_key, _ in bucket:
|
115 |
+
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz"
|
116 |
+
if os.path.isfile(npz_file_name):
|
117 |
+
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
|
118 |
+
os.remove(npz_file_name)
|
119 |
+
|
120 |
+
bucket.clear()
|
121 |
+
|
122 |
+
# 読み込みの高速化のためにDataLoaderを使うオプション
|
123 |
+
if args.max_data_loader_n_workers is not None:
|
124 |
+
dataset = train_util.ImageLoadingDataset(image_paths)
|
125 |
+
data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
|
126 |
+
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
127 |
+
else:
|
128 |
+
data = [[(None, ip)] for ip in image_paths]
|
129 |
+
|
130 |
+
bucket_counts = {}
|
131 |
+
for data_entry in tqdm(data, smoothing=0.0):
|
132 |
+
if data_entry[0] is None:
|
133 |
+
continue
|
134 |
+
|
135 |
+
img_tensor, image_path = data_entry[0]
|
136 |
+
if img_tensor is not None:
|
137 |
+
image = transforms.functional.to_pil_image(img_tensor)
|
138 |
+
else:
|
139 |
+
try:
|
140 |
+
image = Image.open(image_path)
|
141 |
+
if image.mode != 'RGB':
|
142 |
+
image = image.convert("RGB")
|
143 |
+
except Exception as e:
|
144 |
+
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
145 |
+
continue
|
146 |
+
|
147 |
+
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
148 |
+
if image_key not in metadata:
|
149 |
+
metadata[image_key] = {}
|
150 |
+
|
151 |
+
# 本当はこのあとの部分もDataSetに持っていけば高速化できるがいろいろ大変
|
152 |
+
|
153 |
+
reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height)
|
154 |
+
img_ar_errors.append(abs(ar_error))
|
155 |
+
bucket_counts[reso] = bucket_counts.get(reso, 0) + 1
|
156 |
+
|
157 |
+
# メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
|
158 |
+
metadata[image_key]['train_resolution'] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
|
159 |
+
|
160 |
+
if not args.bucket_no_upscale:
|
161 |
+
# upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
|
162 |
+
assert resized_size[0] == reso[0] or resized_size[1] == reso[
|
163 |
+
1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
|
164 |
+
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
|
165 |
+
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
|
166 |
+
|
167 |
+
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
|
168 |
+
1], f"internal error resized size is small: {resized_size}, {reso}"
|
169 |
+
|
170 |
+
# 既に存在するファイルがあればshapeを確認して同じならskipする
|
171 |
+
if args.skip_existing:
|
172 |
+
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"]
|
173 |
+
if args.flip_aug:
|
174 |
+
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz")
|
175 |
+
|
176 |
+
found = True
|
177 |
+
for npz_file in npz_files:
|
178 |
+
if not os.path.exists(npz_file):
|
179 |
+
found = False
|
180 |
+
break
|
181 |
+
|
182 |
+
dat = np.load(npz_file)['arr_0']
|
183 |
+
if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認
|
184 |
+
found = False
|
185 |
+
break
|
186 |
+
if found:
|
187 |
+
continue
|
188 |
+
|
189 |
+
# 画像をリサイズしてトリミングする
|
190 |
+
# PILにinter_areaがないのでcv2で……
|
191 |
+
image = np.array(image)
|
192 |
+
if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要?
|
193 |
+
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
|
194 |
+
|
195 |
+
if resized_size[0] > reso[0]:
|
196 |
+
trim_size = resized_size[0] - reso[0]
|
197 |
+
image = image[:, trim_size//2:trim_size//2 + reso[0]]
|
198 |
+
|
199 |
+
if resized_size[1] > reso[1]:
|
200 |
+
trim_size = resized_size[1] - reso[1]
|
201 |
+
image = image[trim_size//2:trim_size//2 + reso[1]]
|
202 |
+
|
203 |
+
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
204 |
+
|
205 |
+
# # debug
|
206 |
+
# cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1])
|
207 |
+
|
208 |
+
# バッチへ追加
|
209 |
+
bucket_manager.add_image(reso, (image_key, image))
|
210 |
+
|
211 |
+
# バッチを推論するか判定して推論する
|
212 |
+
process_batch(False)
|
213 |
+
|
214 |
+
# 残りを処理する
|
215 |
+
process_batch(True)
|
216 |
+
|
217 |
+
bucket_manager.sort()
|
218 |
+
for i, reso in enumerate(bucket_manager.resos):
|
219 |
+
count = bucket_counts.get(reso, 0)
|
220 |
+
if count > 0:
|
221 |
+
print(f"bucket {i} {reso}: {count}")
|
222 |
+
img_ar_errors = np.array(img_ar_errors)
|
223 |
+
print(f"mean ar error: {np.mean(img_ar_errors)}")
|
224 |
+
|
225 |
+
# metadataを書き出して終わり
|
226 |
+
print(f"writing metadata: {args.out_json}")
|
227 |
+
with open(args.out_json, "wt", encoding='utf-8') as f:
|
228 |
+
json.dump(metadata, f, indent=2)
|
229 |
+
print("done!")
|
230 |
+
|
231 |
+
|
232 |
+
if __name__ == '__main__':
|
233 |
+
parser = argparse.ArgumentParser()
|
234 |
+
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
235 |
+
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
236 |
+
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
237 |
+
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
|
238 |
+
parser.add_argument("--v2", action='store_true',
|
239 |
+
help='not used (for backward compatibility) / 使用されません(互換性のため残してあります)')
|
240 |
+
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
241 |
+
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
|
242 |
+
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
|
243 |
+
parser.add_argument("--max_resolution", type=str, default="512,512",
|
244 |
+
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
|
245 |
+
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
246 |
+
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
|
247 |
+
parser.add_argument("--bucket_reso_steps", type=int, default=64,
|
248 |
+
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
|
249 |
+
parser.add_argument("--bucket_no_upscale", action="store_true",
|
250 |
+
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
|
251 |
+
parser.add_argument("--mixed_precision", type=str, default="no",
|
252 |
+
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
253 |
+
parser.add_argument("--full_path", action="store_true",
|
254 |
+
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
255 |
+
parser.add_argument("--flip_aug", action="store_true",
|
256 |
+
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
|
257 |
+
parser.add_argument("--skip_existing", action="store_true",
|
258 |
+
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)")
|
259 |
+
|
260 |
+
args = parser.parse_args()
|
261 |
+
main(args)
|
finetune/tag_images_by_wd14_tagger.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import csv
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
import cv2
|
8 |
+
from tqdm import tqdm
|
9 |
+
import numpy as np
|
10 |
+
from tensorflow.keras.models import load_model
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
import torch
|
13 |
+
|
14 |
+
import library.train_util as train_util
|
15 |
+
|
16 |
+
# from wd14 tagger
|
17 |
+
IMAGE_SIZE = 448
|
18 |
+
|
19 |
+
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
|
20 |
+
DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2'
|
21 |
+
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
|
22 |
+
SUB_DIR = "variables"
|
23 |
+
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
|
24 |
+
CSV_FILE = FILES[-1]
|
25 |
+
|
26 |
+
|
27 |
+
def preprocess_image(image):
|
28 |
+
image = np.array(image)
|
29 |
+
image = image[:, :, ::-1] # RGB->BGR
|
30 |
+
|
31 |
+
# pad to square
|
32 |
+
size = max(image.shape[0:2])
|
33 |
+
pad_x = size - image.shape[1]
|
34 |
+
pad_y = size - image.shape[0]
|
35 |
+
pad_l = pad_x // 2
|
36 |
+
pad_t = pad_y // 2
|
37 |
+
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255)
|
38 |
+
|
39 |
+
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
|
40 |
+
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
|
41 |
+
|
42 |
+
image = image.astype(np.float32)
|
43 |
+
return image
|
44 |
+
|
45 |
+
|
46 |
+
class ImageLoadingPrepDataset(torch.utils.data.Dataset):
|
47 |
+
def __init__(self, image_paths):
|
48 |
+
self.images = image_paths
|
49 |
+
|
50 |
+
def __len__(self):
|
51 |
+
return len(self.images)
|
52 |
+
|
53 |
+
def __getitem__(self, idx):
|
54 |
+
img_path = self.images[idx]
|
55 |
+
|
56 |
+
try:
|
57 |
+
image = Image.open(img_path).convert("RGB")
|
58 |
+
image = preprocess_image(image)
|
59 |
+
tensor = torch.tensor(image)
|
60 |
+
except Exception as e:
|
61 |
+
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
62 |
+
return None
|
63 |
+
|
64 |
+
return (tensor, img_path)
|
65 |
+
|
66 |
+
|
67 |
+
def collate_fn_remove_corrupted(batch):
|
68 |
+
"""Collate function that allows to remove corrupted examples in the
|
69 |
+
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
70 |
+
The 'None's in the batch are removed.
|
71 |
+
"""
|
72 |
+
# Filter out all the Nones (corrupted examples)
|
73 |
+
batch = list(filter(lambda x: x is not None, batch))
|
74 |
+
return batch
|
75 |
+
|
76 |
+
|
77 |
+
def main(args):
|
78 |
+
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
|
79 |
+
# depreacatedの警告が出るけどなくなったらその時
|
80 |
+
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
81 |
+
if not os.path.exists(args.model_dir) or args.force_download:
|
82 |
+
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
83 |
+
for file in FILES:
|
84 |
+
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
|
85 |
+
for file in SUB_DIR_FILES:
|
86 |
+
hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join(
|
87 |
+
args.model_dir, SUB_DIR), force_download=True, force_filename=file)
|
88 |
+
else:
|
89 |
+
print("using existing wd14 tagger model")
|
90 |
+
|
91 |
+
# 画像を読み込む
|
92 |
+
image_paths = train_util.glob_images(args.train_data_dir)
|
93 |
+
print(f"found {len(image_paths)} images.")
|
94 |
+
|
95 |
+
print("loading model and labels")
|
96 |
+
model = load_model(args.model_dir)
|
97 |
+
|
98 |
+
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
|
99 |
+
# 依存ライブラリを増やしたくないので自力で読むよ
|
100 |
+
with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
|
101 |
+
reader = csv.reader(f)
|
102 |
+
l = [row for row in reader]
|
103 |
+
header = l[0] # tag_id,name,category,count
|
104 |
+
rows = l[1:]
|
105 |
+
assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}"
|
106 |
+
|
107 |
+
tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ
|
108 |
+
|
109 |
+
# 推論する
|
110 |
+
def run_batch(path_imgs):
|
111 |
+
imgs = np.array([im for _, im in path_imgs])
|
112 |
+
|
113 |
+
probs = model(imgs, training=False)
|
114 |
+
probs = probs.numpy()
|
115 |
+
|
116 |
+
for (image_path, _), prob in zip(path_imgs, probs):
|
117 |
+
# 最初の4つはratingなので無視する
|
118 |
+
# # First 4 labels are actually ratings: pick one with argmax
|
119 |
+
# ratings_names = label_names[:4]
|
120 |
+
# rating_index = ratings_names["probs"].argmax()
|
121 |
+
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
|
122 |
+
|
123 |
+
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
|
124 |
+
# Everything else is tags: pick any where prediction confidence > threshold
|
125 |
+
tag_text = ""
|
126 |
+
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
|
127 |
+
if p >= args.thresh and i < len(tags):
|
128 |
+
tag_text += ", " + tags[i]
|
129 |
+
|
130 |
+
if len(tag_text) > 0:
|
131 |
+
tag_text = tag_text[2:] # 最初の ", " を消す
|
132 |
+
|
133 |
+
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
|
134 |
+
f.write(tag_text + '\n')
|
135 |
+
if args.debug:
|
136 |
+
print(image_path, tag_text)
|
137 |
+
|
138 |
+
# 読み込みの高速化のためにDataLoaderを使うオプション
|
139 |
+
if args.max_data_loader_n_workers is not None:
|
140 |
+
dataset = ImageLoadingPrepDataset(image_paths)
|
141 |
+
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
|
142 |
+
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
143 |
+
else:
|
144 |
+
data = [[(None, ip)] for ip in image_paths]
|
145 |
+
|
146 |
+
b_imgs = []
|
147 |
+
for data_entry in tqdm(data, smoothing=0.0):
|
148 |
+
for data in data_entry:
|
149 |
+
if data is None:
|
150 |
+
continue
|
151 |
+
|
152 |
+
image, image_path = data
|
153 |
+
if image is not None:
|
154 |
+
image = image.detach().numpy()
|
155 |
+
else:
|
156 |
+
try:
|
157 |
+
image = Image.open(image_path)
|
158 |
+
if image.mode != 'RGB':
|
159 |
+
image = image.convert("RGB")
|
160 |
+
image = preprocess_image(image)
|
161 |
+
except Exception as e:
|
162 |
+
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
163 |
+
continue
|
164 |
+
b_imgs.append((image_path, image))
|
165 |
+
|
166 |
+
if len(b_imgs) >= args.batch_size:
|
167 |
+
run_batch(b_imgs)
|
168 |
+
b_imgs.clear()
|
169 |
+
|
170 |
+
if len(b_imgs) > 0:
|
171 |
+
run_batch(b_imgs)
|
172 |
+
|
173 |
+
print("done!")
|
174 |
+
|
175 |
+
|
176 |
+
if __name__ == '__main__':
|
177 |
+
parser = argparse.ArgumentParser()
|
178 |
+
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
179 |
+
parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
|
180 |
+
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
|
181 |
+
parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
|
182 |
+
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ")
|
183 |
+
parser.add_argument("--force_download", action='store_true',
|
184 |
+
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします")
|
185 |
+
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
|
186 |
+
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
187 |
+
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
|
188 |
+
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
|
189 |
+
parser.add_argument("--caption_extention", type=str, default=None,
|
190 |
+
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
191 |
+
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
192 |
+
parser.add_argument("--debug", action="store_true", help="debug mode")
|
193 |
+
|
194 |
+
args = parser.parse_args()
|
195 |
+
|
196 |
+
# スペルミスしていたオプションを復元する
|
197 |
+
if args.caption_extention is not None:
|
198 |
+
args.caption_extension = args.caption_extention
|
199 |
+
|
200 |
+
main(args)
|
gen_img_diffusers.py
CHANGED
@@ -47,7 +47,7 @@ VGG(
|
|
47 |
"""
|
48 |
|
49 |
import json
|
50 |
-
from typing import List, Optional, Union
|
51 |
import glob
|
52 |
import importlib
|
53 |
import inspect
|
@@ -60,7 +60,6 @@ import math
|
|
60 |
import os
|
61 |
import random
|
62 |
import re
|
63 |
-
from typing import Any, Callable, List, Optional, Union
|
64 |
|
65 |
import diffusers
|
66 |
import numpy as np
|
@@ -81,6 +80,9 @@ from PIL import Image
|
|
81 |
from PIL.PngImagePlugin import PngInfo
|
82 |
|
83 |
import library.model_util as model_util
|
|
|
|
|
|
|
84 |
|
85 |
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
86 |
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
@@ -487,6 +489,9 @@ class PipelineLike():
|
|
487 |
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
488 |
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
489 |
|
|
|
|
|
|
|
490 |
# Textual Inversion
|
491 |
def add_token_replacement(self, target_token_id, rep_token_ids):
|
492 |
self.token_replacements[target_token_id] = rep_token_ids
|
@@ -500,7 +505,11 @@ class PipelineLike():
|
|
500 |
new_tokens.append(token)
|
501 |
return new_tokens
|
502 |
|
|
|
|
|
|
|
503 |
# region xformersとか使う部分:独自に書き換えるので関係なし
|
|
|
504 |
def enable_xformers_memory_efficient_attention(self):
|
505 |
r"""
|
506 |
Enable memory efficient attention as implemented in xformers.
|
@@ -581,6 +590,8 @@ class PipelineLike():
|
|
581 |
latents: Optional[torch.FloatTensor] = None,
|
582 |
max_embeddings_multiples: Optional[int] = 3,
|
583 |
output_type: Optional[str] = "pil",
|
|
|
|
|
584 |
# return_dict: bool = True,
|
585 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
586 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
@@ -672,6 +683,9 @@ class PipelineLike():
|
|
672 |
else:
|
673 |
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
674 |
|
|
|
|
|
|
|
675 |
if strength < 0 or strength > 1:
|
676 |
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
677 |
|
@@ -752,7 +766,7 @@ class PipelineLike():
|
|
752 |
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
|
753 |
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
|
754 |
|
755 |
-
if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None:
|
756 |
if isinstance(clip_guide_images, PIL.Image.Image):
|
757 |
clip_guide_images = [clip_guide_images]
|
758 |
|
@@ -765,7 +779,7 @@ class PipelineLike():
|
|
765 |
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
|
766 |
if len(image_embeddings_clip) == 1:
|
767 |
image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
|
768 |
-
|
769 |
size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
|
770 |
clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
|
771 |
clip_guide_images = torch.cat(clip_guide_images, dim=0)
|
@@ -774,6 +788,10 @@ class PipelineLike():
|
|
774 |
image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
|
775 |
if len(image_embeddings_vgg16) == 1:
|
776 |
image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
|
|
|
|
|
|
|
|
|
777 |
|
778 |
# set timesteps
|
779 |
self.scheduler.set_timesteps(num_inference_steps, self.device)
|
@@ -781,7 +799,6 @@ class PipelineLike():
|
|
781 |
latents_dtype = text_embeddings.dtype
|
782 |
init_latents_orig = None
|
783 |
mask = None
|
784 |
-
noise = None
|
785 |
|
786 |
if init_image is None:
|
787 |
# get the initial random noise unless the user supplied it
|
@@ -813,6 +830,8 @@ class PipelineLike():
|
|
813 |
if isinstance(init_image[0], PIL.Image.Image):
|
814 |
init_image = [preprocess_image(im) for im in init_image]
|
815 |
init_image = torch.cat(init_image)
|
|
|
|
|
816 |
|
817 |
# mask image to tensor
|
818 |
if mask_image is not None:
|
@@ -823,9 +842,24 @@ class PipelineLike():
|
|
823 |
|
824 |
# encode the init image into latents and scale the latents
|
825 |
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
826 |
-
|
827 |
-
|
828 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
829 |
if len(init_latents) == 1:
|
830 |
init_latents = init_latents.repeat((batch_size, 1, 1, 1))
|
831 |
init_latents_orig = init_latents
|
@@ -864,12 +898,21 @@ class PipelineLike():
|
|
864 |
extra_step_kwargs["eta"] = eta
|
865 |
|
866 |
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
|
|
|
|
|
|
|
|
|
867 |
for i, t in enumerate(tqdm(timesteps)):
|
868 |
# expand the latents if we are doing classifier free guidance
|
869 |
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
870 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
|
|
871 |
# predict the noise residual
|
872 |
-
|
|
|
|
|
|
|
|
|
873 |
|
874 |
# perform guidance
|
875 |
if do_classifier_free_guidance:
|
@@ -911,8 +954,19 @@ class PipelineLike():
|
|
911 |
if is_cancelled_callback is not None and is_cancelled_callback():
|
912 |
return None
|
913 |
|
|
|
|
|
|
|
914 |
latents = 1 / 0.18215 * latents
|
915 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
916 |
|
917 |
image = (image / 2 + 0.5).clamp(0, 1)
|
918 |
|
@@ -1799,7 +1853,7 @@ def preprocess_mask(mask):
|
|
1799 |
mask = mask.convert("L")
|
1800 |
w, h = mask.size
|
1801 |
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
1802 |
-
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.LANCZOS)
|
1803 |
mask = np.array(mask).astype(np.float32) / 255.0
|
1804 |
mask = np.tile(mask, (4, 1, 1))
|
1805 |
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
@@ -1817,6 +1871,35 @@ def preprocess_mask(mask):
|
|
1817 |
# return text_encoder
|
1818 |
|
1819 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1820 |
def main(args):
|
1821 |
if args.fp16:
|
1822 |
dtype = torch.float16
|
@@ -1881,10 +1964,7 @@ def main(args):
|
|
1881 |
# tokenizerを読み込む
|
1882 |
print("loading tokenizer")
|
1883 |
if use_stable_diffusion_format:
|
1884 |
-
|
1885 |
-
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
1886 |
-
else:
|
1887 |
-
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
|
1888 |
|
1889 |
# schedulerを用意する
|
1890 |
sched_init_args = {}
|
@@ -1995,11 +2075,13 @@ def main(args):
|
|
1995 |
# networkを組み込む
|
1996 |
if args.network_module:
|
1997 |
networks = []
|
|
|
1998 |
for i, network_module in enumerate(args.network_module):
|
1999 |
print("import network module:", network_module)
|
2000 |
imported_module = importlib.import_module(network_module)
|
2001 |
|
2002 |
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
|
|
2003 |
|
2004 |
net_kwargs = {}
|
2005 |
if args.network_args and i < len(args.network_args):
|
@@ -2014,7 +2096,7 @@ def main(args):
|
|
2014 |
network_weight = args.network_weights[i]
|
2015 |
print("load network weights from:", network_weight)
|
2016 |
|
2017 |
-
if model_util.is_safetensors(network_weight):
|
2018 |
from safetensors.torch import safe_open
|
2019 |
with safe_open(network_weight, framework="pt") as f:
|
2020 |
metadata = f.metadata()
|
@@ -2037,6 +2119,18 @@ def main(args):
|
|
2037 |
else:
|
2038 |
networks = []
|
2039 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2040 |
if args.opt_channels_last:
|
2041 |
print(f"set optimizing: channels last")
|
2042 |
text_encoder.to(memory_format=torch.channels_last)
|
@@ -2050,9 +2144,14 @@ def main(args):
|
|
2050 |
if vgg16_model is not None:
|
2051 |
vgg16_model.to(memory_format=torch.channels_last)
|
2052 |
|
|
|
|
|
|
|
|
|
2053 |
pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
|
2054 |
clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
|
2055 |
vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
|
|
|
2056 |
print("pipeline is ready.")
|
2057 |
|
2058 |
if args.diffusers_xformers:
|
@@ -2186,9 +2285,12 @@ def main(args):
|
|
2186 |
|
2187 |
prev_image = None # for VGG16 guided
|
2188 |
if args.guide_image_path is not None:
|
2189 |
-
print(f"load image for CLIP/VGG16 guidance: {args.guide_image_path}")
|
2190 |
-
guide_images =
|
2191 |
-
|
|
|
|
|
|
|
2192 |
if len(guide_images) == 0:
|
2193 |
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
|
2194 |
guide_images = None
|
@@ -2219,33 +2321,46 @@ def main(args):
|
|
2219 |
iter_seed = random.randint(0, 0x7fffffff)
|
2220 |
|
2221 |
# バッチ処理の関数
|
2222 |
-
def process_batch(batch, highres_fix, highres_1st=False):
|
2223 |
batch_size = len(batch)
|
2224 |
|
2225 |
# highres_fixの処理
|
2226 |
if highres_fix and not highres_1st:
|
2227 |
-
# 1st stage
|
2228 |
-
print("process 1st
|
2229 |
batch_1st = []
|
2230 |
-
for
|
2231 |
-
width_1st = int(width * args.highres_fix_scale + .5)
|
2232 |
-
height_1st = int(height * args.highres_fix_scale + .5)
|
2233 |
width_1st = width_1st - width_1st % 32
|
2234 |
height_1st = height_1st - height_1st % 32
|
2235 |
-
|
|
|
|
|
|
|
2236 |
images_1st = process_batch(batch_1st, True, True)
|
2237 |
|
2238 |
# 2nd stageのバッチを作成して以下処理する
|
2239 |
-
print("process 2nd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2240 |
batch_2nd = []
|
2241 |
-
for i, (
|
2242 |
-
|
2243 |
-
|
2244 |
-
|
|
|
2245 |
batch = batch_2nd
|
2246 |
|
2247 |
-
|
2248 |
-
|
|
|
2249 |
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
2250 |
|
2251 |
prompts = []
|
@@ -2278,7 +2393,7 @@ def main(args):
|
|
2278 |
all_images_are_same = True
|
2279 |
all_masks_are_same = True
|
2280 |
all_guide_images_are_same = True
|
2281 |
-
for i, ((_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
|
2282 |
prompts.append(prompt)
|
2283 |
negative_prompts.append(negative_prompt)
|
2284 |
seeds.append(seed)
|
@@ -2295,9 +2410,13 @@ def main(args):
|
|
2295 |
all_masks_are_same = mask_images[-2] is mask_image
|
2296 |
|
2297 |
if guide_image is not None:
|
2298 |
-
|
2299 |
-
|
2300 |
-
all_guide_images_are_same =
|
|
|
|
|
|
|
|
|
2301 |
|
2302 |
# make start code
|
2303 |
torch.manual_seed(seed)
|
@@ -2320,10 +2439,24 @@ def main(args):
|
|
2320 |
if guide_images is not None and all_guide_images_are_same:
|
2321 |
guide_images = guide_images[0]
|
2322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2323 |
# generate
|
|
|
|
|
|
|
|
|
2324 |
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
|
2325 |
-
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises,
|
2326 |
-
|
|
|
|
|
2327 |
return images
|
2328 |
|
2329 |
# save image
|
@@ -2398,6 +2531,7 @@ def main(args):
|
|
2398 |
strength = 0.8 if args.strength is None else args.strength
|
2399 |
negative_prompt = ""
|
2400 |
clip_prompt = None
|
|
|
2401 |
|
2402 |
prompt_args = prompt.strip().split(' --')
|
2403 |
prompt = prompt_args[0]
|
@@ -2461,6 +2595,15 @@ def main(args):
|
|
2461 |
clip_prompt = m.group(1)
|
2462 |
print(f"clip prompt: {clip_prompt}")
|
2463 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2464 |
except ValueError as ex:
|
2465 |
print(f"Exception in parsing / 解析エラー: {parg}")
|
2466 |
print(ex)
|
@@ -2498,7 +2641,12 @@ def main(args):
|
|
2498 |
mask_image = mask_images[global_step % len(mask_images)]
|
2499 |
|
2500 |
if guide_images is not None:
|
2501 |
-
|
|
|
|
|
|
|
|
|
|
|
2502 |
elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
|
2503 |
if prev_image is None:
|
2504 |
print("Generate 1st image without guide image.")
|
@@ -2506,10 +2654,9 @@ def main(args):
|
|
2506 |
print("Use previous image as guide image.")
|
2507 |
guide_image = prev_image
|
2508 |
|
2509 |
-
|
2510 |
-
|
2511 |
-
|
2512 |
-
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
|
2513 |
process_batch(batch_data, highres_fix)
|
2514 |
batch_data.clear()
|
2515 |
|
@@ -2553,6 +2700,8 @@ if __name__ == '__main__':
|
|
2553 |
parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
|
2554 |
parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
|
2555 |
parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
|
|
|
|
|
2556 |
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
|
2557 |
parser.add_argument('--sampler', type=str, default='ddim',
|
2558 |
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
|
@@ -2564,6 +2713,8 @@ if __name__ == '__main__':
|
|
2564 |
parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
|
2565 |
parser.add_argument("--vae", type=str, default=None,
|
2566 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
|
|
|
|
2567 |
# parser.add_argument("--replace_clip_l14_336", action='store_true',
|
2568 |
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
|
2569 |
parser.add_argument("--seed", type=int, default=None,
|
@@ -2578,12 +2729,15 @@ if __name__ == '__main__':
|
|
2578 |
parser.add_argument("--opt_channels_last", action='store_true',
|
2579 |
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
|
2580 |
parser.add_argument("--network_module", type=str, default=None, nargs='*',
|
2581 |
-
help='
|
2582 |
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
2583 |
-
help='
|
2584 |
-
parser.add_argument("--network_mul", type=float, default=None, nargs='*',
|
|
|
2585 |
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
2586 |
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
|
|
|
|
2587 |
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
|
2588 |
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
|
2589 |
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
@@ -2597,15 +2751,26 @@ if __name__ == '__main__':
|
|
2597 |
help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
|
2598 |
parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
|
2599 |
help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
|
2600 |
-
parser.add_argument("--guide_image_path", type=str, default=None,
|
|
|
2601 |
parser.add_argument("--highres_fix_scale", type=float, default=None,
|
2602 |
help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
|
2603 |
parser.add_argument("--highres_fix_steps", type=int, default=28,
|
2604 |
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
|
2605 |
parser.add_argument("--highres_fix_save_1st", action='store_true',
|
2606 |
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
|
|
|
|
|
2607 |
parser.add_argument("--negative_scale", type=float, default=None,
|
2608 |
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
|
2609 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2610 |
args = parser.parse_args()
|
2611 |
main(args)
|
|
|
47 |
"""
|
48 |
|
49 |
import json
|
50 |
+
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
|
51 |
import glob
|
52 |
import importlib
|
53 |
import inspect
|
|
|
60 |
import os
|
61 |
import random
|
62 |
import re
|
|
|
63 |
|
64 |
import diffusers
|
65 |
import numpy as np
|
|
|
80 |
from PIL.PngImagePlugin import PngInfo
|
81 |
|
82 |
import library.model_util as model_util
|
83 |
+
import library.train_util as train_util
|
84 |
+
import tools.original_control_net as original_control_net
|
85 |
+
from tools.original_control_net import ControlNetInfo
|
86 |
|
87 |
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
88 |
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
|
|
489 |
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
490 |
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
491 |
|
492 |
+
# ControlNet
|
493 |
+
self.control_nets: List[ControlNetInfo] = []
|
494 |
+
|
495 |
# Textual Inversion
|
496 |
def add_token_replacement(self, target_token_id, rep_token_ids):
|
497 |
self.token_replacements[target_token_id] = rep_token_ids
|
|
|
505 |
new_tokens.append(token)
|
506 |
return new_tokens
|
507 |
|
508 |
+
def set_control_nets(self, ctrl_nets):
|
509 |
+
self.control_nets = ctrl_nets
|
510 |
+
|
511 |
# region xformersとか使う部分:独自に書き換えるので関係なし
|
512 |
+
|
513 |
def enable_xformers_memory_efficient_attention(self):
|
514 |
r"""
|
515 |
Enable memory efficient attention as implemented in xformers.
|
|
|
590 |
latents: Optional[torch.FloatTensor] = None,
|
591 |
max_embeddings_multiples: Optional[int] = 3,
|
592 |
output_type: Optional[str] = "pil",
|
593 |
+
vae_batch_size: float = None,
|
594 |
+
return_latents: bool = False,
|
595 |
# return_dict: bool = True,
|
596 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
597 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
|
|
683 |
else:
|
684 |
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
685 |
|
686 |
+
vae_batch_size = batch_size if vae_batch_size is None else (
|
687 |
+
int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size)))
|
688 |
+
|
689 |
if strength < 0 or strength > 1:
|
690 |
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
691 |
|
|
|
766 |
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
|
767 |
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
|
768 |
|
769 |
+
if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None or self.control_nets:
|
770 |
if isinstance(clip_guide_images, PIL.Image.Image):
|
771 |
clip_guide_images = [clip_guide_images]
|
772 |
|
|
|
779 |
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
|
780 |
if len(image_embeddings_clip) == 1:
|
781 |
image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
|
782 |
+
elif self.vgg16_guidance_scale > 0:
|
783 |
size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
|
784 |
clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
|
785 |
clip_guide_images = torch.cat(clip_guide_images, dim=0)
|
|
|
788 |
image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
|
789 |
if len(image_embeddings_vgg16) == 1:
|
790 |
image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
|
791 |
+
else:
|
792 |
+
# ControlNetのhintにguide imageを流用する
|
793 |
+
# 前処理はControlNet側で行う
|
794 |
+
pass
|
795 |
|
796 |
# set timesteps
|
797 |
self.scheduler.set_timesteps(num_inference_steps, self.device)
|
|
|
799 |
latents_dtype = text_embeddings.dtype
|
800 |
init_latents_orig = None
|
801 |
mask = None
|
|
|
802 |
|
803 |
if init_image is None:
|
804 |
# get the initial random noise unless the user supplied it
|
|
|
830 |
if isinstance(init_image[0], PIL.Image.Image):
|
831 |
init_image = [preprocess_image(im) for im in init_image]
|
832 |
init_image = torch.cat(init_image)
|
833 |
+
if isinstance(init_image, list):
|
834 |
+
init_image = torch.stack(init_image)
|
835 |
|
836 |
# mask image to tensor
|
837 |
if mask_image is not None:
|
|
|
842 |
|
843 |
# encode the init image into latents and scale the latents
|
844 |
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
845 |
+
if init_image.size()[2:] == (height // 8, width // 8):
|
846 |
+
init_latents = init_image
|
847 |
+
else:
|
848 |
+
if vae_batch_size >= batch_size:
|
849 |
+
init_latent_dist = self.vae.encode(init_image).latent_dist
|
850 |
+
init_latents = init_latent_dist.sample(generator=generator)
|
851 |
+
else:
|
852 |
+
if torch.cuda.is_available():
|
853 |
+
torch.cuda.empty_cache()
|
854 |
+
init_latents = []
|
855 |
+
for i in tqdm(range(0, batch_size, vae_batch_size)):
|
856 |
+
init_latent_dist = self.vae.encode(init_image[i:i + vae_batch_size]
|
857 |
+
if vae_batch_size > 1 else init_image[i].unsqueeze(0)).latent_dist
|
858 |
+
init_latents.append(init_latent_dist.sample(generator=generator))
|
859 |
+
init_latents = torch.cat(init_latents)
|
860 |
+
|
861 |
+
init_latents = 0.18215 * init_latents
|
862 |
+
|
863 |
if len(init_latents) == 1:
|
864 |
init_latents = init_latents.repeat((batch_size, 1, 1, 1))
|
865 |
init_latents_orig = init_latents
|
|
|
898 |
extra_step_kwargs["eta"] = eta
|
899 |
|
900 |
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
|
901 |
+
|
902 |
+
if self.control_nets:
|
903 |
+
guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
|
904 |
+
|
905 |
for i, t in enumerate(tqdm(timesteps)):
|
906 |
# expand the latents if we are doing classifier free guidance
|
907 |
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
908 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
909 |
+
|
910 |
# predict the noise residual
|
911 |
+
if self.control_nets:
|
912 |
+
noise_pred = original_control_net.call_unet_and_control_net(
|
913 |
+
i, num_latent_input, self.unet, self.control_nets, guided_hints, i / len(timesteps), latent_model_input, t, text_embeddings).sample
|
914 |
+
else:
|
915 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
916 |
|
917 |
# perform guidance
|
918 |
if do_classifier_free_guidance:
|
|
|
954 |
if is_cancelled_callback is not None and is_cancelled_callback():
|
955 |
return None
|
956 |
|
957 |
+
if return_latents:
|
958 |
+
return (latents, False)
|
959 |
+
|
960 |
latents = 1 / 0.18215 * latents
|
961 |
+
if vae_batch_size >= batch_size:
|
962 |
+
image = self.vae.decode(latents).sample
|
963 |
+
else:
|
964 |
+
if torch.cuda.is_available():
|
965 |
+
torch.cuda.empty_cache()
|
966 |
+
images = []
|
967 |
+
for i in tqdm(range(0, batch_size, vae_batch_size)):
|
968 |
+
images.append(self.vae.decode(latents[i:i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample)
|
969 |
+
image = torch.cat(images)
|
970 |
|
971 |
image = (image / 2 + 0.5).clamp(0, 1)
|
972 |
|
|
|
1853 |
mask = mask.convert("L")
|
1854 |
w, h = mask.size
|
1855 |
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
1856 |
+
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS)
|
1857 |
mask = np.array(mask).astype(np.float32) / 255.0
|
1858 |
mask = np.tile(mask, (4, 1, 1))
|
1859 |
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
|
|
1871 |
# return text_encoder
|
1872 |
|
1873 |
|
1874 |
+
class BatchDataBase(NamedTuple):
|
1875 |
+
# バッチ分割が必要ないデータ
|
1876 |
+
step: int
|
1877 |
+
prompt: str
|
1878 |
+
negative_prompt: str
|
1879 |
+
seed: int
|
1880 |
+
init_image: Any
|
1881 |
+
mask_image: Any
|
1882 |
+
clip_prompt: str
|
1883 |
+
guide_image: Any
|
1884 |
+
|
1885 |
+
|
1886 |
+
class BatchDataExt(NamedTuple):
|
1887 |
+
# バッチ分割が必要なデータ
|
1888 |
+
width: int
|
1889 |
+
height: int
|
1890 |
+
steps: int
|
1891 |
+
scale: float
|
1892 |
+
negative_scale: float
|
1893 |
+
strength: float
|
1894 |
+
network_muls: Tuple[float]
|
1895 |
+
|
1896 |
+
|
1897 |
+
class BatchData(NamedTuple):
|
1898 |
+
return_latents: bool
|
1899 |
+
base: BatchDataBase
|
1900 |
+
ext: BatchDataExt
|
1901 |
+
|
1902 |
+
|
1903 |
def main(args):
|
1904 |
if args.fp16:
|
1905 |
dtype = torch.float16
|
|
|
1964 |
# tokenizerを読み込む
|
1965 |
print("loading tokenizer")
|
1966 |
if use_stable_diffusion_format:
|
1967 |
+
tokenizer = train_util.load_tokenizer(args)
|
|
|
|
|
|
|
1968 |
|
1969 |
# schedulerを用意する
|
1970 |
sched_init_args = {}
|
|
|
2075 |
# networkを組み込む
|
2076 |
if args.network_module:
|
2077 |
networks = []
|
2078 |
+
network_default_muls = []
|
2079 |
for i, network_module in enumerate(args.network_module):
|
2080 |
print("import network module:", network_module)
|
2081 |
imported_module = importlib.import_module(network_module)
|
2082 |
|
2083 |
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
2084 |
+
network_default_muls.append(network_mul)
|
2085 |
|
2086 |
net_kwargs = {}
|
2087 |
if args.network_args and i < len(args.network_args):
|
|
|
2096 |
network_weight = args.network_weights[i]
|
2097 |
print("load network weights from:", network_weight)
|
2098 |
|
2099 |
+
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
2100 |
from safetensors.torch import safe_open
|
2101 |
with safe_open(network_weight, framework="pt") as f:
|
2102 |
metadata = f.metadata()
|
|
|
2119 |
else:
|
2120 |
networks = []
|
2121 |
|
2122 |
+
# ControlNetの処理
|
2123 |
+
control_nets: List[ControlNetInfo] = []
|
2124 |
+
if args.control_net_models:
|
2125 |
+
for i, model in enumerate(args.control_net_models):
|
2126 |
+
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
|
2127 |
+
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
|
2128 |
+
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
|
2129 |
+
|
2130 |
+
ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
|
2131 |
+
prep = original_control_net.load_preprocess(prep_type)
|
2132 |
+
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
|
2133 |
+
|
2134 |
if args.opt_channels_last:
|
2135 |
print(f"set optimizing: channels last")
|
2136 |
text_encoder.to(memory_format=torch.channels_last)
|
|
|
2144 |
if vgg16_model is not None:
|
2145 |
vgg16_model.to(memory_format=torch.channels_last)
|
2146 |
|
2147 |
+
for cn in control_nets:
|
2148 |
+
cn.unet.to(memory_format=torch.channels_last)
|
2149 |
+
cn.net.to(memory_format=torch.channels_last)
|
2150 |
+
|
2151 |
pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
|
2152 |
clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
|
2153 |
vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
|
2154 |
+
pipe.set_control_nets(control_nets)
|
2155 |
print("pipeline is ready.")
|
2156 |
|
2157 |
if args.diffusers_xformers:
|
|
|
2285 |
|
2286 |
prev_image = None # for VGG16 guided
|
2287 |
if args.guide_image_path is not None:
|
2288 |
+
print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
|
2289 |
+
guide_images = []
|
2290 |
+
for p in args.guide_image_path:
|
2291 |
+
guide_images.extend(load_images(p))
|
2292 |
+
|
2293 |
+
print(f"loaded {len(guide_images)} guide images for guidance")
|
2294 |
if len(guide_images) == 0:
|
2295 |
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
|
2296 |
guide_images = None
|
|
|
2321 |
iter_seed = random.randint(0, 0x7fffffff)
|
2322 |
|
2323 |
# バッチ処理の関数
|
2324 |
+
def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
|
2325 |
batch_size = len(batch)
|
2326 |
|
2327 |
# highres_fixの処理
|
2328 |
if highres_fix and not highres_1st:
|
2329 |
+
# 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
|
2330 |
+
print("process 1st stage")
|
2331 |
batch_1st = []
|
2332 |
+
for _, base, ext in batch:
|
2333 |
+
width_1st = int(ext.width * args.highres_fix_scale + .5)
|
2334 |
+
height_1st = int(ext.height * args.highres_fix_scale + .5)
|
2335 |
width_1st = width_1st - width_1st % 32
|
2336 |
height_1st = height_1st - height_1st % 32
|
2337 |
+
|
2338 |
+
ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale,
|
2339 |
+
ext.negative_scale, ext.strength, ext.network_muls)
|
2340 |
+
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
|
2341 |
images_1st = process_batch(batch_1st, True, True)
|
2342 |
|
2343 |
# 2nd stageのバッチを作成して以下処理する
|
2344 |
+
print("process 2nd stage")
|
2345 |
+
if args.highres_fix_latents_upscaling:
|
2346 |
+
org_dtype = images_1st.dtype
|
2347 |
+
if images_1st.dtype == torch.bfloat16:
|
2348 |
+
images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
|
2349 |
+
images_1st = torch.nn.functional.interpolate(
|
2350 |
+
images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode='bilinear') # , antialias=True)
|
2351 |
+
images_1st = images_1st.to(org_dtype)
|
2352 |
+
|
2353 |
batch_2nd = []
|
2354 |
+
for i, (bd, image) in enumerate(zip(batch, images_1st)):
|
2355 |
+
if not args.highres_fix_latents_upscaling:
|
2356 |
+
image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
|
2357 |
+
bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext)
|
2358 |
+
batch_2nd.append(bd_2nd)
|
2359 |
batch = batch_2nd
|
2360 |
|
2361 |
+
# このバッチの情報を取り出す
|
2362 |
+
return_latents, (step_first, _, _, _, init_image, mask_image, _, guide_image), \
|
2363 |
+
(width, height, steps, scale, negative_scale, strength, network_muls) = batch[0]
|
2364 |
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
2365 |
|
2366 |
prompts = []
|
|
|
2393 |
all_images_are_same = True
|
2394 |
all_masks_are_same = True
|
2395 |
all_guide_images_are_same = True
|
2396 |
+
for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
|
2397 |
prompts.append(prompt)
|
2398 |
negative_prompts.append(negative_prompt)
|
2399 |
seeds.append(seed)
|
|
|
2410 |
all_masks_are_same = mask_images[-2] is mask_image
|
2411 |
|
2412 |
if guide_image is not None:
|
2413 |
+
if type(guide_image) is list:
|
2414 |
+
guide_images.extend(guide_image)
|
2415 |
+
all_guide_images_are_same = False
|
2416 |
+
else:
|
2417 |
+
guide_images.append(guide_image)
|
2418 |
+
if i > 0 and all_guide_images_are_same:
|
2419 |
+
all_guide_images_are_same = guide_images[-2] is guide_image
|
2420 |
|
2421 |
# make start code
|
2422 |
torch.manual_seed(seed)
|
|
|
2439 |
if guide_images is not None and all_guide_images_are_same:
|
2440 |
guide_images = guide_images[0]
|
2441 |
|
2442 |
+
# ControlNet使用時はguide imageをリサイズする
|
2443 |
+
if control_nets:
|
2444 |
+
# TODO resampleのメソッド
|
2445 |
+
guide_images = guide_images if type(guide_images) == list else [guide_images]
|
2446 |
+
guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images]
|
2447 |
+
if len(guide_images) == 1:
|
2448 |
+
guide_images = guide_images[0]
|
2449 |
+
|
2450 |
# generate
|
2451 |
+
if networks:
|
2452 |
+
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
2453 |
+
n.set_multiplier(m)
|
2454 |
+
|
2455 |
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
|
2456 |
+
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises,
|
2457 |
+
vae_batch_size=args.vae_batch_size, return_latents=return_latents,
|
2458 |
+
clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
|
2459 |
+
if highres_1st and not args.highres_fix_save_1st: # return images or latents
|
2460 |
return images
|
2461 |
|
2462 |
# save image
|
|
|
2531 |
strength = 0.8 if args.strength is None else args.strength
|
2532 |
negative_prompt = ""
|
2533 |
clip_prompt = None
|
2534 |
+
network_muls = None
|
2535 |
|
2536 |
prompt_args = prompt.strip().split(' --')
|
2537 |
prompt = prompt_args[0]
|
|
|
2595 |
clip_prompt = m.group(1)
|
2596 |
print(f"clip prompt: {clip_prompt}")
|
2597 |
continue
|
2598 |
+
|
2599 |
+
m = re.match(r'am ([\d\.\-,]+)', parg, re.IGNORECASE)
|
2600 |
+
if m: # network multiplies
|
2601 |
+
network_muls = [float(v) for v in m.group(1).split(",")]
|
2602 |
+
while len(network_muls) < len(networks):
|
2603 |
+
network_muls.append(network_muls[-1])
|
2604 |
+
print(f"network mul: {network_muls}")
|
2605 |
+
continue
|
2606 |
+
|
2607 |
except ValueError as ex:
|
2608 |
print(f"Exception in parsing / 解析エラー: {parg}")
|
2609 |
print(ex)
|
|
|
2641 |
mask_image = mask_images[global_step % len(mask_images)]
|
2642 |
|
2643 |
if guide_images is not None:
|
2644 |
+
if control_nets: # 複数件の場合あり
|
2645 |
+
c = len(control_nets)
|
2646 |
+
p = global_step % (len(guide_images) // c)
|
2647 |
+
guide_image = guide_images[p * c:p * c + c]
|
2648 |
+
else:
|
2649 |
+
guide_image = guide_images[global_step % len(guide_images)]
|
2650 |
elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
|
2651 |
if prev_image is None:
|
2652 |
print("Generate 1st image without guide image.")
|
|
|
2654 |
print("Use previous image as guide image.")
|
2655 |
guide_image = prev_image
|
2656 |
|
2657 |
+
b1 = BatchData(False, BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
2658 |
+
BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None))
|
2659 |
+
if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
|
|
|
2660 |
process_batch(batch_data, highres_fix)
|
2661 |
batch_data.clear()
|
2662 |
|
|
|
2700 |
parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
|
2701 |
parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
|
2702 |
parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
|
2703 |
+
parser.add_argument("--vae_batch_size", type=float, default=None,
|
2704 |
+
help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率")
|
2705 |
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
|
2706 |
parser.add_argument('--sampler', type=str, default='ddim',
|
2707 |
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
|
|
|
2713 |
parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
|
2714 |
parser.add_argument("--vae", type=str, default=None,
|
2715 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
2716 |
+
parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
|
2717 |
+
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
|
2718 |
# parser.add_argument("--replace_clip_l14_336", action='store_true',
|
2719 |
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
|
2720 |
parser.add_argument("--seed", type=int, default=None,
|
|
|
2729 |
parser.add_argument("--opt_channels_last", action='store_true',
|
2730 |
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
|
2731 |
parser.add_argument("--network_module", type=str, default=None, nargs='*',
|
2732 |
+
help='additional network module to use / 追加ネットワークを使う時そのモジュール名')
|
2733 |
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
2734 |
+
help='additional network weights to load / 追加ネットワークの重み')
|
2735 |
+
parser.add_argument("--network_mul", type=float, default=None, nargs='*',
|
2736 |
+
help='additional network multiplier / 追加ネットワークの効果の倍率')
|
2737 |
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
2738 |
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
2739 |
+
parser.add_argument("--network_show_meta", action='store_true',
|
2740 |
+
help='show metadata of network model / ネットワークモデルのメタデータを表示する')
|
2741 |
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
|
2742 |
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
|
2743 |
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
|
|
2751 |
help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
|
2752 |
parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
|
2753 |
help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
|
2754 |
+
parser.add_argument("--guide_image_path", type=str, default=None, nargs="*",
|
2755 |
+
help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
|
2756 |
parser.add_argument("--highres_fix_scale", type=float, default=None,
|
2757 |
help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
|
2758 |
parser.add_argument("--highres_fix_steps", type=int, default=28,
|
2759 |
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
|
2760 |
parser.add_argument("--highres_fix_save_1st", action='store_true',
|
2761 |
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
|
2762 |
+
parser.add_argument("--highres_fix_latents_upscaling", action='store_true',
|
2763 |
+
help="use latents upscaling for highres fix / highres fixでlatentで拡大する")
|
2764 |
parser.add_argument("--negative_scale", type=float, default=None,
|
2765 |
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
|
2766 |
|
2767 |
+
parser.add_argument("--control_net_models", type=str, default=None, nargs='*',
|
2768 |
+
help='ControlNet models to use / 使用するControlNetのモデル名')
|
2769 |
+
parser.add_argument("--control_net_preps", type=str, default=None, nargs='*',
|
2770 |
+
help='ControlNet preprocess to use / 使用するControlNetのプリプロセス名')
|
2771 |
+
parser.add_argument("--control_net_weights", type=float, default=None, nargs='*', help='ControlNet weights / ControlNetの重み')
|
2772 |
+
parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*',
|
2773 |
+
help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率')
|
2774 |
+
|
2775 |
args = parser.parse_args()
|
2776 |
main(args)
|
library/config_util.py
ADDED
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from dataclasses import (
|
3 |
+
asdict,
|
4 |
+
dataclass,
|
5 |
+
)
|
6 |
+
import functools
|
7 |
+
from textwrap import dedent, indent
|
8 |
+
import json
|
9 |
+
from pathlib import Path
|
10 |
+
# from toolz import curry
|
11 |
+
from typing import (
|
12 |
+
List,
|
13 |
+
Optional,
|
14 |
+
Sequence,
|
15 |
+
Tuple,
|
16 |
+
Union,
|
17 |
+
)
|
18 |
+
|
19 |
+
import toml
|
20 |
+
import voluptuous
|
21 |
+
from voluptuous import (
|
22 |
+
Any,
|
23 |
+
ExactSequence,
|
24 |
+
MultipleInvalid,
|
25 |
+
Object,
|
26 |
+
Required,
|
27 |
+
Schema,
|
28 |
+
)
|
29 |
+
from transformers import CLIPTokenizer
|
30 |
+
|
31 |
+
from . import train_util
|
32 |
+
from .train_util import (
|
33 |
+
DreamBoothSubset,
|
34 |
+
FineTuningSubset,
|
35 |
+
DreamBoothDataset,
|
36 |
+
FineTuningDataset,
|
37 |
+
DatasetGroup,
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
def add_config_arguments(parser: argparse.ArgumentParser):
|
42 |
+
parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル")
|
43 |
+
|
44 |
+
# TODO: inherit Params class in Subset, Dataset
|
45 |
+
|
46 |
+
@dataclass
|
47 |
+
class BaseSubsetParams:
|
48 |
+
image_dir: Optional[str] = None
|
49 |
+
num_repeats: int = 1
|
50 |
+
shuffle_caption: bool = False
|
51 |
+
keep_tokens: int = 0
|
52 |
+
color_aug: bool = False
|
53 |
+
flip_aug: bool = False
|
54 |
+
face_crop_aug_range: Optional[Tuple[float, float]] = None
|
55 |
+
random_crop: bool = False
|
56 |
+
caption_dropout_rate: float = 0.0
|
57 |
+
caption_dropout_every_n_epochs: int = 0
|
58 |
+
caption_tag_dropout_rate: float = 0.0
|
59 |
+
|
60 |
+
@dataclass
|
61 |
+
class DreamBoothSubsetParams(BaseSubsetParams):
|
62 |
+
is_reg: bool = False
|
63 |
+
class_tokens: Optional[str] = None
|
64 |
+
caption_extension: str = ".caption"
|
65 |
+
|
66 |
+
@dataclass
|
67 |
+
class FineTuningSubsetParams(BaseSubsetParams):
|
68 |
+
metadata_file: Optional[str] = None
|
69 |
+
|
70 |
+
@dataclass
|
71 |
+
class BaseDatasetParams:
|
72 |
+
tokenizer: CLIPTokenizer = None
|
73 |
+
max_token_length: int = None
|
74 |
+
resolution: Optional[Tuple[int, int]] = None
|
75 |
+
debug_dataset: bool = False
|
76 |
+
|
77 |
+
@dataclass
|
78 |
+
class DreamBoothDatasetParams(BaseDatasetParams):
|
79 |
+
batch_size: int = 1
|
80 |
+
enable_bucket: bool = False
|
81 |
+
min_bucket_reso: int = 256
|
82 |
+
max_bucket_reso: int = 1024
|
83 |
+
bucket_reso_steps: int = 64
|
84 |
+
bucket_no_upscale: bool = False
|
85 |
+
prior_loss_weight: float = 1.0
|
86 |
+
|
87 |
+
@dataclass
|
88 |
+
class FineTuningDatasetParams(BaseDatasetParams):
|
89 |
+
batch_size: int = 1
|
90 |
+
enable_bucket: bool = False
|
91 |
+
min_bucket_reso: int = 256
|
92 |
+
max_bucket_reso: int = 1024
|
93 |
+
bucket_reso_steps: int = 64
|
94 |
+
bucket_no_upscale: bool = False
|
95 |
+
|
96 |
+
@dataclass
|
97 |
+
class SubsetBlueprint:
|
98 |
+
params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
|
99 |
+
|
100 |
+
@dataclass
|
101 |
+
class DatasetBlueprint:
|
102 |
+
is_dreambooth: bool
|
103 |
+
params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
|
104 |
+
subsets: Sequence[SubsetBlueprint]
|
105 |
+
|
106 |
+
@dataclass
|
107 |
+
class DatasetGroupBlueprint:
|
108 |
+
datasets: Sequence[DatasetBlueprint]
|
109 |
+
@dataclass
|
110 |
+
class Blueprint:
|
111 |
+
dataset_group: DatasetGroupBlueprint
|
112 |
+
|
113 |
+
|
114 |
+
class ConfigSanitizer:
|
115 |
+
# @curry
|
116 |
+
@staticmethod
|
117 |
+
def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
|
118 |
+
Schema(ExactSequence([klass, klass]))(value)
|
119 |
+
return tuple(value)
|
120 |
+
|
121 |
+
# @curry
|
122 |
+
@staticmethod
|
123 |
+
def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
|
124 |
+
Schema(Any(klass, ExactSequence([klass, klass])))(value)
|
125 |
+
try:
|
126 |
+
Schema(klass)(value)
|
127 |
+
return (value, value)
|
128 |
+
except:
|
129 |
+
return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
|
130 |
+
|
131 |
+
# subset schema
|
132 |
+
SUBSET_ASCENDABLE_SCHEMA = {
|
133 |
+
"color_aug": bool,
|
134 |
+
"face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float),
|
135 |
+
"flip_aug": bool,
|
136 |
+
"num_repeats": int,
|
137 |
+
"random_crop": bool,
|
138 |
+
"shuffle_caption": bool,
|
139 |
+
"keep_tokens": int,
|
140 |
+
}
|
141 |
+
# DO means DropOut
|
142 |
+
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
143 |
+
"caption_dropout_every_n_epochs": int,
|
144 |
+
"caption_dropout_rate": Any(float, int),
|
145 |
+
"caption_tag_dropout_rate": Any(float, int),
|
146 |
+
}
|
147 |
+
# DB means DreamBooth
|
148 |
+
DB_SUBSET_ASCENDABLE_SCHEMA = {
|
149 |
+
"caption_extension": str,
|
150 |
+
"class_tokens": str,
|
151 |
+
}
|
152 |
+
DB_SUBSET_DISTINCT_SCHEMA = {
|
153 |
+
Required("image_dir"): str,
|
154 |
+
"is_reg": bool,
|
155 |
+
}
|
156 |
+
# FT means FineTuning
|
157 |
+
FT_SUBSET_DISTINCT_SCHEMA = {
|
158 |
+
Required("metadata_file"): str,
|
159 |
+
"image_dir": str,
|
160 |
+
}
|
161 |
+
|
162 |
+
# datasets schema
|
163 |
+
DATASET_ASCENDABLE_SCHEMA = {
|
164 |
+
"batch_size": int,
|
165 |
+
"bucket_no_upscale": bool,
|
166 |
+
"bucket_reso_steps": int,
|
167 |
+
"enable_bucket": bool,
|
168 |
+
"max_bucket_reso": int,
|
169 |
+
"min_bucket_reso": int,
|
170 |
+
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
171 |
+
}
|
172 |
+
|
173 |
+
# options handled by argparse but not handled by user config
|
174 |
+
ARGPARSE_SPECIFIC_SCHEMA = {
|
175 |
+
"debug_dataset": bool,
|
176 |
+
"max_token_length": Any(None, int),
|
177 |
+
"prior_loss_weight": Any(float, int),
|
178 |
+
}
|
179 |
+
# for handling default None value of argparse
|
180 |
+
ARGPARSE_NULLABLE_OPTNAMES = [
|
181 |
+
"face_crop_aug_range",
|
182 |
+
"resolution",
|
183 |
+
]
|
184 |
+
# prepare map because option name may differ among argparse and user config
|
185 |
+
ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
|
186 |
+
"train_batch_size": "batch_size",
|
187 |
+
"dataset_repeats": "num_repeats",
|
188 |
+
}
|
189 |
+
|
190 |
+
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None:
|
191 |
+
assert support_dreambooth or support_finetuning, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
|
192 |
+
|
193 |
+
self.db_subset_schema = self.__merge_dict(
|
194 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
195 |
+
self.DB_SUBSET_DISTINCT_SCHEMA,
|
196 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA,
|
197 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
198 |
+
)
|
199 |
+
|
200 |
+
self.ft_subset_schema = self.__merge_dict(
|
201 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
202 |
+
self.FT_SUBSET_DISTINCT_SCHEMA,
|
203 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
204 |
+
)
|
205 |
+
|
206 |
+
self.db_dataset_schema = self.__merge_dict(
|
207 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
208 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
209 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA,
|
210 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
211 |
+
{"subsets": [self.db_subset_schema]},
|
212 |
+
)
|
213 |
+
|
214 |
+
self.ft_dataset_schema = self.__merge_dict(
|
215 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
216 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
217 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
218 |
+
{"subsets": [self.ft_subset_schema]},
|
219 |
+
)
|
220 |
+
|
221 |
+
if support_dreambooth and support_finetuning:
|
222 |
+
def validate_flex_dataset(dataset_config: dict):
|
223 |
+
subsets_config = dataset_config.get("subsets", [])
|
224 |
+
|
225 |
+
# check dataset meets FT style
|
226 |
+
# NOTE: all FT subsets should have "metadata_file"
|
227 |
+
if all(["metadata_file" in subset for subset in subsets_config]):
|
228 |
+
return Schema(self.ft_dataset_schema)(dataset_config)
|
229 |
+
# check dataset meets DB style
|
230 |
+
# NOTE: all DB subsets should have no "metadata_file"
|
231 |
+
elif all(["metadata_file" not in subset for subset in subsets_config]):
|
232 |
+
return Schema(self.db_dataset_schema)(dataset_config)
|
233 |
+
else:
|
234 |
+
raise voluptuous.Invalid("DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。")
|
235 |
+
|
236 |
+
self.dataset_schema = validate_flex_dataset
|
237 |
+
elif support_dreambooth:
|
238 |
+
self.dataset_schema = self.db_dataset_schema
|
239 |
+
else:
|
240 |
+
self.dataset_schema = self.ft_dataset_schema
|
241 |
+
|
242 |
+
self.general_schema = self.__merge_dict(
|
243 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
244 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
245 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
|
246 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
247 |
+
)
|
248 |
+
|
249 |
+
self.user_config_validator = Schema({
|
250 |
+
"general": self.general_schema,
|
251 |
+
"datasets": [self.dataset_schema],
|
252 |
+
})
|
253 |
+
|
254 |
+
self.argparse_schema = self.__merge_dict(
|
255 |
+
self.general_schema,
|
256 |
+
self.ARGPARSE_SPECIFIC_SCHEMA,
|
257 |
+
{optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES},
|
258 |
+
{a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()},
|
259 |
+
)
|
260 |
+
|
261 |
+
self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
|
262 |
+
|
263 |
+
def sanitize_user_config(self, user_config: dict) -> dict:
|
264 |
+
try:
|
265 |
+
return self.user_config_validator(user_config)
|
266 |
+
except MultipleInvalid:
|
267 |
+
# TODO: エラー発生時のメッセージをわかりやすくする
|
268 |
+
print("Invalid user config / ユーザ設定の形式が正しくないようです")
|
269 |
+
raise
|
270 |
+
|
271 |
+
# NOTE: In nature, argument parser result is not needed to be sanitize
|
272 |
+
# However this will help us to detect program bug
|
273 |
+
def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
|
274 |
+
try:
|
275 |
+
return self.argparse_config_validator(argparse_namespace)
|
276 |
+
except MultipleInvalid:
|
277 |
+
# XXX: this should be a bug
|
278 |
+
print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
|
279 |
+
raise
|
280 |
+
|
281 |
+
# NOTE: value would be overwritten by latter dict if there is already the same key
|
282 |
+
@staticmethod
|
283 |
+
def __merge_dict(*dict_list: dict) -> dict:
|
284 |
+
merged = {}
|
285 |
+
for schema in dict_list:
|
286 |
+
# merged |= schema
|
287 |
+
for k, v in schema.items():
|
288 |
+
merged[k] = v
|
289 |
+
return merged
|
290 |
+
|
291 |
+
|
292 |
+
class BlueprintGenerator:
|
293 |
+
BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {
|
294 |
+
}
|
295 |
+
|
296 |
+
def __init__(self, sanitizer: ConfigSanitizer):
|
297 |
+
self.sanitizer = sanitizer
|
298 |
+
|
299 |
+
# runtime_params is for parameters which is only configurable on runtime, such as tokenizer
|
300 |
+
def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
|
301 |
+
sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
|
302 |
+
sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
|
303 |
+
|
304 |
+
# convert argparse namespace to dict like config
|
305 |
+
# NOTE: it is ok to have extra entries in dict
|
306 |
+
optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
|
307 |
+
argparse_config = {optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()}
|
308 |
+
|
309 |
+
general_config = sanitized_user_config.get("general", {})
|
310 |
+
|
311 |
+
dataset_blueprints = []
|
312 |
+
for dataset_config in sanitized_user_config.get("datasets", []):
|
313 |
+
# NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
|
314 |
+
subsets = dataset_config.get("subsets", [])
|
315 |
+
is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
|
316 |
+
if is_dreambooth:
|
317 |
+
subset_params_klass = DreamBoothSubsetParams
|
318 |
+
dataset_params_klass = DreamBoothDatasetParams
|
319 |
+
else:
|
320 |
+
subset_params_klass = FineTuningSubsetParams
|
321 |
+
dataset_params_klass = FineTuningDatasetParams
|
322 |
+
|
323 |
+
subset_blueprints = []
|
324 |
+
for subset_config in subsets:
|
325 |
+
params = self.generate_params_by_fallbacks(subset_params_klass,
|
326 |
+
[subset_config, dataset_config, general_config, argparse_config, runtime_params])
|
327 |
+
subset_blueprints.append(SubsetBlueprint(params))
|
328 |
+
|
329 |
+
params = self.generate_params_by_fallbacks(dataset_params_klass,
|
330 |
+
[dataset_config, general_config, argparse_config, runtime_params])
|
331 |
+
dataset_blueprints.append(DatasetBlueprint(is_dreambooth, params, subset_blueprints))
|
332 |
+
|
333 |
+
dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
|
334 |
+
|
335 |
+
return Blueprint(dataset_group_blueprint)
|
336 |
+
|
337 |
+
@staticmethod
|
338 |
+
def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
|
339 |
+
name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
|
340 |
+
search_value = BlueprintGenerator.search_value
|
341 |
+
default_params = asdict(param_klass())
|
342 |
+
param_names = default_params.keys()
|
343 |
+
|
344 |
+
params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
|
345 |
+
|
346 |
+
return param_klass(**params)
|
347 |
+
|
348 |
+
@staticmethod
|
349 |
+
def search_value(key: str, fallbacks: Sequence[dict], default_value = None):
|
350 |
+
for cand in fallbacks:
|
351 |
+
value = cand.get(key)
|
352 |
+
if value is not None:
|
353 |
+
return value
|
354 |
+
|
355 |
+
return default_value
|
356 |
+
|
357 |
+
|
358 |
+
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
|
359 |
+
datasets: List[Union[DreamBoothDataset, FineTuningDataset]] = []
|
360 |
+
|
361 |
+
for dataset_blueprint in dataset_group_blueprint.datasets:
|
362 |
+
if dataset_blueprint.is_dreambooth:
|
363 |
+
subset_klass = DreamBoothSubset
|
364 |
+
dataset_klass = DreamBoothDataset
|
365 |
+
else:
|
366 |
+
subset_klass = FineTuningSubset
|
367 |
+
dataset_klass = FineTuningDataset
|
368 |
+
|
369 |
+
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
370 |
+
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
|
371 |
+
datasets.append(dataset)
|
372 |
+
|
373 |
+
# print info
|
374 |
+
info = ""
|
375 |
+
for i, dataset in enumerate(datasets):
|
376 |
+
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
377 |
+
info += dedent(f"""\
|
378 |
+
[Dataset {i}]
|
379 |
+
batch_size: {dataset.batch_size}
|
380 |
+
resolution: {(dataset.width, dataset.height)}
|
381 |
+
enable_bucket: {dataset.enable_bucket}
|
382 |
+
""")
|
383 |
+
|
384 |
+
if dataset.enable_bucket:
|
385 |
+
info += indent(dedent(f"""\
|
386 |
+
min_bucket_reso: {dataset.min_bucket_reso}
|
387 |
+
max_bucket_reso: {dataset.max_bucket_reso}
|
388 |
+
bucket_reso_steps: {dataset.bucket_reso_steps}
|
389 |
+
bucket_no_upscale: {dataset.bucket_no_upscale}
|
390 |
+
\n"""), " ")
|
391 |
+
else:
|
392 |
+
info += "\n"
|
393 |
+
|
394 |
+
for j, subset in enumerate(dataset.subsets):
|
395 |
+
info += indent(dedent(f"""\
|
396 |
+
[Subset {j} of Dataset {i}]
|
397 |
+
image_dir: "{subset.image_dir}"
|
398 |
+
image_count: {subset.img_count}
|
399 |
+
num_repeats: {subset.num_repeats}
|
400 |
+
shuffle_caption: {subset.shuffle_caption}
|
401 |
+
keep_tokens: {subset.keep_tokens}
|
402 |
+
caption_dropout_rate: {subset.caption_dropout_rate}
|
403 |
+
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
404 |
+
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
405 |
+
color_aug: {subset.color_aug}
|
406 |
+
flip_aug: {subset.flip_aug}
|
407 |
+
face_crop_aug_range: {subset.face_crop_aug_range}
|
408 |
+
random_crop: {subset.random_crop}
|
409 |
+
"""), " ")
|
410 |
+
|
411 |
+
if is_dreambooth:
|
412 |
+
info += indent(dedent(f"""\
|
413 |
+
is_reg: {subset.is_reg}
|
414 |
+
class_tokens: {subset.class_tokens}
|
415 |
+
caption_extension: {subset.caption_extension}
|
416 |
+
\n"""), " ")
|
417 |
+
else:
|
418 |
+
info += indent(dedent(f"""\
|
419 |
+
metadata_file: {subset.metadata_file}
|
420 |
+
\n"""), " ")
|
421 |
+
|
422 |
+
print(info)
|
423 |
+
|
424 |
+
# make buckets first because it determines the length of dataset
|
425 |
+
for i, dataset in enumerate(datasets):
|
426 |
+
print(f"[Dataset {i}]")
|
427 |
+
dataset.make_buckets()
|
428 |
+
|
429 |
+
return DatasetGroup(datasets)
|
430 |
+
|
431 |
+
|
432 |
+
def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
|
433 |
+
def extract_dreambooth_params(name: str) -> Tuple[int, str]:
|
434 |
+
tokens = name.split('_')
|
435 |
+
try:
|
436 |
+
n_repeats = int(tokens[0])
|
437 |
+
except ValueError as e:
|
438 |
+
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
|
439 |
+
return 0, ""
|
440 |
+
caption_by_folder = '_'.join(tokens[1:])
|
441 |
+
return n_repeats, caption_by_folder
|
442 |
+
|
443 |
+
def generate(base_dir: Optional[str], is_reg: bool):
|
444 |
+
if base_dir is None:
|
445 |
+
return []
|
446 |
+
|
447 |
+
base_dir: Path = Path(base_dir)
|
448 |
+
if not base_dir.is_dir():
|
449 |
+
return []
|
450 |
+
|
451 |
+
subsets_config = []
|
452 |
+
for subdir in base_dir.iterdir():
|
453 |
+
if not subdir.is_dir():
|
454 |
+
continue
|
455 |
+
|
456 |
+
num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
|
457 |
+
if num_repeats < 1:
|
458 |
+
continue
|
459 |
+
|
460 |
+
subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens}
|
461 |
+
subsets_config.append(subset_config)
|
462 |
+
|
463 |
+
return subsets_config
|
464 |
+
|
465 |
+
subsets_config = []
|
466 |
+
subsets_config += generate(train_data_dir, False)
|
467 |
+
subsets_config += generate(reg_data_dir, True)
|
468 |
+
|
469 |
+
return subsets_config
|
470 |
+
|
471 |
+
|
472 |
+
def load_user_config(file: str) -> dict:
|
473 |
+
file: Path = Path(file)
|
474 |
+
if not file.is_file():
|
475 |
+
raise ValueError(f"file not found / ファイルが見つかりません: {file}")
|
476 |
+
|
477 |
+
if file.name.lower().endswith('.json'):
|
478 |
+
try:
|
479 |
+
config = json.load(file)
|
480 |
+
except Exception:
|
481 |
+
print(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
|
482 |
+
raise
|
483 |
+
elif file.name.lower().endswith('.toml'):
|
484 |
+
try:
|
485 |
+
config = toml.load(file)
|
486 |
+
except Exception:
|
487 |
+
print(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
|
488 |
+
raise
|
489 |
+
else:
|
490 |
+
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
|
491 |
+
|
492 |
+
return config
|
493 |
+
|
494 |
+
|
495 |
+
# for config test
|
496 |
+
if __name__ == "__main__":
|
497 |
+
parser = argparse.ArgumentParser()
|
498 |
+
parser.add_argument("--support_dreambooth", action="store_true")
|
499 |
+
parser.add_argument("--support_finetuning", action="store_true")
|
500 |
+
parser.add_argument("--support_dropout", action="store_true")
|
501 |
+
parser.add_argument("dataset_config")
|
502 |
+
config_args, remain = parser.parse_known_args()
|
503 |
+
|
504 |
+
parser = argparse.ArgumentParser()
|
505 |
+
train_util.add_dataset_arguments(parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout)
|
506 |
+
train_util.add_training_arguments(parser, config_args.support_dreambooth)
|
507 |
+
argparse_namespace = parser.parse_args(remain)
|
508 |
+
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
|
509 |
+
|
510 |
+
print("[argparse_namespace]")
|
511 |
+
print(vars(argparse_namespace))
|
512 |
+
|
513 |
+
user_config = load_user_config(config_args.dataset_config)
|
514 |
+
|
515 |
+
print("\n[user_config]")
|
516 |
+
print(user_config)
|
517 |
+
|
518 |
+
sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout)
|
519 |
+
sanitized_user_config = sanitizer.sanitize_user_config(user_config)
|
520 |
+
|
521 |
+
print("\n[sanitized_user_config]")
|
522 |
+
print(sanitized_user_config)
|
523 |
+
|
524 |
+
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
|
525 |
+
|
526 |
+
print("\n[blueprint]")
|
527 |
+
print(blueprint)
|
library/train_util.py
CHANGED
@@ -1,12 +1,21 @@
|
|
1 |
# common functions for training
|
2 |
|
3 |
import argparse
|
|
|
4 |
import json
|
|
|
5 |
import shutil
|
6 |
import time
|
7 |
-
from typing import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from accelerate import Accelerator
|
9 |
-
from torch.autograd.function import Function
|
10 |
import glob
|
11 |
import math
|
12 |
import os
|
@@ -17,10 +26,16 @@ from io import BytesIO
|
|
17 |
|
18 |
from tqdm import tqdm
|
19 |
import torch
|
|
|
20 |
from torchvision import transforms
|
21 |
from transformers import CLIPTokenizer
|
|
|
22 |
import diffusers
|
23 |
-
from diffusers import
|
|
|
|
|
|
|
|
|
24 |
import albumentations as albu
|
25 |
import numpy as np
|
26 |
from PIL import Image
|
@@ -195,23 +210,93 @@ class BucketBatchIndex(NamedTuple):
|
|
195 |
batch_index: int
|
196 |
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
class BaseDataset(torch.utils.data.Dataset):
|
199 |
-
def __init__(self, tokenizer, max_token_length
|
200 |
super().__init__()
|
201 |
-
self.tokenizer
|
202 |
self.max_token_length = max_token_length
|
203 |
-
self.shuffle_caption = shuffle_caption
|
204 |
-
self.shuffle_keep_tokens = shuffle_keep_tokens
|
205 |
# width/height is used when enable_bucket==False
|
206 |
self.width, self.height = (None, None) if resolution is None else resolution
|
207 |
-
self.face_crop_aug_range = face_crop_aug_range
|
208 |
-
self.flip_aug = flip_aug
|
209 |
-
self.color_aug = color_aug
|
210 |
self.debug_dataset = debug_dataset
|
211 |
-
|
|
|
|
|
212 |
self.token_padding_disabled = False
|
213 |
-
self.dataset_dirs_info = {}
|
214 |
-
self.reg_dataset_dirs_info = {}
|
215 |
self.tag_frequency = {}
|
216 |
|
217 |
self.enable_bucket = False
|
@@ -225,49 +310,28 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
225 |
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
226 |
|
227 |
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
228 |
-
self.dropout_rate: float = 0
|
229 |
-
self.dropout_every_n_epochs: int = None
|
230 |
-
self.tag_dropout_rate: float = 0
|
231 |
|
232 |
# augmentation
|
233 |
-
|
234 |
-
if color_aug:
|
235 |
-
# わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
|
236 |
-
self.aug = albu.Compose([
|
237 |
-
albu.OneOf([
|
238 |
-
albu.HueSaturationValue(8, 0, 0, p=.5),
|
239 |
-
albu.RandomGamma((95, 105), p=.5),
|
240 |
-
], p=.33),
|
241 |
-
albu.HorizontalFlip(p=flip_p)
|
242 |
-
], p=1.)
|
243 |
-
elif flip_aug:
|
244 |
-
self.aug = albu.Compose([
|
245 |
-
albu.HorizontalFlip(p=flip_p)
|
246 |
-
], p=1.)
|
247 |
-
else:
|
248 |
-
self.aug = None
|
249 |
|
250 |
self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
|
251 |
|
252 |
self.image_data: Dict[str, ImageInfo] = {}
|
|
|
253 |
|
254 |
self.replacements = {}
|
255 |
|
256 |
def set_current_epoch(self, epoch):
|
257 |
self.current_epoch = epoch
|
258 |
-
|
259 |
-
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
|
260 |
-
# コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
|
261 |
-
self.dropout_rate = dropout_rate
|
262 |
-
self.dropout_every_n_epochs = dropout_every_n_epochs
|
263 |
-
self.tag_dropout_rate = tag_dropout_rate
|
264 |
|
265 |
def set_tag_frequency(self, dir_name, captions):
|
266 |
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
267 |
self.tag_frequency[dir_name] = frequency_for_dir
|
268 |
for caption in captions:
|
269 |
for tag in caption.split(","):
|
270 |
-
|
|
|
271 |
tag = tag.lower()
|
272 |
frequency = frequency_for_dir.get(tag, 0)
|
273 |
frequency_for_dir[tag] = frequency + 1
|
@@ -278,42 +342,36 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
278 |
def add_replacement(self, str_from, str_to):
|
279 |
self.replacements[str_from] = str_to
|
280 |
|
281 |
-
def process_caption(self, caption):
|
282 |
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
283 |
-
is_drop_out =
|
284 |
-
is_drop_out = is_drop_out or
|
285 |
|
286 |
if is_drop_out:
|
287 |
caption = ""
|
288 |
else:
|
289 |
-
if
|
290 |
def dropout_tags(tokens):
|
291 |
-
if
|
292 |
return tokens
|
293 |
l = []
|
294 |
for token in tokens:
|
295 |
-
if random.random() >=
|
296 |
l.append(token)
|
297 |
return l
|
298 |
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
tokens = dropout_tags(tokens)
|
305 |
-
else:
|
306 |
-
if len(tokens) > self.shuffle_keep_tokens:
|
307 |
-
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
308 |
-
tokens = tokens[self.shuffle_keep_tokens:]
|
309 |
|
310 |
-
|
311 |
-
|
312 |
|
313 |
-
|
314 |
|
315 |
-
|
316 |
-
caption = ", ".join(tokens)
|
317 |
|
318 |
# textual inversion対応
|
319 |
for str_from, str_to in self.replacements.items():
|
@@ -367,8 +425,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
367 |
input_ids = torch.stack(iids_list) # 3,77
|
368 |
return input_ids
|
369 |
|
370 |
-
def register_image(self, info: ImageInfo):
|
371 |
self.image_data[info.image_key] = info
|
|
|
372 |
|
373 |
def make_buckets(self):
|
374 |
'''
|
@@ -467,7 +526,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
467 |
img = np.array(image, np.uint8)
|
468 |
return img
|
469 |
|
470 |
-
def trim_and_resize_if_required(self, image, reso, resized_size):
|
471 |
image_height, image_width = image.shape[0:2]
|
472 |
|
473 |
if image_width != resized_size[0] or image_height != resized_size[1]:
|
@@ -477,22 +536,27 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
477 |
image_height, image_width = image.shape[0:2]
|
478 |
if image_width > reso[0]:
|
479 |
trim_size = image_width - reso[0]
|
480 |
-
p = trim_size // 2 if not
|
481 |
# print("w", trim_size, p)
|
482 |
image = image[:, p:p + reso[0]]
|
483 |
if image_height > reso[1]:
|
484 |
trim_size = image_height - reso[1]
|
485 |
-
p = trim_size // 2 if not
|
486 |
# print("h", trim_size, p)
|
487 |
image = image[p:p + reso[1]]
|
488 |
|
489 |
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
490 |
return image
|
491 |
|
|
|
|
|
|
|
492 |
def cache_latents(self, vae):
|
493 |
# TODO ここを高速化したい
|
494 |
print("caching latents.")
|
495 |
for info in tqdm(self.image_data.values()):
|
|
|
|
|
496 |
if info.latents_npz is not None:
|
497 |
info.latents = self.load_latents_from_npz(info, False)
|
498 |
info.latents = torch.FloatTensor(info.latents)
|
@@ -502,13 +566,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
502 |
continue
|
503 |
|
504 |
image = self.load_image(info.absolute_path)
|
505 |
-
image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
|
506 |
|
507 |
img_tensor = self.image_transforms(image)
|
508 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
509 |
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
510 |
|
511 |
-
if
|
512 |
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
|
513 |
img_tensor = self.image_transforms(image)
|
514 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
@@ -518,11 +582,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
518 |
image = Image.open(image_path)
|
519 |
return image.size
|
520 |
|
521 |
-
def load_image_with_face_info(self, image_path: str):
|
522 |
img = self.load_image(image_path)
|
523 |
|
524 |
face_cx = face_cy = face_w = face_h = 0
|
525 |
-
if
|
526 |
tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
|
527 |
if len(tokens) >= 5:
|
528 |
face_cx = int(tokens[-4])
|
@@ -533,7 +597,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
533 |
return img, face_cx, face_cy, face_w, face_h
|
534 |
|
535 |
# いい感じに切り出す
|
536 |
-
def crop_target(self, image, face_cx, face_cy, face_w, face_h):
|
537 |
height, width = image.shape[0:2]
|
538 |
if height == self.height and width == self.width:
|
539 |
return image
|
@@ -541,8 +605,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
541 |
# 画像サイズはsizeより大きいのでリサイズする
|
542 |
face_size = max(face_w, face_h)
|
543 |
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
|
544 |
-
min_scale = min(1.0, max(min_scale, self.size / (face_size *
|
545 |
-
max_scale = min(1.0, max(min_scale, self.size / (face_size *
|
546 |
if min_scale >= max_scale: # range指定がmin==max
|
547 |
scale = min_scale
|
548 |
else:
|
@@ -560,13 +624,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
560 |
for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
|
561 |
p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
|
562 |
|
563 |
-
if
|
564 |
# 背景も含めるために顔を中心に置く確率を高めつつずらす
|
565 |
range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
|
566 |
p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
|
567 |
else:
|
568 |
# range指定があるときのみ、すこしだけランダムに(わりと適当)
|
569 |
-
if
|
570 |
if face_size > self.size // 10 and face_size >= 40:
|
571 |
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
|
572 |
|
@@ -589,9 +653,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
589 |
return self._length
|
590 |
|
591 |
def __getitem__(self, index):
|
592 |
-
if index == 0:
|
593 |
-
self.shuffle_buckets()
|
594 |
-
|
595 |
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
596 |
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
597 |
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
@@ -604,28 +665,29 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
604 |
|
605 |
for image_key in bucket[image_index:image_index + bucket_batch_size]:
|
606 |
image_info = self.image_data[image_key]
|
|
|
607 |
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
608 |
|
609 |
# image/latentsを処理する
|
610 |
if image_info.latents is not None:
|
611 |
-
latents = image_info.latents if not
|
612 |
image = None
|
613 |
elif image_info.latents_npz is not None:
|
614 |
-
latents = self.load_latents_from_npz(image_info,
|
615 |
latents = torch.FloatTensor(latents)
|
616 |
image = None
|
617 |
else:
|
618 |
# 画像を読み込み、必要ならcropする
|
619 |
-
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
|
620 |
im_h, im_w = img.shape[0:2]
|
621 |
|
622 |
if self.enable_bucket:
|
623 |
-
img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
|
624 |
else:
|
625 |
if face_cx > 0: # 顔位置情報あり
|
626 |
-
img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
|
627 |
elif im_h > self.height or im_w > self.width:
|
628 |
-
assert
|
629 |
if im_h > self.height:
|
630 |
p = random.randint(0, im_h - self.height)
|
631 |
img = img[p:p + self.height]
|
@@ -637,8 +699,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
637 |
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
638 |
|
639 |
# augmentation
|
640 |
-
|
641 |
-
|
|
|
642 |
|
643 |
latents = None
|
644 |
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
@@ -646,7 +709,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
646 |
images.append(image)
|
647 |
latents_list.append(latents)
|
648 |
|
649 |
-
caption = self.process_caption(image_info.caption)
|
650 |
captions.append(caption)
|
651 |
if not self.token_padding_disabled: # this option might be omitted in future
|
652 |
input_ids_list.append(self.get_input_ids(caption))
|
@@ -677,9 +740,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
677 |
|
678 |
|
679 |
class DreamBoothDataset(BaseDataset):
|
680 |
-
def __init__(self,
|
681 |
-
super().__init__(tokenizer, max_token_length,
|
682 |
-
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
683 |
|
684 |
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
685 |
|
@@ -702,7 +764,7 @@ class DreamBoothDataset(BaseDataset):
|
|
702 |
self.bucket_reso_steps = None # この情報は使われない
|
703 |
self.bucket_no_upscale = False
|
704 |
|
705 |
-
def read_caption(img_path):
|
706 |
# captionの候補ファイル名を作る
|
707 |
base_name = os.path.splitext(img_path)[0]
|
708 |
base_name_face_det = base_name
|
@@ -725,153 +787,171 @@ class DreamBoothDataset(BaseDataset):
|
|
725 |
break
|
726 |
return caption
|
727 |
|
728 |
-
def load_dreambooth_dir(
|
729 |
-
if not os.path.isdir(
|
730 |
-
|
731 |
-
return
|
732 |
-
|
733 |
-
tokens = os.path.basename(dir).split('_')
|
734 |
-
try:
|
735 |
-
n_repeats = int(tokens[0])
|
736 |
-
except ValueError as e:
|
737 |
-
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
|
738 |
-
return 0, [], []
|
739 |
|
740 |
-
|
741 |
-
|
742 |
-
print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
|
743 |
|
744 |
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
745 |
captions = []
|
746 |
for img_path in img_paths:
|
747 |
-
cap_for_img = read_caption(img_path)
|
748 |
-
|
749 |
-
|
750 |
-
|
|
|
|
|
|
|
|
|
751 |
|
752 |
-
return
|
753 |
|
754 |
-
print("prepare
|
755 |
-
train_dirs = os.listdir(train_data_dir)
|
756 |
num_train_images = 0
|
757 |
-
|
758 |
-
|
759 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
760 |
|
761 |
for img_path, caption in zip(img_paths, captions):
|
762 |
-
info = ImageInfo(img_path,
|
763 |
-
|
|
|
|
|
|
|
764 |
|
765 |
-
|
|
|
766 |
|
767 |
print(f"{num_train_images} train images with repeating.")
|
768 |
self.num_train_images = num_train_images
|
769 |
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
print("prepare reg images.")
|
774 |
-
reg_infos: List[ImageInfo] = []
|
775 |
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
780 |
|
781 |
-
|
782 |
-
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
783 |
-
reg_infos.append(info)
|
784 |
|
785 |
-
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
786 |
|
787 |
-
|
788 |
-
|
789 |
-
|
|
|
|
|
|
|
|
|
|
|
790 |
|
791 |
-
|
792 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
793 |
else:
|
794 |
-
|
795 |
-
n = 0
|
796 |
-
first_loop = True
|
797 |
-
while n < num_train_images:
|
798 |
-
for info in reg_infos:
|
799 |
-
if first_loop:
|
800 |
-
self.register_image(info)
|
801 |
-
n += info.num_repeats
|
802 |
-
else:
|
803 |
-
info.num_repeats += 1
|
804 |
-
n += 1
|
805 |
-
if n >= num_train_images:
|
806 |
-
break
|
807 |
-
first_loop = False
|
808 |
|
809 |
-
|
|
|
|
|
810 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
811 |
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
with open(json_file_name, "rt", encoding='utf-8') as f:
|
821 |
-
metadata = json.load(f)
|
822 |
-
else:
|
823 |
-
raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}")
|
824 |
|
825 |
-
|
826 |
-
|
827 |
-
self.batch_size = batch_size
|
828 |
|
829 |
-
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
else:
|
835 |
-
# わりといい加減だがいい方法が思いつかん
|
836 |
-
abs_path = glob_images(train_data_dir, image_key)
|
837 |
-
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
838 |
-
abs_path = abs_path[0]
|
839 |
-
|
840 |
-
caption = img_md.get('caption')
|
841 |
-
tags = img_md.get('tags')
|
842 |
-
if caption is None:
|
843 |
-
caption = tags
|
844 |
-
elif tags is not None and len(tags) > 0:
|
845 |
-
caption = caption + ', ' + tags
|
846 |
-
tags_list.append(tags)
|
847 |
-
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
|
848 |
-
|
849 |
-
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
|
850 |
-
image_info.image_size = img_md.get('train_resolution')
|
851 |
-
|
852 |
-
if not self.color_aug and not self.random_crop:
|
853 |
-
# if npz exists, use them
|
854 |
-
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
|
855 |
-
|
856 |
-
self.register_image(image_info)
|
857 |
-
self.num_train_images = len(metadata) * dataset_repeats
|
858 |
-
self.num_reg_images = 0
|
859 |
|
860 |
-
|
861 |
-
|
862 |
-
|
|
|
|
|
|
|
863 |
|
864 |
# check existence of all npz files
|
865 |
-
use_npz_latents = not
|
866 |
if use_npz_latents:
|
|
|
867 |
npz_any = False
|
868 |
npz_all = True
|
|
|
869 |
for image_info in self.image_data.values():
|
|
|
|
|
870 |
has_npz = image_info.latents_npz is not None
|
871 |
npz_any = npz_any or has_npz
|
872 |
|
873 |
-
if
|
874 |
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
|
|
875 |
npz_all = npz_all and has_npz
|
876 |
|
877 |
if npz_any and not npz_all:
|
@@ -883,7 +963,7 @@ class FineTuningDataset(BaseDataset):
|
|
883 |
elif not npz_all:
|
884 |
use_npz_latents = False
|
885 |
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
|
886 |
-
if
|
887 |
print("maybe no flipped files / ��転されたnpzファイルがないのかもしれません")
|
888 |
# else:
|
889 |
# print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
@@ -929,7 +1009,7 @@ class FineTuningDataset(BaseDataset):
|
|
929 |
for image_info in self.image_data.values():
|
930 |
image_info.latents_npz = image_info.latents_npz_flipped = None
|
931 |
|
932 |
-
def image_key_to_npz_file(self, image_key):
|
933 |
base_name = os.path.splitext(image_key)[0]
|
934 |
npz_file_norm = base_name + '.npz'
|
935 |
|
@@ -941,8 +1021,8 @@ class FineTuningDataset(BaseDataset):
|
|
941 |
return npz_file_norm, npz_file_flip
|
942 |
|
943 |
# image_key is relative path
|
944 |
-
npz_file_norm = os.path.join(
|
945 |
-
npz_file_flip = os.path.join(
|
946 |
|
947 |
if not os.path.exists(npz_file_norm):
|
948 |
npz_file_norm = None
|
@@ -953,13 +1033,60 @@ class FineTuningDataset(BaseDataset):
|
|
953 |
return npz_file_norm, npz_file_flip
|
954 |
|
955 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
956 |
def debug_dataset(train_dataset, show_input_ids=False):
|
957 |
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
958 |
print("Escape for exit. / Escキーで中断、終了します")
|
959 |
|
960 |
train_dataset.set_current_epoch(1)
|
961 |
k = 0
|
962 |
-
|
|
|
|
|
|
|
963 |
if example['latents'] is not None:
|
964 |
print(f"sample has latents from npz file: {example['latents'].size()}")
|
965 |
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
|
@@ -1364,6 +1491,35 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
|
1364 |
help='enable v-parameterization training / v-parameterization学習を有効にする')
|
1365 |
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
|
1366 |
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1367 |
|
1368 |
|
1369 |
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
@@ -1387,10 +1543,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|
1387 |
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
1388 |
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
1389 |
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
1390 |
-
parser.add_argument("--use_8bit_adam", action="store_true",
|
1391 |
-
help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
1392 |
-
parser.add_argument("--use_lion_optimizer", action="store_true",
|
1393 |
-
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
1394 |
parser.add_argument("--mem_eff_attn", action="store_true",
|
1395 |
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
1396 |
parser.add_argument("--xformers", action="store_true",
|
@@ -1398,7 +1550,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|
1398 |
parser.add_argument("--vae", type=str, default=None,
|
1399 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
1400 |
|
1401 |
-
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
1402 |
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
1403 |
parser.add_argument("--max_train_epochs", type=int, default=None,
|
1404 |
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
@@ -1419,15 +1570,23 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|
1419 |
parser.add_argument("--logging_dir", type=str, default=None,
|
1420 |
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
1421 |
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
1422 |
-
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
1423 |
-
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
|
1424 |
-
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
1425 |
-
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
1426 |
parser.add_argument("--noise_offset", type=float, default=None,
|
1427 |
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
|
1428 |
parser.add_argument("--lowram", action="store_true",
|
1429 |
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
|
1430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1431 |
if support_dreambooth:
|
1432 |
# DreamBooth training
|
1433 |
parser.add_argument("--prior_loss_weight", type=float, default=1.0,
|
@@ -1449,8 +1608,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
|
1449 |
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
|
1450 |
parser.add_argument("--caption_extention", type=str, default=None,
|
1451 |
help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
|
1452 |
-
parser.add_argument("--keep_tokens", type=int, default=
|
1453 |
-
help="keep heading N tokens when shuffling caption tokens / caption
|
1454 |
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
|
1455 |
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
|
1456 |
parser.add_argument("--face_crop_aug_range", type=str, default=None,
|
@@ -1475,11 +1634,11 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
|
1475 |
if support_caption_dropout:
|
1476 |
# Textual Inversion はcaptionのdropoutをsupportしない
|
1477 |
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
1478 |
-
parser.add_argument("--caption_dropout_rate", type=float, default=0,
|
1479 |
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
1480 |
-
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=
|
1481 |
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
1482 |
-
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
|
1483 |
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
|
1484 |
|
1485 |
if support_dreambooth:
|
@@ -1504,16 +1663,249 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
|
1504 |
# region utils
|
1505 |
|
1506 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1507 |
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
1508 |
# backward compatibility
|
1509 |
if args.caption_extention is not None:
|
1510 |
args.caption_extension = args.caption_extention
|
1511 |
args.caption_extention = None
|
1512 |
|
1513 |
-
if args.cache_latents:
|
1514 |
-
assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
|
1515 |
-
assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
|
1516 |
-
|
1517 |
# assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
|
1518 |
if args.resolution is not None:
|
1519 |
args.resolution = tuple([int(r) for r in args.resolution.split(',')])
|
@@ -1536,12 +1928,28 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
|
1536 |
|
1537 |
def load_tokenizer(args: argparse.Namespace):
|
1538 |
print("prepare tokenizer")
|
1539 |
-
if args.v2
|
1540 |
-
|
1541 |
-
|
1542 |
-
|
1543 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1544 |
print(f"update token length: {args.max_token_length}")
|
|
|
|
|
|
|
|
|
|
|
1545 |
return tokenizer
|
1546 |
|
1547 |
|
@@ -1592,13 +2000,19 @@ def prepare_dtype(args: argparse.Namespace):
|
|
1592 |
|
1593 |
|
1594 |
def load_target_model(args: argparse.Namespace, weight_dtype):
|
1595 |
-
|
|
|
|
|
1596 |
if load_stable_diffusion_format:
|
1597 |
print("load StableDiffusion checkpoint")
|
1598 |
-
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2,
|
1599 |
else:
|
1600 |
print("load Diffusers pretrained models")
|
1601 |
-
|
|
|
|
|
|
|
|
|
1602 |
text_encoder = pipe.text_encoder
|
1603 |
vae = pipe.vae
|
1604 |
unet = pipe.unet
|
@@ -1767,6 +2181,185 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
|
1767 |
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
1768 |
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
1769 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1770 |
# endregion
|
1771 |
|
1772 |
# region 前処理用
|
|
|
1 |
# common functions for training
|
2 |
|
3 |
import argparse
|
4 |
+
import importlib
|
5 |
import json
|
6 |
+
import re
|
7 |
import shutil
|
8 |
import time
|
9 |
+
from typing import (
|
10 |
+
Dict,
|
11 |
+
List,
|
12 |
+
NamedTuple,
|
13 |
+
Optional,
|
14 |
+
Sequence,
|
15 |
+
Tuple,
|
16 |
+
Union,
|
17 |
+
)
|
18 |
from accelerate import Accelerator
|
|
|
19 |
import glob
|
20 |
import math
|
21 |
import os
|
|
|
26 |
|
27 |
from tqdm import tqdm
|
28 |
import torch
|
29 |
+
from torch.optim import Optimizer
|
30 |
from torchvision import transforms
|
31 |
from transformers import CLIPTokenizer
|
32 |
+
import transformers
|
33 |
import diffusers
|
34 |
+
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
35 |
+
from diffusers import (StableDiffusionPipeline, DDPMScheduler,
|
36 |
+
EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler,
|
37 |
+
LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler,
|
38 |
+
KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler)
|
39 |
import albumentations as albu
|
40 |
import numpy as np
|
41 |
from PIL import Image
|
|
|
210 |
batch_index: int
|
211 |
|
212 |
|
213 |
+
class AugHelper:
|
214 |
+
def __init__(self):
|
215 |
+
# prepare all possible augmentators
|
216 |
+
color_aug_method = albu.OneOf([
|
217 |
+
albu.HueSaturationValue(8, 0, 0, p=.5),
|
218 |
+
albu.RandomGamma((95, 105), p=.5),
|
219 |
+
], p=.33)
|
220 |
+
flip_aug_method = albu.HorizontalFlip(p=0.5)
|
221 |
+
|
222 |
+
# key: (use_color_aug, use_flip_aug)
|
223 |
+
self.augmentors = {
|
224 |
+
(True, True): albu.Compose([
|
225 |
+
color_aug_method,
|
226 |
+
flip_aug_method,
|
227 |
+
], p=1.),
|
228 |
+
(True, False): albu.Compose([
|
229 |
+
color_aug_method,
|
230 |
+
], p=1.),
|
231 |
+
(False, True): albu.Compose([
|
232 |
+
flip_aug_method,
|
233 |
+
], p=1.),
|
234 |
+
(False, False): None
|
235 |
+
}
|
236 |
+
|
237 |
+
def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]:
|
238 |
+
return self.augmentors[(use_color_aug, use_flip_aug)]
|
239 |
+
|
240 |
+
|
241 |
+
class BaseSubset:
|
242 |
+
def __init__(self, image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, keep_tokens: int, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], random_crop: bool, caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float) -> None:
|
243 |
+
self.image_dir = image_dir
|
244 |
+
self.num_repeats = num_repeats
|
245 |
+
self.shuffle_caption = shuffle_caption
|
246 |
+
self.keep_tokens = keep_tokens
|
247 |
+
self.color_aug = color_aug
|
248 |
+
self.flip_aug = flip_aug
|
249 |
+
self.face_crop_aug_range = face_crop_aug_range
|
250 |
+
self.random_crop = random_crop
|
251 |
+
self.caption_dropout_rate = caption_dropout_rate
|
252 |
+
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
|
253 |
+
self.caption_tag_dropout_rate = caption_tag_dropout_rate
|
254 |
+
|
255 |
+
self.img_count = 0
|
256 |
+
|
257 |
+
|
258 |
+
class DreamBoothSubset(BaseSubset):
|
259 |
+
def __init__(self, image_dir: str, is_reg: bool, class_tokens: Optional[str], caption_extension: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
|
260 |
+
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
261 |
+
|
262 |
+
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
|
263 |
+
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
|
264 |
+
|
265 |
+
self.is_reg = is_reg
|
266 |
+
self.class_tokens = class_tokens
|
267 |
+
self.caption_extension = caption_extension
|
268 |
+
|
269 |
+
def __eq__(self, other) -> bool:
|
270 |
+
if not isinstance(other, DreamBoothSubset):
|
271 |
+
return NotImplemented
|
272 |
+
return self.image_dir == other.image_dir
|
273 |
+
|
274 |
+
class FineTuningSubset(BaseSubset):
|
275 |
+
def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
|
276 |
+
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
277 |
+
|
278 |
+
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
|
279 |
+
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
|
280 |
+
|
281 |
+
self.metadata_file = metadata_file
|
282 |
+
|
283 |
+
def __eq__(self, other) -> bool:
|
284 |
+
if not isinstance(other, FineTuningSubset):
|
285 |
+
return NotImplemented
|
286 |
+
return self.metadata_file == other.metadata_file
|
287 |
+
|
288 |
class BaseDataset(torch.utils.data.Dataset):
|
289 |
+
def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None:
|
290 |
super().__init__()
|
291 |
+
self.tokenizer = tokenizer
|
292 |
self.max_token_length = max_token_length
|
|
|
|
|
293 |
# width/height is used when enable_bucket==False
|
294 |
self.width, self.height = (None, None) if resolution is None else resolution
|
|
|
|
|
|
|
295 |
self.debug_dataset = debug_dataset
|
296 |
+
|
297 |
+
self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
|
298 |
+
|
299 |
self.token_padding_disabled = False
|
|
|
|
|
300 |
self.tag_frequency = {}
|
301 |
|
302 |
self.enable_bucket = False
|
|
|
310 |
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
311 |
|
312 |
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
|
|
|
|
|
|
313 |
|
314 |
# augmentation
|
315 |
+
self.aug_helper = AugHelper()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
|
317 |
self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
|
318 |
|
319 |
self.image_data: Dict[str, ImageInfo] = {}
|
320 |
+
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
|
321 |
|
322 |
self.replacements = {}
|
323 |
|
324 |
def set_current_epoch(self, epoch):
|
325 |
self.current_epoch = epoch
|
326 |
+
self.shuffle_buckets()
|
|
|
|
|
|
|
|
|
|
|
327 |
|
328 |
def set_tag_frequency(self, dir_name, captions):
|
329 |
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
330 |
self.tag_frequency[dir_name] = frequency_for_dir
|
331 |
for caption in captions:
|
332 |
for tag in caption.split(","):
|
333 |
+
tag = tag.strip()
|
334 |
+
if tag:
|
335 |
tag = tag.lower()
|
336 |
frequency = frequency_for_dir.get(tag, 0)
|
337 |
frequency_for_dir[tag] = frequency + 1
|
|
|
342 |
def add_replacement(self, str_from, str_to):
|
343 |
self.replacements[str_from] = str_to
|
344 |
|
345 |
+
def process_caption(self, subset: BaseSubset, caption):
|
346 |
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
347 |
+
is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate
|
348 |
+
is_drop_out = is_drop_out or subset.caption_dropout_every_n_epochs > 0 and self.current_epoch % subset.caption_dropout_every_n_epochs == 0
|
349 |
|
350 |
if is_drop_out:
|
351 |
caption = ""
|
352 |
else:
|
353 |
+
if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0:
|
354 |
def dropout_tags(tokens):
|
355 |
+
if subset.caption_tag_dropout_rate <= 0:
|
356 |
return tokens
|
357 |
l = []
|
358 |
for token in tokens:
|
359 |
+
if random.random() >= subset.caption_tag_dropout_rate:
|
360 |
l.append(token)
|
361 |
return l
|
362 |
|
363 |
+
fixed_tokens = []
|
364 |
+
flex_tokens = [t.strip() for t in caption.strip().split(",")]
|
365 |
+
if subset.keep_tokens > 0:
|
366 |
+
fixed_tokens = flex_tokens[:subset.keep_tokens]
|
367 |
+
flex_tokens = flex_tokens[subset.keep_tokens:]
|
|
|
|
|
|
|
|
|
|
|
368 |
|
369 |
+
if subset.shuffle_caption:
|
370 |
+
random.shuffle(flex_tokens)
|
371 |
|
372 |
+
flex_tokens = dropout_tags(flex_tokens)
|
373 |
|
374 |
+
caption = ", ".join(fixed_tokens + flex_tokens)
|
|
|
375 |
|
376 |
# textual inversion対応
|
377 |
for str_from, str_to in self.replacements.items():
|
|
|
425 |
input_ids = torch.stack(iids_list) # 3,77
|
426 |
return input_ids
|
427 |
|
428 |
+
def register_image(self, info: ImageInfo, subset: BaseSubset):
|
429 |
self.image_data[info.image_key] = info
|
430 |
+
self.image_to_subset[info.image_key] = subset
|
431 |
|
432 |
def make_buckets(self):
|
433 |
'''
|
|
|
526 |
img = np.array(image, np.uint8)
|
527 |
return img
|
528 |
|
529 |
+
def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size):
|
530 |
image_height, image_width = image.shape[0:2]
|
531 |
|
532 |
if image_width != resized_size[0] or image_height != resized_size[1]:
|
|
|
536 |
image_height, image_width = image.shape[0:2]
|
537 |
if image_width > reso[0]:
|
538 |
trim_size = image_width - reso[0]
|
539 |
+
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
|
540 |
# print("w", trim_size, p)
|
541 |
image = image[:, p:p + reso[0]]
|
542 |
if image_height > reso[1]:
|
543 |
trim_size = image_height - reso[1]
|
544 |
+
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
|
545 |
# print("h", trim_size, p)
|
546 |
image = image[p:p + reso[1]]
|
547 |
|
548 |
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
549 |
return image
|
550 |
|
551 |
+
def is_latent_cacheable(self):
|
552 |
+
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
553 |
+
|
554 |
def cache_latents(self, vae):
|
555 |
# TODO ここを高速化したい
|
556 |
print("caching latents.")
|
557 |
for info in tqdm(self.image_data.values()):
|
558 |
+
subset = self.image_to_subset[info.image_key]
|
559 |
+
|
560 |
if info.latents_npz is not None:
|
561 |
info.latents = self.load_latents_from_npz(info, False)
|
562 |
info.latents = torch.FloatTensor(info.latents)
|
|
|
566 |
continue
|
567 |
|
568 |
image = self.load_image(info.absolute_path)
|
569 |
+
image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
|
570 |
|
571 |
img_tensor = self.image_transforms(image)
|
572 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
573 |
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
574 |
|
575 |
+
if subset.flip_aug:
|
576 |
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
|
577 |
img_tensor = self.image_transforms(image)
|
578 |
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
|
|
582 |
image = Image.open(image_path)
|
583 |
return image.size
|
584 |
|
585 |
+
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
|
586 |
img = self.load_image(image_path)
|
587 |
|
588 |
face_cx = face_cy = face_w = face_h = 0
|
589 |
+
if subset.face_crop_aug_range is not None:
|
590 |
tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
|
591 |
if len(tokens) >= 5:
|
592 |
face_cx = int(tokens[-4])
|
|
|
597 |
return img, face_cx, face_cy, face_w, face_h
|
598 |
|
599 |
# いい感じに切り出す
|
600 |
+
def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h):
|
601 |
height, width = image.shape[0:2]
|
602 |
if height == self.height and width == self.width:
|
603 |
return image
|
|
|
605 |
# 画像サイズはsizeより大きいのでリサイズする
|
606 |
face_size = max(face_w, face_h)
|
607 |
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
|
608 |
+
min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
|
609 |
+
max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
|
610 |
if min_scale >= max_scale: # range指定がmin==max
|
611 |
scale = min_scale
|
612 |
else:
|
|
|
624 |
for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
|
625 |
p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
|
626 |
|
627 |
+
if subset.random_crop:
|
628 |
# 背景も含めるために顔を中心に置く確率を高めつつずらす
|
629 |
range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
|
630 |
p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
|
631 |
else:
|
632 |
# range指定があるときのみ、すこしだけランダムに(わりと適当)
|
633 |
+
if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]:
|
634 |
if face_size > self.size // 10 and face_size >= 40:
|
635 |
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
|
636 |
|
|
|
653 |
return self._length
|
654 |
|
655 |
def __getitem__(self, index):
|
|
|
|
|
|
|
656 |
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
657 |
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
658 |
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
|
|
665 |
|
666 |
for image_key in bucket[image_index:image_index + bucket_batch_size]:
|
667 |
image_info = self.image_data[image_key]
|
668 |
+
subset = self.image_to_subset[image_key]
|
669 |
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
670 |
|
671 |
# image/latentsを処理する
|
672 |
if image_info.latents is not None:
|
673 |
+
latents = image_info.latents if not subset.flip_aug or random.random() < .5 else image_info.latents_flipped
|
674 |
image = None
|
675 |
elif image_info.latents_npz is not None:
|
676 |
+
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= .5)
|
677 |
latents = torch.FloatTensor(latents)
|
678 |
image = None
|
679 |
else:
|
680 |
# 画像を読み込み、必要ならcropする
|
681 |
+
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path)
|
682 |
im_h, im_w = img.shape[0:2]
|
683 |
|
684 |
if self.enable_bucket:
|
685 |
+
img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size)
|
686 |
else:
|
687 |
if face_cx > 0: # 顔位置情報あり
|
688 |
+
img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
|
689 |
elif im_h > self.height or im_w > self.width:
|
690 |
+
assert subset.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
|
691 |
if im_h > self.height:
|
692 |
p = random.randint(0, im_h - self.height)
|
693 |
img = img[p:p + self.height]
|
|
|
699 |
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
700 |
|
701 |
# augmentation
|
702 |
+
aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
|
703 |
+
if aug is not None:
|
704 |
+
img = aug(image=img)['image']
|
705 |
|
706 |
latents = None
|
707 |
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
|
|
709 |
images.append(image)
|
710 |
latents_list.append(latents)
|
711 |
|
712 |
+
caption = self.process_caption(subset, image_info.caption)
|
713 |
captions.append(caption)
|
714 |
if not self.token_padding_disabled: # this option might be omitted in future
|
715 |
input_ids_list.append(self.get_input_ids(caption))
|
|
|
740 |
|
741 |
|
742 |
class DreamBoothDataset(BaseDataset):
|
743 |
+
def __init__(self, subsets: Sequence[DreamBoothSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset) -> None:
|
744 |
+
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
|
|
745 |
|
746 |
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
747 |
|
|
|
764 |
self.bucket_reso_steps = None # この情報は使われない
|
765 |
self.bucket_no_upscale = False
|
766 |
|
767 |
+
def read_caption(img_path, caption_extension):
|
768 |
# captionの候補ファイル名を作る
|
769 |
base_name = os.path.splitext(img_path)[0]
|
770 |
base_name_face_det = base_name
|
|
|
787 |
break
|
788 |
return caption
|
789 |
|
790 |
+
def load_dreambooth_dir(subset: DreamBoothSubset):
|
791 |
+
if not os.path.isdir(subset.image_dir):
|
792 |
+
print(f"not directory: {subset.image_dir}")
|
793 |
+
return [], []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
794 |
|
795 |
+
img_paths = glob_images(subset.image_dir, "*")
|
796 |
+
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
|
|
797 |
|
798 |
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
799 |
captions = []
|
800 |
for img_path in img_paths:
|
801 |
+
cap_for_img = read_caption(img_path, subset.caption_extension)
|
802 |
+
if cap_for_img is None and subset.class_tokens is None:
|
803 |
+
print(f"neither caption file nor class tokens are found. use empty caption for {img_path}")
|
804 |
+
captions.append("")
|
805 |
+
else:
|
806 |
+
captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
|
807 |
+
|
808 |
+
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
809 |
|
810 |
+
return img_paths, captions
|
811 |
|
812 |
+
print("prepare images.")
|
|
|
813 |
num_train_images = 0
|
814 |
+
num_reg_images = 0
|
815 |
+
reg_infos: List[ImageInfo] = []
|
816 |
+
for subset in subsets:
|
817 |
+
if subset.num_repeats < 1:
|
818 |
+
print(f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
819 |
+
continue
|
820 |
+
|
821 |
+
if subset in self.subsets:
|
822 |
+
print(f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
823 |
+
continue
|
824 |
+
|
825 |
+
img_paths, captions = load_dreambooth_dir(subset)
|
826 |
+
if len(img_paths) < 1:
|
827 |
+
print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
|
828 |
+
continue
|
829 |
+
|
830 |
+
if subset.is_reg:
|
831 |
+
num_reg_images += subset.num_repeats * len(img_paths)
|
832 |
+
else:
|
833 |
+
num_train_images += subset.num_repeats * len(img_paths)
|
834 |
|
835 |
for img_path, caption in zip(img_paths, captions):
|
836 |
+
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
|
837 |
+
if subset.is_reg:
|
838 |
+
reg_infos.append(info)
|
839 |
+
else:
|
840 |
+
self.register_image(info, subset)
|
841 |
|
842 |
+
subset.img_count = len(img_paths)
|
843 |
+
self.subsets.append(subset)
|
844 |
|
845 |
print(f"{num_train_images} train images with repeating.")
|
846 |
self.num_train_images = num_train_images
|
847 |
|
848 |
+
print(f"{num_reg_images} reg images.")
|
849 |
+
if num_train_images < num_reg_images:
|
850 |
+
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
|
|
|
|
|
851 |
|
852 |
+
if num_reg_images == 0:
|
853 |
+
print("no regularization images / 正則化画像が見つかりませんでした")
|
854 |
+
else:
|
855 |
+
# num_repeatsを計算する:どうせ大した数ではないのでループで処理する
|
856 |
+
n = 0
|
857 |
+
first_loop = True
|
858 |
+
while n < num_train_images:
|
859 |
+
for info in reg_infos:
|
860 |
+
if first_loop:
|
861 |
+
self.register_image(info, subset)
|
862 |
+
n += info.num_repeats
|
863 |
+
else:
|
864 |
+
info.num_repeats += 1
|
865 |
+
n += 1
|
866 |
+
if n >= num_train_images:
|
867 |
+
break
|
868 |
+
first_loop = False
|
869 |
|
870 |
+
self.num_reg_images = num_reg_images
|
|
|
|
|
871 |
|
|
|
872 |
|
873 |
+
class FineTuningDataset(BaseDataset):
|
874 |
+
def __init__(self, subsets: Sequence[FineTuningSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset) -> None:
|
875 |
+
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
876 |
+
|
877 |
+
self.batch_size = batch_size
|
878 |
+
|
879 |
+
self.num_train_images = 0
|
880 |
+
self.num_reg_images = 0
|
881 |
|
882 |
+
for subset in subsets:
|
883 |
+
if subset.num_repeats < 1:
|
884 |
+
print(f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
885 |
+
continue
|
886 |
+
|
887 |
+
if subset in self.subsets:
|
888 |
+
print(f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
889 |
+
continue
|
890 |
+
|
891 |
+
# メタデータを読み込む
|
892 |
+
if os.path.exists(subset.metadata_file):
|
893 |
+
print(f"loading existing metadata: {subset.metadata_file}")
|
894 |
+
with open(subset.metadata_file, "rt", encoding='utf-8') as f:
|
895 |
+
metadata = json.load(f)
|
896 |
else:
|
897 |
+
raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
898 |
|
899 |
+
if len(metadata) < 1:
|
900 |
+
print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します")
|
901 |
+
continue
|
902 |
|
903 |
+
tags_list = []
|
904 |
+
for image_key, img_md in metadata.items():
|
905 |
+
# path情報を作る
|
906 |
+
if os.path.exists(image_key):
|
907 |
+
abs_path = image_key
|
908 |
+
else:
|
909 |
+
# わりといい加減だがいい方法が思いつかん
|
910 |
+
abs_path = glob_images(subset.image_dir, image_key)
|
911 |
+
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
912 |
+
abs_path = abs_path[0]
|
913 |
|
914 |
+
caption = img_md.get('caption')
|
915 |
+
tags = img_md.get('tags')
|
916 |
+
if caption is None:
|
917 |
+
caption = tags
|
918 |
+
elif tags is not None and len(tags) > 0:
|
919 |
+
caption = caption + ', ' + tags
|
920 |
+
tags_list.append(tags)
|
921 |
+
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
|
|
|
|
|
|
|
|
|
922 |
|
923 |
+
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
|
924 |
+
image_info.image_size = img_md.get('train_resolution')
|
|
|
925 |
|
926 |
+
if not subset.color_aug and not subset.random_crop:
|
927 |
+
# if npz exists, use them
|
928 |
+
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
|
929 |
+
|
930 |
+
self.register_image(image_info, subset)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
931 |
|
932 |
+
self.num_train_images += len(metadata) * subset.num_repeats
|
933 |
+
|
934 |
+
# TODO do not record tag freq when no tag
|
935 |
+
self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list)
|
936 |
+
subset.img_count = len(metadata)
|
937 |
+
self.subsets.append(subset)
|
938 |
|
939 |
# check existence of all npz files
|
940 |
+
use_npz_latents = all([not(subset.color_aug or subset.random_crop) for subset in self.subsets])
|
941 |
if use_npz_latents:
|
942 |
+
flip_aug_in_subset = False
|
943 |
npz_any = False
|
944 |
npz_all = True
|
945 |
+
|
946 |
for image_info in self.image_data.values():
|
947 |
+
subset = self.image_to_subset[image_info.image_key]
|
948 |
+
|
949 |
has_npz = image_info.latents_npz is not None
|
950 |
npz_any = npz_any or has_npz
|
951 |
|
952 |
+
if subset.flip_aug:
|
953 |
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
954 |
+
flip_aug_in_subset = True
|
955 |
npz_all = npz_all and has_npz
|
956 |
|
957 |
if npz_any and not npz_all:
|
|
|
963 |
elif not npz_all:
|
964 |
use_npz_latents = False
|
965 |
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
|
966 |
+
if flip_aug_in_subset:
|
967 |
print("maybe no flipped files / ��転されたnpzファイルがないのかもしれません")
|
968 |
# else:
|
969 |
# print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
|
|
1009 |
for image_info in self.image_data.values():
|
1010 |
image_info.latents_npz = image_info.latents_npz_flipped = None
|
1011 |
|
1012 |
+
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
|
1013 |
base_name = os.path.splitext(image_key)[0]
|
1014 |
npz_file_norm = base_name + '.npz'
|
1015 |
|
|
|
1021 |
return npz_file_norm, npz_file_flip
|
1022 |
|
1023 |
# image_key is relative path
|
1024 |
+
npz_file_norm = os.path.join(subset.image_dir, image_key + '.npz')
|
1025 |
+
npz_file_flip = os.path.join(subset.image_dir, image_key + '_flip.npz')
|
1026 |
|
1027 |
if not os.path.exists(npz_file_norm):
|
1028 |
npz_file_norm = None
|
|
|
1033 |
return npz_file_norm, npz_file_flip
|
1034 |
|
1035 |
|
1036 |
+
# behave as Dataset mock
|
1037 |
+
class DatasetGroup(torch.utils.data.ConcatDataset):
|
1038 |
+
def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
|
1039 |
+
self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]]
|
1040 |
+
|
1041 |
+
super().__init__(datasets)
|
1042 |
+
|
1043 |
+
self.image_data = {}
|
1044 |
+
self.num_train_images = 0
|
1045 |
+
self.num_reg_images = 0
|
1046 |
+
|
1047 |
+
# simply concat together
|
1048 |
+
# TODO: handling image_data key duplication among dataset
|
1049 |
+
# In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset.
|
1050 |
+
for dataset in datasets:
|
1051 |
+
self.image_data.update(dataset.image_data)
|
1052 |
+
self.num_train_images += dataset.num_train_images
|
1053 |
+
self.num_reg_images += dataset.num_reg_images
|
1054 |
+
|
1055 |
+
def add_replacement(self, str_from, str_to):
|
1056 |
+
for dataset in self.datasets:
|
1057 |
+
dataset.add_replacement(str_from, str_to)
|
1058 |
+
|
1059 |
+
# def make_buckets(self):
|
1060 |
+
# for dataset in self.datasets:
|
1061 |
+
# dataset.make_buckets()
|
1062 |
+
|
1063 |
+
def cache_latents(self, vae):
|
1064 |
+
for i, dataset in enumerate(self.datasets):
|
1065 |
+
print(f"[Dataset {i}]")
|
1066 |
+
dataset.cache_latents(vae)
|
1067 |
+
|
1068 |
+
def is_latent_cacheable(self) -> bool:
|
1069 |
+
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
1070 |
+
|
1071 |
+
def set_current_epoch(self, epoch):
|
1072 |
+
for dataset in self.datasets:
|
1073 |
+
dataset.set_current_epoch(epoch)
|
1074 |
+
|
1075 |
+
def disable_token_padding(self):
|
1076 |
+
for dataset in self.datasets:
|
1077 |
+
dataset.disable_token_padding()
|
1078 |
+
|
1079 |
+
|
1080 |
def debug_dataset(train_dataset, show_input_ids=False):
|
1081 |
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
1082 |
print("Escape for exit. / Escキーで中断、終了します")
|
1083 |
|
1084 |
train_dataset.set_current_epoch(1)
|
1085 |
k = 0
|
1086 |
+
indices = list(range(len(train_dataset)))
|
1087 |
+
random.shuffle(indices)
|
1088 |
+
for i, idx in enumerate(indices):
|
1089 |
+
example = train_dataset[idx]
|
1090 |
if example['latents'] is not None:
|
1091 |
print(f"sample has latents from npz file: {example['latents'].size()}")
|
1092 |
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
|
|
|
1491 |
help='enable v-parameterization training / v-parameterization学習を有効にする')
|
1492 |
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
|
1493 |
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
1494 |
+
parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
|
1495 |
+
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
|
1496 |
+
|
1497 |
+
|
1498 |
+
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
1499 |
+
parser.add_argument("--optimizer_type", type=str, default="",
|
1500 |
+
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor")
|
1501 |
+
|
1502 |
+
# backward compatibility
|
1503 |
+
parser.add_argument("--use_8bit_adam", action="store_true",
|
1504 |
+
help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
1505 |
+
parser.add_argument("--use_lion_optimizer", action="store_true",
|
1506 |
+
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
1507 |
+
|
1508 |
+
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
1509 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
1510 |
+
help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない")
|
1511 |
+
|
1512 |
+
parser.add_argument("--optimizer_args", type=str, default=None, nargs='*',
|
1513 |
+
help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")")
|
1514 |
+
|
1515 |
+
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
1516 |
+
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor")
|
1517 |
+
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
1518 |
+
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
1519 |
+
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
1520 |
+
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
1521 |
+
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
1522 |
+
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
1523 |
|
1524 |
|
1525 |
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
|
|
1543 |
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
1544 |
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
1545 |
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
|
|
|
|
|
|
|
|
1546 |
parser.add_argument("--mem_eff_attn", action="store_true",
|
1547 |
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
1548 |
parser.add_argument("--xformers", action="store_true",
|
|
|
1550 |
parser.add_argument("--vae", type=str, default=None,
|
1551 |
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
1552 |
|
|
|
1553 |
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
1554 |
parser.add_argument("--max_train_epochs", type=int, default=None,
|
1555 |
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
|
|
1570 |
parser.add_argument("--logging_dir", type=str, default=None,
|
1571 |
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
1572 |
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
|
|
|
|
|
|
|
|
1573 |
parser.add_argument("--noise_offset", type=float, default=None,
|
1574 |
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
|
1575 |
parser.add_argument("--lowram", action="store_true",
|
1576 |
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
|
1577 |
|
1578 |
+
parser.add_argument("--sample_every_n_steps", type=int, default=None,
|
1579 |
+
help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する")
|
1580 |
+
parser.add_argument("--sample_every_n_epochs", type=int, default=None,
|
1581 |
+
help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)")
|
1582 |
+
parser.add_argument("--sample_prompts", type=str, default=None,
|
1583 |
+
help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル")
|
1584 |
+
parser.add_argument('--sample_sampler', type=str, default='ddim',
|
1585 |
+
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
|
1586 |
+
'dpmsolver++', 'dpmsingle',
|
1587 |
+
'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'],
|
1588 |
+
help=f'sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類')
|
1589 |
+
|
1590 |
if support_dreambooth:
|
1591 |
# DreamBooth training
|
1592 |
parser.add_argument("--prior_loss_weight", type=float, default=1.0,
|
|
|
1608 |
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
|
1609 |
parser.add_argument("--caption_extention", type=str, default=None,
|
1610 |
help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
|
1611 |
+
parser.add_argument("--keep_tokens", type=int, default=0,
|
1612 |
+
help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)")
|
1613 |
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
|
1614 |
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
|
1615 |
parser.add_argument("--face_crop_aug_range", type=str, default=None,
|
|
|
1634 |
if support_caption_dropout:
|
1635 |
# Textual Inversion はcaptionのdropoutをsupportしない
|
1636 |
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
1637 |
+
parser.add_argument("--caption_dropout_rate", type=float, default=0.0,
|
1638 |
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
1639 |
+
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=0,
|
1640 |
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
1641 |
+
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0.0,
|
1642 |
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
|
1643 |
|
1644 |
if support_dreambooth:
|
|
|
1663 |
# region utils
|
1664 |
|
1665 |
|
1666 |
+
def get_optimizer(args, trainable_params):
|
1667 |
+
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
|
1668 |
+
|
1669 |
+
optimizer_type = args.optimizer_type
|
1670 |
+
if args.use_8bit_adam:
|
1671 |
+
assert not args.use_lion_optimizer, "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています"
|
1672 |
+
assert optimizer_type is None or optimizer_type == "", "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています"
|
1673 |
+
optimizer_type = "AdamW8bit"
|
1674 |
+
|
1675 |
+
elif args.use_lion_optimizer:
|
1676 |
+
assert optimizer_type is None or optimizer_type == "", "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています"
|
1677 |
+
optimizer_type = "Lion"
|
1678 |
+
|
1679 |
+
if optimizer_type is None or optimizer_type == "":
|
1680 |
+
optimizer_type = "AdamW"
|
1681 |
+
optimizer_type = optimizer_type.lower()
|
1682 |
+
|
1683 |
+
# 引数を分解する:boolとfloat、tupleのみ対応
|
1684 |
+
optimizer_kwargs = {}
|
1685 |
+
if args.optimizer_args is not None and len(args.optimizer_args) > 0:
|
1686 |
+
for arg in args.optimizer_args:
|
1687 |
+
key, value = arg.split('=')
|
1688 |
+
|
1689 |
+
value = value.split(",")
|
1690 |
+
for i in range(len(value)):
|
1691 |
+
if value[i].lower() == "true" or value[i].lower() == "false":
|
1692 |
+
value[i] = (value[i].lower() == "true")
|
1693 |
+
else:
|
1694 |
+
value[i] = float(value[i])
|
1695 |
+
if len(value) == 1:
|
1696 |
+
value = value[0]
|
1697 |
+
else:
|
1698 |
+
value = tuple(value)
|
1699 |
+
|
1700 |
+
optimizer_kwargs[key] = value
|
1701 |
+
# print("optkwargs:", optimizer_kwargs)
|
1702 |
+
|
1703 |
+
lr = args.learning_rate
|
1704 |
+
|
1705 |
+
if optimizer_type == "AdamW8bit".lower():
|
1706 |
+
try:
|
1707 |
+
import bitsandbytes as bnb
|
1708 |
+
except ImportError:
|
1709 |
+
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
1710 |
+
print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
|
1711 |
+
optimizer_class = bnb.optim.AdamW8bit
|
1712 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1713 |
+
|
1714 |
+
elif optimizer_type == "SGDNesterov8bit".lower():
|
1715 |
+
try:
|
1716 |
+
import bitsandbytes as bnb
|
1717 |
+
except ImportError:
|
1718 |
+
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
1719 |
+
print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
|
1720 |
+
if "momentum" not in optimizer_kwargs:
|
1721 |
+
print(f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
|
1722 |
+
optimizer_kwargs["momentum"] = 0.9
|
1723 |
+
|
1724 |
+
optimizer_class = bnb.optim.SGD8bit
|
1725 |
+
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
1726 |
+
|
1727 |
+
elif optimizer_type == "Lion".lower():
|
1728 |
+
try:
|
1729 |
+
import lion_pytorch
|
1730 |
+
except ImportError:
|
1731 |
+
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
1732 |
+
print(f"use Lion optimizer | {optimizer_kwargs}")
|
1733 |
+
optimizer_class = lion_pytorch.Lion
|
1734 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1735 |
+
|
1736 |
+
elif optimizer_type == "SGDNesterov".lower():
|
1737 |
+
print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
|
1738 |
+
if "momentum" not in optimizer_kwargs:
|
1739 |
+
print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
|
1740 |
+
optimizer_kwargs["momentum"] = 0.9
|
1741 |
+
|
1742 |
+
optimizer_class = torch.optim.SGD
|
1743 |
+
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
1744 |
+
|
1745 |
+
elif optimizer_type == "DAdaptation".lower():
|
1746 |
+
try:
|
1747 |
+
import dadaptation
|
1748 |
+
except ImportError:
|
1749 |
+
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
1750 |
+
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
|
1751 |
+
|
1752 |
+
min_lr = lr
|
1753 |
+
if type(trainable_params) == list and type(trainable_params[0]) == dict:
|
1754 |
+
for group in trainable_params:
|
1755 |
+
min_lr = min(min_lr, group.get("lr", lr))
|
1756 |
+
|
1757 |
+
if min_lr <= 0.1:
|
1758 |
+
print(
|
1759 |
+
f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {min_lr}')
|
1760 |
+
print('recommend option: lr=1.0 / 推奨は1.0です')
|
1761 |
+
|
1762 |
+
optimizer_class = dadaptation.DAdaptAdam
|
1763 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1764 |
+
|
1765 |
+
elif optimizer_type == "Adafactor".lower():
|
1766 |
+
# 引数を確認して適宜補正する
|
1767 |
+
if "relative_step" not in optimizer_kwargs:
|
1768 |
+
optimizer_kwargs["relative_step"] = True # default
|
1769 |
+
if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
|
1770 |
+
print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします")
|
1771 |
+
optimizer_kwargs["relative_step"] = True
|
1772 |
+
print(f"use Adafactor optimizer | {optimizer_kwargs}")
|
1773 |
+
|
1774 |
+
if optimizer_kwargs["relative_step"]:
|
1775 |
+
print(f"relative_step is true / relative_stepがtrueです")
|
1776 |
+
if lr != 0.0:
|
1777 |
+
print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
|
1778 |
+
args.learning_rate = None
|
1779 |
+
|
1780 |
+
# trainable_paramsがgroupだった時の処理:lrを削除する
|
1781 |
+
if type(trainable_params) == list and type(trainable_params[0]) == dict:
|
1782 |
+
has_group_lr = False
|
1783 |
+
for group in trainable_params:
|
1784 |
+
p = group.pop("lr", None)
|
1785 |
+
has_group_lr = has_group_lr or (p is not None)
|
1786 |
+
|
1787 |
+
if has_group_lr:
|
1788 |
+
# 一応argsを無効にしてお�� TODO 依存関係が逆転してるのであまり望ましくない
|
1789 |
+
print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます")
|
1790 |
+
args.unet_lr = None
|
1791 |
+
args.text_encoder_lr = None
|
1792 |
+
|
1793 |
+
if args.lr_scheduler != "adafactor":
|
1794 |
+
print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
|
1795 |
+
args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
|
1796 |
+
|
1797 |
+
lr = None
|
1798 |
+
else:
|
1799 |
+
if args.max_grad_norm != 0.0:
|
1800 |
+
print(f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません")
|
1801 |
+
if args.lr_scheduler != "constant_with_warmup":
|
1802 |
+
print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
|
1803 |
+
if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
|
1804 |
+
print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
|
1805 |
+
|
1806 |
+
optimizer_class = transformers.optimization.Adafactor
|
1807 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1808 |
+
|
1809 |
+
elif optimizer_type == "AdamW".lower():
|
1810 |
+
print(f"use AdamW optimizer | {optimizer_kwargs}")
|
1811 |
+
optimizer_class = torch.optim.AdamW
|
1812 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1813 |
+
|
1814 |
+
else:
|
1815 |
+
# 任意のoptimizerを使う
|
1816 |
+
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
1817 |
+
print(f"use {optimizer_type} | {optimizer_kwargs}")
|
1818 |
+
if "." not in optimizer_type:
|
1819 |
+
optimizer_module = torch.optim
|
1820 |
+
else:
|
1821 |
+
values = optimizer_type.split(".")
|
1822 |
+
optimizer_module = importlib.import_module(".".join(values[:-1]))
|
1823 |
+
optimizer_type = values[-1]
|
1824 |
+
|
1825 |
+
optimizer_class = getattr(optimizer_module, optimizer_type)
|
1826 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
1827 |
+
|
1828 |
+
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
1829 |
+
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
|
1830 |
+
|
1831 |
+
return optimizer_name, optimizer_args, optimizer
|
1832 |
+
|
1833 |
+
|
1834 |
+
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
1835 |
+
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
1836 |
+
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
1837 |
+
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
1838 |
+
|
1839 |
+
|
1840 |
+
def get_scheduler_fix(
|
1841 |
+
name: Union[str, SchedulerType],
|
1842 |
+
optimizer: Optimizer,
|
1843 |
+
num_warmup_steps: Optional[int] = None,
|
1844 |
+
num_training_steps: Optional[int] = None,
|
1845 |
+
num_cycles: int = 1,
|
1846 |
+
power: float = 1.0,
|
1847 |
+
):
|
1848 |
+
"""
|
1849 |
+
Unified API to get any scheduler from its name.
|
1850 |
+
Args:
|
1851 |
+
name (`str` or `SchedulerType`):
|
1852 |
+
The name of the scheduler to use.
|
1853 |
+
optimizer (`torch.optim.Optimizer`):
|
1854 |
+
The optimizer that will be used during training.
|
1855 |
+
num_warmup_steps (`int`, *optional*):
|
1856 |
+
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
1857 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
1858 |
+
num_training_steps (`int``, *optional*):
|
1859 |
+
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
1860 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
1861 |
+
num_cycles (`int`, *optional*):
|
1862 |
+
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
1863 |
+
power (`float`, *optional*, defaults to 1.0):
|
1864 |
+
Power factor. See `POLYNOMIAL` scheduler
|
1865 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
1866 |
+
The index of the last epoch when resuming training.
|
1867 |
+
"""
|
1868 |
+
if name.startswith("adafactor"):
|
1869 |
+
assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
|
1870 |
+
initial_lr = float(name.split(':')[1])
|
1871 |
+
# print("adafactor scheduler init lr", initial_lr)
|
1872 |
+
return transformers.optimization.AdafactorSchedule(optimizer, initial_lr)
|
1873 |
+
|
1874 |
+
name = SchedulerType(name)
|
1875 |
+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
1876 |
+
if name == SchedulerType.CONSTANT:
|
1877 |
+
return schedule_func(optimizer)
|
1878 |
+
|
1879 |
+
# All other schedulers require `num_warmup_steps`
|
1880 |
+
if num_warmup_steps is None:
|
1881 |
+
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
1882 |
+
|
1883 |
+
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
1884 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
1885 |
+
|
1886 |
+
# All other schedulers require `num_training_steps`
|
1887 |
+
if num_training_steps is None:
|
1888 |
+
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
1889 |
+
|
1890 |
+
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
1891 |
+
return schedule_func(
|
1892 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
1893 |
+
)
|
1894 |
+
|
1895 |
+
if name == SchedulerType.POLYNOMIAL:
|
1896 |
+
return schedule_func(
|
1897 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
1898 |
+
)
|
1899 |
+
|
1900 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
1901 |
+
|
1902 |
+
|
1903 |
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
1904 |
# backward compatibility
|
1905 |
if args.caption_extention is not None:
|
1906 |
args.caption_extension = args.caption_extention
|
1907 |
args.caption_extention = None
|
1908 |
|
|
|
|
|
|
|
|
|
1909 |
# assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
|
1910 |
if args.resolution is not None:
|
1911 |
args.resolution = tuple([int(r) for r in args.resolution.split(',')])
|
|
|
1928 |
|
1929 |
def load_tokenizer(args: argparse.Namespace):
|
1930 |
print("prepare tokenizer")
|
1931 |
+
original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH
|
1932 |
+
|
1933 |
+
tokenizer: CLIPTokenizer = None
|
1934 |
+
if args.tokenizer_cache_dir:
|
1935 |
+
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace('/', '_'))
|
1936 |
+
if os.path.exists(local_tokenizer_path):
|
1937 |
+
print(f"load tokenizer from cache: {local_tokenizer_path}")
|
1938 |
+
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2
|
1939 |
+
|
1940 |
+
if tokenizer is None:
|
1941 |
+
if args.v2:
|
1942 |
+
tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer")
|
1943 |
+
else:
|
1944 |
+
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
1945 |
+
|
1946 |
+
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
1947 |
print(f"update token length: {args.max_token_length}")
|
1948 |
+
|
1949 |
+
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
1950 |
+
print(f"save Tokenizer to cache: {local_tokenizer_path}")
|
1951 |
+
tokenizer.save_pretrained(local_tokenizer_path)
|
1952 |
+
|
1953 |
return tokenizer
|
1954 |
|
1955 |
|
|
|
2000 |
|
2001 |
|
2002 |
def load_target_model(args: argparse.Namespace, weight_dtype):
|
2003 |
+
name_or_path = args.pretrained_model_name_or_path
|
2004 |
+
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
2005 |
+
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
2006 |
if load_stable_diffusion_format:
|
2007 |
print("load StableDiffusion checkpoint")
|
2008 |
+
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path)
|
2009 |
else:
|
2010 |
print("load Diffusers pretrained models")
|
2011 |
+
try:
|
2012 |
+
pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
|
2013 |
+
except EnvironmentError as ex:
|
2014 |
+
print(
|
2015 |
+
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}")
|
2016 |
text_encoder = pipe.text_encoder
|
2017 |
vae = pipe.vae
|
2018 |
unet = pipe.unet
|
|
|
2181 |
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
2182 |
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
2183 |
|
2184 |
+
|
2185 |
+
# scheduler:
|
2186 |
+
SCHEDULER_LINEAR_START = 0.00085
|
2187 |
+
SCHEDULER_LINEAR_END = 0.0120
|
2188 |
+
SCHEDULER_TIMESTEPS = 1000
|
2189 |
+
SCHEDLER_SCHEDULE = 'scaled_linear'
|
2190 |
+
|
2191 |
+
|
2192 |
+
def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None):
|
2193 |
+
"""
|
2194 |
+
生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない
|
2195 |
+
clip skipは対応した
|
2196 |
+
"""
|
2197 |
+
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
2198 |
+
return
|
2199 |
+
if args.sample_every_n_epochs is not None:
|
2200 |
+
# sample_every_n_steps は無視する
|
2201 |
+
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
2202 |
+
return
|
2203 |
+
else:
|
2204 |
+
if steps % args.sample_every_n_steps != 0:
|
2205 |
+
return
|
2206 |
+
|
2207 |
+
print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
2208 |
+
if not os.path.isfile(args.sample_prompts):
|
2209 |
+
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
2210 |
+
return
|
2211 |
+
|
2212 |
+
# ここでCUDAのキャッシュクリアとかしたほうがいいのか……
|
2213 |
+
|
2214 |
+
org_vae_device = vae.device # CPUにいるはず
|
2215 |
+
vae.to(device)
|
2216 |
+
|
2217 |
+
# clip skip 対応のための wrapper を作る
|
2218 |
+
if args.clip_skip is None:
|
2219 |
+
text_encoder_or_wrapper = text_encoder
|
2220 |
+
else:
|
2221 |
+
class Wrapper():
|
2222 |
+
def __init__(self, tenc) -> None:
|
2223 |
+
self.tenc = tenc
|
2224 |
+
self.config = {}
|
2225 |
+
super().__init__()
|
2226 |
+
|
2227 |
+
def __call__(self, input_ids, attention_mask):
|
2228 |
+
enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True)
|
2229 |
+
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
2230 |
+
encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states)
|
2231 |
+
pooled_output = enc_out['pooler_output']
|
2232 |
+
return encoder_hidden_states, pooled_output # 1st output is only used
|
2233 |
+
|
2234 |
+
text_encoder_or_wrapper = Wrapper(text_encoder)
|
2235 |
+
|
2236 |
+
# read prompts
|
2237 |
+
with open(args.sample_prompts, 'rt', encoding='utf-8') as f:
|
2238 |
+
prompts = f.readlines()
|
2239 |
+
|
2240 |
+
# schedulerを用意する
|
2241 |
+
sched_init_args = {}
|
2242 |
+
if args.sample_sampler == "ddim":
|
2243 |
+
scheduler_cls = DDIMScheduler
|
2244 |
+
elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
|
2245 |
+
scheduler_cls = DDPMScheduler
|
2246 |
+
elif args.sample_sampler == "pndm":
|
2247 |
+
scheduler_cls = PNDMScheduler
|
2248 |
+
elif args.sample_sampler == 'lms' or args.sample_sampler == 'k_lms':
|
2249 |
+
scheduler_cls = LMSDiscreteScheduler
|
2250 |
+
elif args.sample_sampler == 'euler' or args.sample_sampler == 'k_euler':
|
2251 |
+
scheduler_cls = EulerDiscreteScheduler
|
2252 |
+
elif args.sample_sampler == 'euler_a' or args.sample_sampler == 'k_euler_a':
|
2253 |
+
scheduler_cls = EulerAncestralDiscreteScheduler
|
2254 |
+
elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
|
2255 |
+
scheduler_cls = DPMSolverMultistepScheduler
|
2256 |
+
sched_init_args['algorithm_type'] = args.sample_sampler
|
2257 |
+
elif args.sample_sampler == "dpmsingle":
|
2258 |
+
scheduler_cls = DPMSolverSinglestepScheduler
|
2259 |
+
elif args.sample_sampler == "heun":
|
2260 |
+
scheduler_cls = HeunDiscreteScheduler
|
2261 |
+
elif args.sample_sampler == 'dpm_2' or args.sample_sampler == 'k_dpm_2':
|
2262 |
+
scheduler_cls = KDPM2DiscreteScheduler
|
2263 |
+
elif args.sample_sampler == 'dpm_2_a' or args.sample_sampler == 'k_dpm_2_a':
|
2264 |
+
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
2265 |
+
else:
|
2266 |
+
scheduler_cls = DDIMScheduler
|
2267 |
+
|
2268 |
+
if args.v_parameterization:
|
2269 |
+
sched_init_args['prediction_type'] = 'v_prediction'
|
2270 |
+
|
2271 |
+
scheduler = scheduler_cls(num_train_timesteps=SCHEDULER_TIMESTEPS,
|
2272 |
+
beta_start=SCHEDULER_LINEAR_START, beta_end=SCHEDULER_LINEAR_END,
|
2273 |
+
beta_schedule=SCHEDLER_SCHEDULE, **sched_init_args)
|
2274 |
+
|
2275 |
+
# clip_sample=Trueにする
|
2276 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
2277 |
+
# print("set clip_sample to True")
|
2278 |
+
scheduler.config.clip_sample = True
|
2279 |
+
|
2280 |
+
pipeline = StableDiffusionPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer,
|
2281 |
+
scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False)
|
2282 |
+
pipeline.to(device)
|
2283 |
+
|
2284 |
+
save_dir = args.output_dir + "/sample"
|
2285 |
+
os.makedirs(save_dir, exist_ok=True)
|
2286 |
+
|
2287 |
+
rng_state = torch.get_rng_state()
|
2288 |
+
cuda_rng_state = torch.cuda.get_rng_state()
|
2289 |
+
|
2290 |
+
with torch.no_grad():
|
2291 |
+
with accelerator.autocast():
|
2292 |
+
for i, prompt in enumerate(prompts):
|
2293 |
+
prompt = prompt.strip()
|
2294 |
+
if len(prompt) == 0 or prompt[0] == '#':
|
2295 |
+
continue
|
2296 |
+
|
2297 |
+
# subset of gen_img_diffusers
|
2298 |
+
prompt_args = prompt.split(' --')
|
2299 |
+
prompt = prompt_args[0]
|
2300 |
+
negative_prompt = None
|
2301 |
+
sample_steps = 30
|
2302 |
+
width = height = 512
|
2303 |
+
scale = 7.5
|
2304 |
+
seed = None
|
2305 |
+
for parg in prompt_args:
|
2306 |
+
try:
|
2307 |
+
m = re.match(r'w (\d+)', parg, re.IGNORECASE)
|
2308 |
+
if m:
|
2309 |
+
width = int(m.group(1))
|
2310 |
+
continue
|
2311 |
+
|
2312 |
+
m = re.match(r'h (\d+)', parg, re.IGNORECASE)
|
2313 |
+
if m:
|
2314 |
+
height = int(m.group(1))
|
2315 |
+
continue
|
2316 |
+
|
2317 |
+
m = re.match(r'd (\d+)', parg, re.IGNORECASE)
|
2318 |
+
if m:
|
2319 |
+
seed = int(m.group(1))
|
2320 |
+
continue
|
2321 |
+
|
2322 |
+
m = re.match(r's (\d+)', parg, re.IGNORECASE)
|
2323 |
+
if m: # steps
|
2324 |
+
sample_steps = max(1, min(1000, int(m.group(1))))
|
2325 |
+
continue
|
2326 |
+
|
2327 |
+
m = re.match(r'l ([\d\.]+)', parg, re.IGNORECASE)
|
2328 |
+
if m: # scale
|
2329 |
+
scale = float(m.group(1))
|
2330 |
+
continue
|
2331 |
+
|
2332 |
+
m = re.match(r'n (.+)', parg, re.IGNORECASE)
|
2333 |
+
if m: # negative prompt
|
2334 |
+
negative_prompt = m.group(1)
|
2335 |
+
continue
|
2336 |
+
|
2337 |
+
except ValueError as ex:
|
2338 |
+
print(f"Exception in parsing / 解析エラー: {parg}")
|
2339 |
+
print(ex)
|
2340 |
+
|
2341 |
+
if seed is not None:
|
2342 |
+
torch.manual_seed(seed)
|
2343 |
+
torch.cuda.manual_seed(seed)
|
2344 |
+
|
2345 |
+
if prompt_replacement is not None:
|
2346 |
+
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
2347 |
+
if negative_prompt is not None:
|
2348 |
+
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
2349 |
+
|
2350 |
+
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
|
2351 |
+
|
2352 |
+
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
|
2353 |
+
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
2354 |
+
seed_suffix = "" if seed is None else f"_{seed}"
|
2355 |
+
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
|
2356 |
+
|
2357 |
+
image.save(os.path.join(save_dir, img_filename))
|
2358 |
+
|
2359 |
+
torch.set_rng_state(rng_state)
|
2360 |
+
torch.cuda.set_rng_state(cuda_rng_state)
|
2361 |
+
vae.to(org_vae_device)
|
2362 |
+
|
2363 |
# endregion
|
2364 |
|
2365 |
# region 前処理用
|
networks/lora.py
CHANGED
@@ -126,6 +126,11 @@ class LoRANetwork(torch.nn.Module):
|
|
126 |
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
127 |
names.add(lora.lora_name)
|
128 |
|
|
|
|
|
|
|
|
|
|
|
129 |
def load_weights(self, file):
|
130 |
if os.path.splitext(file)[1] == '.safetensors':
|
131 |
from safetensors.torch import load_file, safe_open
|
|
|
126 |
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
127 |
names.add(lora.lora_name)
|
128 |
|
129 |
+
def set_multiplier(self, multiplier):
|
130 |
+
self.multiplier = multiplier
|
131 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
132 |
+
lora.multiplier = self.multiplier
|
133 |
+
|
134 |
def load_weights(self, file):
|
135 |
if os.path.splitext(file)[1] == '.safetensors':
|
136 |
from safetensors.torch import load_file, safe_open
|
requirements.txt
CHANGED
@@ -12,6 +12,8 @@ safetensors==0.2.6
|
|
12 |
gradio==3.16.2
|
13 |
altair==4.2.2
|
14 |
easygui==0.98.3
|
|
|
|
|
15 |
# for BLIP captioning
|
16 |
requests==2.28.2
|
17 |
timm==0.6.12
|
@@ -21,5 +23,4 @@ fairscale==0.4.13
|
|
21 |
tensorflow==2.10.1
|
22 |
huggingface-hub==0.12.0
|
23 |
# for kohya_ss library
|
24 |
-
#locon.locon_kohya
|
25 |
.
|
|
|
12 |
gradio==3.16.2
|
13 |
altair==4.2.2
|
14 |
easygui==0.98.3
|
15 |
+
toml==0.10.2
|
16 |
+
voluptuous==0.13.1
|
17 |
# for BLIP captioning
|
18 |
requests==2.28.2
|
19 |
timm==0.6.12
|
|
|
23 |
tensorflow==2.10.1
|
24 |
huggingface-hub==0.12.0
|
25 |
# for kohya_ss library
|
|
|
26 |
.
|
train_db.py
CHANGED
@@ -15,7 +15,11 @@ import diffusers
|
|
15 |
from diffusers import DDPMScheduler
|
16 |
|
17 |
import library.train_util as train_util
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
|
20 |
|
21 |
def collate_fn(examples):
|
@@ -33,24 +37,33 @@ def train(args):
|
|
33 |
|
34 |
tokenizer = train_util.load_tokenizer(args)
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
|
46 |
-
|
47 |
|
48 |
-
|
|
|
49 |
|
50 |
if args.debug_dataset:
|
51 |
-
train_util.debug_dataset(
|
52 |
return
|
53 |
|
|
|
|
|
|
|
54 |
# acceleratorを準備する
|
55 |
print("prepare accelerator")
|
56 |
|
@@ -91,7 +104,7 @@ def train(args):
|
|
91 |
vae.requires_grad_(False)
|
92 |
vae.eval()
|
93 |
with torch.no_grad():
|
94 |
-
|
95 |
vae.to("cpu")
|
96 |
if torch.cuda.is_available():
|
97 |
torch.cuda.empty_cache()
|
@@ -115,38 +128,18 @@ def train(args):
|
|
115 |
|
116 |
# 学習に必要なクラスを準備する
|
117 |
print("prepare optimizer, data loader etc.")
|
118 |
-
|
119 |
-
# 8-bit Adamを使う
|
120 |
-
if args.use_8bit_adam:
|
121 |
-
try:
|
122 |
-
import bitsandbytes as bnb
|
123 |
-
except ImportError:
|
124 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
125 |
-
print("use 8-bit Adam optimizer")
|
126 |
-
optimizer_class = bnb.optim.AdamW8bit
|
127 |
-
elif args.use_lion_optimizer:
|
128 |
-
try:
|
129 |
-
import lion_pytorch
|
130 |
-
except ImportError:
|
131 |
-
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
132 |
-
print("use Lion optimizer")
|
133 |
-
optimizer_class = lion_pytorch.Lion
|
134 |
-
else:
|
135 |
-
optimizer_class = torch.optim.AdamW
|
136 |
-
|
137 |
if train_text_encoder:
|
138 |
trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
139 |
else:
|
140 |
trainable_params = unet.parameters()
|
141 |
|
142 |
-
|
143 |
-
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
144 |
|
145 |
# dataloaderを準備する
|
146 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
147 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
148 |
train_dataloader = torch.utils.data.DataLoader(
|
149 |
-
|
150 |
|
151 |
# 学習ステップ数を計算する
|
152 |
if args.max_train_epochs is not None:
|
@@ -156,9 +149,10 @@ def train(args):
|
|
156 |
if args.stop_text_encoder_training is None:
|
157 |
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
158 |
|
159 |
-
# lr schedulerを用意する
|
160 |
-
lr_scheduler =
|
161 |
-
|
|
|
162 |
|
163 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
164 |
if args.full_fp16:
|
@@ -195,8 +189,8 @@ def train(args):
|
|
195 |
# 学習する
|
196 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
197 |
print("running training / 学習開始")
|
198 |
-
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {
|
199 |
-
print(f" num reg images / 正則化画像の数: {
|
200 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
201 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
202 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
@@ -217,7 +211,7 @@ def train(args):
|
|
217 |
loss_total = 0.0
|
218 |
for epoch in range(num_train_epochs):
|
219 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
220 |
-
|
221 |
|
222 |
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
223 |
unet.train()
|
@@ -281,12 +275,12 @@ def train(args):
|
|
281 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
282 |
|
283 |
accelerator.backward(loss)
|
284 |
-
if accelerator.sync_gradients:
|
285 |
if train_text_encoder:
|
286 |
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
287 |
else:
|
288 |
params_to_clip = unet.parameters()
|
289 |
-
accelerator.clip_grad_norm_(params_to_clip,
|
290 |
|
291 |
optimizer.step()
|
292 |
lr_scheduler.step()
|
@@ -297,9 +291,13 @@ def train(args):
|
|
297 |
progress_bar.update(1)
|
298 |
global_step += 1
|
299 |
|
|
|
|
|
300 |
current_loss = loss.detach().item()
|
301 |
if args.logging_dir is not None:
|
302 |
-
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
|
|
303 |
accelerator.log(logs, step=global_step)
|
304 |
|
305 |
if epoch == 0:
|
@@ -326,6 +324,8 @@ def train(args):
|
|
326 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
327 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
328 |
|
|
|
|
|
329 |
is_main_process = accelerator.is_main_process
|
330 |
if is_main_process:
|
331 |
unet = unwrap_model(unet)
|
@@ -352,6 +352,8 @@ if __name__ == '__main__':
|
|
352 |
train_util.add_dataset_arguments(parser, True, False, True)
|
353 |
train_util.add_training_arguments(parser, True)
|
354 |
train_util.add_sd_saving_arguments(parser)
|
|
|
|
|
355 |
|
356 |
parser.add_argument("--no_token_padding", action="store_true",
|
357 |
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
|
|
15 |
from diffusers import DDPMScheduler
|
16 |
|
17 |
import library.train_util as train_util
|
18 |
+
import library.config_util as config_util
|
19 |
+
from library.config_util import (
|
20 |
+
ConfigSanitizer,
|
21 |
+
BlueprintGenerator,
|
22 |
+
)
|
23 |
|
24 |
|
25 |
def collate_fn(examples):
|
|
|
37 |
|
38 |
tokenizer = train_util.load_tokenizer(args)
|
39 |
|
40 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
|
41 |
+
if args.dataset_config is not None:
|
42 |
+
print(f"Load dataset config from {args.dataset_config}")
|
43 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
44 |
+
ignored = ["train_data_dir", "reg_data_dir"]
|
45 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
46 |
+
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
47 |
+
else:
|
48 |
+
user_config = {
|
49 |
+
"datasets": [{
|
50 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
|
51 |
+
}]
|
52 |
+
}
|
53 |
|
54 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
55 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
56 |
|
57 |
+
if args.no_token_padding:
|
58 |
+
train_dataset_group.disable_token_padding()
|
59 |
|
60 |
if args.debug_dataset:
|
61 |
+
train_util.debug_dataset(train_dataset_group)
|
62 |
return
|
63 |
|
64 |
+
if cache_latents:
|
65 |
+
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
66 |
+
|
67 |
# acceleratorを準備する
|
68 |
print("prepare accelerator")
|
69 |
|
|
|
104 |
vae.requires_grad_(False)
|
105 |
vae.eval()
|
106 |
with torch.no_grad():
|
107 |
+
train_dataset_group.cache_latents(vae)
|
108 |
vae.to("cpu")
|
109 |
if torch.cuda.is_available():
|
110 |
torch.cuda.empty_cache()
|
|
|
128 |
|
129 |
# 学習に必要なクラスを準備する
|
130 |
print("prepare optimizer, data loader etc.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
if train_text_encoder:
|
132 |
trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
133 |
else:
|
134 |
trainable_params = unet.parameters()
|
135 |
|
136 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
|
|
137 |
|
138 |
# dataloaderを準備する
|
139 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
140 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
141 |
train_dataloader = torch.utils.data.DataLoader(
|
142 |
+
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
143 |
|
144 |
# 学習ステップ数を計算する
|
145 |
if args.max_train_epochs is not None:
|
|
|
149 |
if args.stop_text_encoder_training is None:
|
150 |
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
151 |
|
152 |
+
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
|
153 |
+
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
154 |
+
num_training_steps=args.max_train_steps,
|
155 |
+
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
156 |
|
157 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
158 |
if args.full_fp16:
|
|
|
189 |
# 学習する
|
190 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
191 |
print("running training / 学習開始")
|
192 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
193 |
+
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
194 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
195 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
196 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
|
|
211 |
loss_total = 0.0
|
212 |
for epoch in range(num_train_epochs):
|
213 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
214 |
+
train_dataset_group.set_current_epoch(epoch + 1)
|
215 |
|
216 |
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
217 |
unet.train()
|
|
|
275 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
276 |
|
277 |
accelerator.backward(loss)
|
278 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
279 |
if train_text_encoder:
|
280 |
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
281 |
else:
|
282 |
params_to_clip = unet.parameters()
|
283 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
284 |
|
285 |
optimizer.step()
|
286 |
lr_scheduler.step()
|
|
|
291 |
progress_bar.update(1)
|
292 |
global_step += 1
|
293 |
|
294 |
+
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
295 |
+
|
296 |
current_loss = loss.detach().item()
|
297 |
if args.logging_dir is not None:
|
298 |
+
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
299 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
300 |
+
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
301 |
accelerator.log(logs, step=global_step)
|
302 |
|
303 |
if epoch == 0:
|
|
|
324 |
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
325 |
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
326 |
|
327 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
328 |
+
|
329 |
is_main_process = accelerator.is_main_process
|
330 |
if is_main_process:
|
331 |
unet = unwrap_model(unet)
|
|
|
352 |
train_util.add_dataset_arguments(parser, True, False, True)
|
353 |
train_util.add_training_arguments(parser, True)
|
354 |
train_util.add_sd_saving_arguments(parser)
|
355 |
+
train_util.add_optimizer_arguments(parser)
|
356 |
+
config_util.add_config_arguments(parser)
|
357 |
|
358 |
parser.add_argument("--no_token_padding", action="store_true",
|
359 |
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
train_network.py
CHANGED
@@ -1,8 +1,4 @@
|
|
1 |
-
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
2 |
-
from torch.optim import Optimizer
|
3 |
-
from torch.cuda.amp import autocast
|
4 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
5 |
-
from typing import Optional, Union
|
6 |
import importlib
|
7 |
import argparse
|
8 |
import gc
|
@@ -15,92 +11,39 @@ import json
|
|
15 |
from tqdm import tqdm
|
16 |
import torch
|
17 |
from accelerate.utils import set_seed
|
18 |
-
import diffusers
|
19 |
from diffusers import DDPMScheduler
|
20 |
|
21 |
import library.train_util as train_util
|
22 |
-
from library.train_util import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
def collate_fn(examples):
|
26 |
return examples[0]
|
27 |
|
28 |
|
|
|
29 |
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
30 |
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
31 |
|
32 |
if args.network_train_unet_only:
|
33 |
-
logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
|
34 |
elif args.network_train_text_encoder_only:
|
35 |
-
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
36 |
else:
|
37 |
-
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
38 |
-
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
|
39 |
-
|
40 |
-
return logs
|
41 |
|
|
|
|
|
42 |
|
43 |
-
|
44 |
-
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
45 |
-
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
46 |
-
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
47 |
-
|
48 |
-
|
49 |
-
def get_scheduler_fix(
|
50 |
-
name: Union[str, SchedulerType],
|
51 |
-
optimizer: Optimizer,
|
52 |
-
num_warmup_steps: Optional[int] = None,
|
53 |
-
num_training_steps: Optional[int] = None,
|
54 |
-
num_cycles: int = 1,
|
55 |
-
power: float = 1.0,
|
56 |
-
):
|
57 |
-
"""
|
58 |
-
Unified API to get any scheduler from its name.
|
59 |
-
Args:
|
60 |
-
name (`str` or `SchedulerType`):
|
61 |
-
The name of the scheduler to use.
|
62 |
-
optimizer (`torch.optim.Optimizer`):
|
63 |
-
The optimizer that will be used during training.
|
64 |
-
num_warmup_steps (`int`, *optional*):
|
65 |
-
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
66 |
-
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
67 |
-
num_training_steps (`int``, *optional*):
|
68 |
-
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
69 |
-
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
70 |
-
num_cycles (`int`, *optional*):
|
71 |
-
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
72 |
-
power (`float`, *optional*, defaults to 1.0):
|
73 |
-
Power factor. See `POLYNOMIAL` scheduler
|
74 |
-
last_epoch (`int`, *optional*, defaults to -1):
|
75 |
-
The index of the last epoch when resuming training.
|
76 |
-
"""
|
77 |
-
name = SchedulerType(name)
|
78 |
-
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
79 |
-
if name == SchedulerType.CONSTANT:
|
80 |
-
return schedule_func(optimizer)
|
81 |
-
|
82 |
-
# All other schedulers require `num_warmup_steps`
|
83 |
-
if num_warmup_steps is None:
|
84 |
-
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
85 |
-
|
86 |
-
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
87 |
-
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
88 |
-
|
89 |
-
# All other schedulers require `num_training_steps`
|
90 |
-
if num_training_steps is None:
|
91 |
-
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
92 |
-
|
93 |
-
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
94 |
-
return schedule_func(
|
95 |
-
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
96 |
-
)
|
97 |
-
|
98 |
-
if name == SchedulerType.POLYNOMIAL:
|
99 |
-
return schedule_func(
|
100 |
-
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
101 |
-
)
|
102 |
-
|
103 |
-
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
104 |
|
105 |
|
106 |
def train(args):
|
@@ -111,6 +54,7 @@ def train(args):
|
|
111 |
|
112 |
cache_latents = args.cache_latents
|
113 |
use_dreambooth_method = args.in_json is None
|
|
|
114 |
|
115 |
if args.seed is not None:
|
116 |
set_seed(args.seed)
|
@@ -118,35 +62,47 @@ def train(args):
|
|
118 |
tokenizer = train_util.load_tokenizer(args)
|
119 |
|
120 |
# データセットを準備する
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
else:
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
if args.debug_dataset:
|
144 |
-
train_util.debug_dataset(
|
145 |
return
|
146 |
-
if len(
|
147 |
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
148 |
return
|
149 |
|
|
|
|
|
|
|
|
|
150 |
# acceleratorを準備する
|
151 |
print("prepare accelerator")
|
152 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
@@ -161,7 +117,7 @@ def train(args):
|
|
161 |
if args.lowram:
|
162 |
text_encoder.to("cuda")
|
163 |
unet.to("cuda")
|
164 |
-
|
165 |
# モデルに xformers とか memory efficient attention を組み込む
|
166 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
167 |
|
@@ -171,7 +127,7 @@ def train(args):
|
|
171 |
vae.requires_grad_(False)
|
172 |
vae.eval()
|
173 |
with torch.no_grad():
|
174 |
-
|
175 |
vae.to("cpu")
|
176 |
if torch.cuda.is_available():
|
177 |
torch.cuda.empty_cache()
|
@@ -208,36 +164,14 @@ def train(args):
|
|
208 |
# 学習に必要なクラスを準備する
|
209 |
print("prepare optimizer, data loader etc.")
|
210 |
|
211 |
-
# 8-bit Adamを使う
|
212 |
-
if args.use_8bit_adam:
|
213 |
-
try:
|
214 |
-
import bitsandbytes as bnb
|
215 |
-
except ImportError:
|
216 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
217 |
-
print("use 8-bit Adam optimizer")
|
218 |
-
optimizer_class = bnb.optim.AdamW8bit
|
219 |
-
elif args.use_lion_optimizer:
|
220 |
-
try:
|
221 |
-
import lion_pytorch
|
222 |
-
except ImportError:
|
223 |
-
raise ImportError("No lion_pytorch / lion_pytorch がインストールされて��ないようです")
|
224 |
-
print("use Lion optimizer")
|
225 |
-
optimizer_class = lion_pytorch.Lion
|
226 |
-
else:
|
227 |
-
optimizer_class = torch.optim.AdamW
|
228 |
-
|
229 |
-
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
230 |
-
|
231 |
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
232 |
-
|
233 |
-
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
234 |
-
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
235 |
|
236 |
# dataloaderを準備する
|
237 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
238 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
239 |
train_dataloader = torch.utils.data.DataLoader(
|
240 |
-
|
241 |
|
242 |
# 学習ステップ数を計算する
|
243 |
if args.max_train_epochs is not None:
|
@@ -245,11 +179,9 @@ def train(args):
|
|
245 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
246 |
|
247 |
# lr schedulerを用意する
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
252 |
-
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
253 |
|
254 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
255 |
if args.full_fp16:
|
@@ -317,17 +249,19 @@ def train(args):
|
|
317 |
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
318 |
|
319 |
# 学習する
|
|
|
320 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
321 |
print("running training / 学習開始")
|
322 |
-
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {
|
323 |
-
print(f" num reg images / 正則化画像の数: {
|
324 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
325 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
326 |
-
print(f" batch size per device / バッチサイズ: {
|
327 |
-
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
328 |
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
329 |
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
330 |
|
|
|
331 |
metadata = {
|
332 |
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
333 |
"ss_training_started_at": training_started_at, # unix timestamp
|
@@ -335,12 +269,10 @@ def train(args):
|
|
335 |
"ss_learning_rate": args.learning_rate,
|
336 |
"ss_text_encoder_lr": args.text_encoder_lr,
|
337 |
"ss_unet_lr": args.unet_lr,
|
338 |
-
"ss_num_train_images":
|
339 |
-
"ss_num_reg_images":
|
340 |
"ss_num_batches_per_epoch": len(train_dataloader),
|
341 |
"ss_num_epochs": num_train_epochs,
|
342 |
-
"ss_batch_size_per_device": args.train_batch_size,
|
343 |
-
"ss_total_batch_size": total_batch_size,
|
344 |
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
345 |
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
346 |
"ss_max_train_steps": args.max_train_steps,
|
@@ -352,29 +284,149 @@ def train(args):
|
|
352 |
"ss_mixed_precision": args.mixed_precision,
|
353 |
"ss_full_fp16": bool(args.full_fp16),
|
354 |
"ss_v2": bool(args.v2),
|
355 |
-
"ss_resolution": args.resolution,
|
356 |
"ss_clip_skip": args.clip_skip,
|
357 |
"ss_max_token_length": args.max_token_length,
|
358 |
-
"ss_color_aug": bool(args.color_aug),
|
359 |
-
"ss_flip_aug": bool(args.flip_aug),
|
360 |
-
"ss_random_crop": bool(args.random_crop),
|
361 |
-
"ss_shuffle_caption": bool(args.shuffle_caption),
|
362 |
"ss_cache_latents": bool(args.cache_latents),
|
363 |
-
"ss_enable_bucket": bool(train_dataset.enable_bucket),
|
364 |
-
"ss_min_bucket_reso": train_dataset.min_bucket_reso,
|
365 |
-
"ss_max_bucket_reso": train_dataset.max_bucket_reso,
|
366 |
"ss_seed": args.seed,
|
367 |
-
"
|
368 |
"ss_noise_offset": args.noise_offset,
|
369 |
-
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
370 |
-
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
371 |
-
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
|
372 |
-
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
|
373 |
"ss_training_comment": args.training_comment, # will not be updated after training
|
374 |
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
375 |
-
"ss_optimizer": optimizer_name
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
}
|
377 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
378 |
# uncomment if another network is added
|
379 |
# for key, value in net_kwargs.items():
|
380 |
# metadata["ss_arg_" + key] = value
|
@@ -410,7 +462,7 @@ def train(args):
|
|
410 |
loss_total = 0.0
|
411 |
for epoch in range(num_train_epochs):
|
412 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
413 |
-
|
414 |
|
415 |
metadata["ss_epoch"] = str(epoch+1)
|
416 |
|
@@ -447,7 +499,7 @@ def train(args):
|
|
447 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
448 |
|
449 |
# Predict the noise residual
|
450 |
-
with autocast():
|
451 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
452 |
|
453 |
if args.v_parameterization:
|
@@ -465,9 +517,9 @@ def train(args):
|
|
465 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
466 |
|
467 |
accelerator.backward(loss)
|
468 |
-
if accelerator.sync_gradients:
|
469 |
params_to_clip = network.get_trainable_params()
|
470 |
-
accelerator.clip_grad_norm_(params_to_clip,
|
471 |
|
472 |
optimizer.step()
|
473 |
lr_scheduler.step()
|
@@ -478,6 +530,8 @@ def train(args):
|
|
478 |
progress_bar.update(1)
|
479 |
global_step += 1
|
480 |
|
|
|
|
|
481 |
current_loss = loss.detach().item()
|
482 |
if epoch == 0:
|
483 |
loss_list.append(current_loss)
|
@@ -508,6 +562,7 @@ def train(args):
|
|
508 |
def save_func():
|
509 |
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
510 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
|
|
511 |
print(f"saving checkpoint: {ckpt_file}")
|
512 |
unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
513 |
|
@@ -522,9 +577,12 @@ def train(args):
|
|
522 |
if saving and args.save_state:
|
523 |
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
524 |
|
|
|
|
|
525 |
# end of epoch
|
526 |
|
527 |
metadata["ss_epoch"] = str(num_train_epochs)
|
|
|
528 |
|
529 |
is_main_process = accelerator.is_main_process
|
530 |
if is_main_process:
|
@@ -555,6 +613,8 @@ if __name__ == '__main__':
|
|
555 |
train_util.add_sd_models_arguments(parser)
|
556 |
train_util.add_dataset_arguments(parser, True, True, True)
|
557 |
train_util.add_training_arguments(parser, True)
|
|
|
|
|
558 |
|
559 |
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
560 |
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
@@ -562,10 +622,6 @@ if __name__ == '__main__':
|
|
562 |
|
563 |
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
564 |
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
565 |
-
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
566 |
-
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
567 |
-
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
568 |
-
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
569 |
|
570 |
parser.add_argument("--network_weights", type=str, default=None,
|
571 |
help="pretrained weights for network / 学習するネットワークの初期重み")
|
|
|
|
|
|
|
|
|
1 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
2 |
import importlib
|
3 |
import argparse
|
4 |
import gc
|
|
|
11 |
from tqdm import tqdm
|
12 |
import torch
|
13 |
from accelerate.utils import set_seed
|
|
|
14 |
from diffusers import DDPMScheduler
|
15 |
|
16 |
import library.train_util as train_util
|
17 |
+
from library.train_util import (
|
18 |
+
DreamBoothDataset,
|
19 |
+
)
|
20 |
+
import library.config_util as config_util
|
21 |
+
from library.config_util import (
|
22 |
+
ConfigSanitizer,
|
23 |
+
BlueprintGenerator,
|
24 |
+
)
|
25 |
|
26 |
|
27 |
def collate_fn(examples):
|
28 |
return examples[0]
|
29 |
|
30 |
|
31 |
+
# TODO 他のスクリプトと共通化する
|
32 |
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
33 |
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
34 |
|
35 |
if args.network_train_unet_only:
|
36 |
+
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
|
37 |
elif args.network_train_text_encoder_only:
|
38 |
+
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
39 |
else:
|
40 |
+
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
41 |
+
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
|
|
|
|
|
42 |
|
43 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
|
44 |
+
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
|
45 |
|
46 |
+
return logs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
|
49 |
def train(args):
|
|
|
54 |
|
55 |
cache_latents = args.cache_latents
|
56 |
use_dreambooth_method = args.in_json is None
|
57 |
+
use_user_config = args.dataset_config is not None
|
58 |
|
59 |
if args.seed is not None:
|
60 |
set_seed(args.seed)
|
|
|
62 |
tokenizer = train_util.load_tokenizer(args)
|
63 |
|
64 |
# データセットを準備する
|
65 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
|
66 |
+
if use_user_config:
|
67 |
+
print(f"Load dataset config from {args.dataset_config}")
|
68 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
69 |
+
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
70 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
71 |
+
print(
|
72 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
73 |
else:
|
74 |
+
if use_dreambooth_method:
|
75 |
+
print("Use DreamBooth method.")
|
76 |
+
user_config = {
|
77 |
+
"datasets": [{
|
78 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
|
79 |
+
}]
|
80 |
+
}
|
81 |
+
else:
|
82 |
+
print("Train with captions.")
|
83 |
+
user_config = {
|
84 |
+
"datasets": [{
|
85 |
+
"subsets": [{
|
86 |
+
"image_dir": args.train_data_dir,
|
87 |
+
"metadata_file": args.in_json,
|
88 |
+
}]
|
89 |
+
}]
|
90 |
+
}
|
91 |
+
|
92 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
93 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
94 |
|
95 |
if args.debug_dataset:
|
96 |
+
train_util.debug_dataset(train_dataset_group)
|
97 |
return
|
98 |
+
if len(train_dataset_group) == 0:
|
99 |
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
100 |
return
|
101 |
|
102 |
+
if cache_latents:
|
103 |
+
assert train_dataset_group.is_latent_cacheable(
|
104 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
105 |
+
|
106 |
# acceleratorを準備する
|
107 |
print("prepare accelerator")
|
108 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
|
|
117 |
if args.lowram:
|
118 |
text_encoder.to("cuda")
|
119 |
unet.to("cuda")
|
120 |
+
|
121 |
# モデルに xformers とか memory efficient attention を組み込む
|
122 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
123 |
|
|
|
127 |
vae.requires_grad_(False)
|
128 |
vae.eval()
|
129 |
with torch.no_grad():
|
130 |
+
train_dataset_group.cache_latents(vae)
|
131 |
vae.to("cpu")
|
132 |
if torch.cuda.is_available():
|
133 |
torch.cuda.empty_cache()
|
|
|
164 |
# 学習に必要なクラスを準備する
|
165 |
print("prepare optimizer, data loader etc.")
|
166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
168 |
+
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
|
|
|
|
169 |
|
170 |
# dataloaderを準備する
|
171 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
172 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
173 |
train_dataloader = torch.utils.data.DataLoader(
|
174 |
+
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
175 |
|
176 |
# 学習ステップ数を計算する
|
177 |
if args.max_train_epochs is not None:
|
|
|
179 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
180 |
|
181 |
# lr schedulerを用意する
|
182 |
+
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
183 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
184 |
+
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
|
|
|
|
185 |
|
186 |
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
187 |
if args.full_fp16:
|
|
|
249 |
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
250 |
|
251 |
# 学習する
|
252 |
+
# TODO: find a way to handle total batch size when there are multiple datasets
|
253 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
254 |
print("running training / 学習開始")
|
255 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
256 |
+
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
257 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
258 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
259 |
+
print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
260 |
+
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
261 |
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
262 |
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
263 |
|
264 |
+
# TODO refactor metadata creation and move to util
|
265 |
metadata = {
|
266 |
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
267 |
"ss_training_started_at": training_started_at, # unix timestamp
|
|
|
269 |
"ss_learning_rate": args.learning_rate,
|
270 |
"ss_text_encoder_lr": args.text_encoder_lr,
|
271 |
"ss_unet_lr": args.unet_lr,
|
272 |
+
"ss_num_train_images": train_dataset_group.num_train_images,
|
273 |
+
"ss_num_reg_images": train_dataset_group.num_reg_images,
|
274 |
"ss_num_batches_per_epoch": len(train_dataloader),
|
275 |
"ss_num_epochs": num_train_epochs,
|
|
|
|
|
276 |
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
277 |
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
278 |
"ss_max_train_steps": args.max_train_steps,
|
|
|
284 |
"ss_mixed_precision": args.mixed_precision,
|
285 |
"ss_full_fp16": bool(args.full_fp16),
|
286 |
"ss_v2": bool(args.v2),
|
|
|
287 |
"ss_clip_skip": args.clip_skip,
|
288 |
"ss_max_token_length": args.max_token_length,
|
|
|
|
|
|
|
|
|
289 |
"ss_cache_latents": bool(args.cache_latents),
|
|
|
|
|
|
|
290 |
"ss_seed": args.seed,
|
291 |
+
"ss_lowram": args.lowram,
|
292 |
"ss_noise_offset": args.noise_offset,
|
|
|
|
|
|
|
|
|
293 |
"ss_training_comment": args.training_comment, # will not be updated after training
|
294 |
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
295 |
+
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
|
296 |
+
"ss_max_grad_norm": args.max_grad_norm,
|
297 |
+
"ss_caption_dropout_rate": args.caption_dropout_rate,
|
298 |
+
"ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
|
299 |
+
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
|
300 |
+
"ss_face_crop_aug_range": args.face_crop_aug_range,
|
301 |
+
"ss_prior_loss_weight": args.prior_loss_weight,
|
302 |
}
|
303 |
|
304 |
+
if use_user_config:
|
305 |
+
# save metadata of multiple datasets
|
306 |
+
# NOTE: pack "ss_datasets" value as json one time
|
307 |
+
# or should also pack nested collections as json?
|
308 |
+
datasets_metadata = []
|
309 |
+
tag_frequency = {} # merge tag frequency for metadata editor
|
310 |
+
dataset_dirs_info = {} # merge subset dirs for metadata editor
|
311 |
+
|
312 |
+
for dataset in train_dataset_group.datasets:
|
313 |
+
is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
|
314 |
+
dataset_metadata = {
|
315 |
+
"is_dreambooth": is_dreambooth_dataset,
|
316 |
+
"batch_size_per_device": dataset.batch_size,
|
317 |
+
"num_train_images": dataset.num_train_images, # includes repeating
|
318 |
+
"num_reg_images": dataset.num_reg_images,
|
319 |
+
"resolution": (dataset.width, dataset.height),
|
320 |
+
"enable_bucket": bool(dataset.enable_bucket),
|
321 |
+
"min_bucket_reso": dataset.min_bucket_reso,
|
322 |
+
"max_bucket_reso": dataset.max_bucket_reso,
|
323 |
+
"tag_frequency": dataset.tag_frequency,
|
324 |
+
"bucket_info": dataset.bucket_info,
|
325 |
+
}
|
326 |
+
|
327 |
+
subsets_metadata = []
|
328 |
+
for subset in dataset.subsets:
|
329 |
+
subset_metadata = {
|
330 |
+
"img_count": subset.img_count,
|
331 |
+
"num_repeats": subset.num_repeats,
|
332 |
+
"color_aug": bool(subset.color_aug),
|
333 |
+
"flip_aug": bool(subset.flip_aug),
|
334 |
+
"random_crop": bool(subset.random_crop),
|
335 |
+
"shuffle_caption": bool(subset.shuffle_caption),
|
336 |
+
"keep_tokens": subset.keep_tokens,
|
337 |
+
}
|
338 |
+
|
339 |
+
image_dir_or_metadata_file = None
|
340 |
+
if subset.image_dir:
|
341 |
+
image_dir = os.path.basename(subset.image_dir)
|
342 |
+
subset_metadata["image_dir"] = image_dir
|
343 |
+
image_dir_or_metadata_file = image_dir
|
344 |
+
|
345 |
+
if is_dreambooth_dataset:
|
346 |
+
subset_metadata["class_tokens"] = subset.class_tokens
|
347 |
+
subset_metadata["is_reg"] = subset.is_reg
|
348 |
+
if subset.is_reg:
|
349 |
+
image_dir_or_metadata_file = None # not merging reg dataset
|
350 |
+
else:
|
351 |
+
metadata_file = os.path.basename(subset.metadata_file)
|
352 |
+
subset_metadata["metadata_file"] = metadata_file
|
353 |
+
image_dir_or_metadata_file = metadata_file # may overwrite
|
354 |
+
|
355 |
+
subsets_metadata.append(subset_metadata)
|
356 |
+
|
357 |
+
# merge dataset dir: not reg subset only
|
358 |
+
# TODO update additional-network extension to show detailed dataset config from metadata
|
359 |
+
if image_dir_or_metadata_file is not None:
|
360 |
+
# datasets may have a certain dir multiple times
|
361 |
+
v = image_dir_or_metadata_file
|
362 |
+
i = 2
|
363 |
+
while v in dataset_dirs_info:
|
364 |
+
v = image_dir_or_metadata_file + f" ({i})"
|
365 |
+
i += 1
|
366 |
+
image_dir_or_metadata_file = v
|
367 |
+
|
368 |
+
dataset_dirs_info[image_dir_or_metadata_file] = {
|
369 |
+
"n_repeats": subset.num_repeats,
|
370 |
+
"img_count": subset.img_count
|
371 |
+
}
|
372 |
+
|
373 |
+
dataset_metadata["subsets"] = subsets_metadata
|
374 |
+
datasets_metadata.append(dataset_metadata)
|
375 |
+
|
376 |
+
# merge tag frequency:
|
377 |
+
for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
|
378 |
+
# あるデ���レクトリが複数のdatasetで使用されている場合、一度だけ数える
|
379 |
+
# もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
|
380 |
+
# なので、ここで複数datasetの回数を合算してもあまり意味はない
|
381 |
+
if ds_dir_name in tag_frequency:
|
382 |
+
continue
|
383 |
+
tag_frequency[ds_dir_name] = ds_freq_for_dir
|
384 |
+
|
385 |
+
metadata["ss_datasets"] = json.dumps(datasets_metadata)
|
386 |
+
metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
|
387 |
+
metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
|
388 |
+
else:
|
389 |
+
# conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
|
390 |
+
assert len(
|
391 |
+
train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
|
392 |
+
|
393 |
+
dataset = train_dataset_group.datasets[0]
|
394 |
+
|
395 |
+
dataset_dirs_info = {}
|
396 |
+
reg_dataset_dirs_info = {}
|
397 |
+
if use_dreambooth_method:
|
398 |
+
for subset in dataset.subsets:
|
399 |
+
info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
|
400 |
+
info[os.path.basename(subset.image_dir)] = {
|
401 |
+
"n_repeats": subset.num_repeats,
|
402 |
+
"img_count": subset.img_count
|
403 |
+
}
|
404 |
+
else:
|
405 |
+
for subset in dataset.subsets:
|
406 |
+
dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
|
407 |
+
"n_repeats": subset.num_repeats,
|
408 |
+
"img_count": subset.img_count
|
409 |
+
}
|
410 |
+
|
411 |
+
metadata.update({
|
412 |
+
"ss_batch_size_per_device": args.train_batch_size,
|
413 |
+
"ss_total_batch_size": total_batch_size,
|
414 |
+
"ss_resolution": args.resolution,
|
415 |
+
"ss_color_aug": bool(args.color_aug),
|
416 |
+
"ss_flip_aug": bool(args.flip_aug),
|
417 |
+
"ss_random_crop": bool(args.random_crop),
|
418 |
+
"ss_shuffle_caption": bool(args.shuffle_caption),
|
419 |
+
"ss_enable_bucket": bool(dataset.enable_bucket),
|
420 |
+
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
|
421 |
+
"ss_min_bucket_reso": dataset.min_bucket_reso,
|
422 |
+
"ss_max_bucket_reso": dataset.max_bucket_reso,
|
423 |
+
"ss_keep_tokens": args.keep_tokens,
|
424 |
+
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
|
425 |
+
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
|
426 |
+
"ss_tag_frequency": json.dumps(dataset.tag_frequency),
|
427 |
+
"ss_bucket_info": json.dumps(dataset.bucket_info),
|
428 |
+
})
|
429 |
+
|
430 |
# uncomment if another network is added
|
431 |
# for key, value in net_kwargs.items():
|
432 |
# metadata["ss_arg_" + key] = value
|
|
|
462 |
loss_total = 0.0
|
463 |
for epoch in range(num_train_epochs):
|
464 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
465 |
+
train_dataset_group.set_current_epoch(epoch + 1)
|
466 |
|
467 |
metadata["ss_epoch"] = str(epoch+1)
|
468 |
|
|
|
499 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
500 |
|
501 |
# Predict the noise residual
|
502 |
+
with accelerator.autocast():
|
503 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
504 |
|
505 |
if args.v_parameterization:
|
|
|
517 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
518 |
|
519 |
accelerator.backward(loss)
|
520 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
521 |
params_to_clip = network.get_trainable_params()
|
522 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
523 |
|
524 |
optimizer.step()
|
525 |
lr_scheduler.step()
|
|
|
530 |
progress_bar.update(1)
|
531 |
global_step += 1
|
532 |
|
533 |
+
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
534 |
+
|
535 |
current_loss = loss.detach().item()
|
536 |
if epoch == 0:
|
537 |
loss_list.append(current_loss)
|
|
|
562 |
def save_func():
|
563 |
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
564 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
565 |
+
metadata["ss_training_finished_at"] = str(time.time())
|
566 |
print(f"saving checkpoint: {ckpt_file}")
|
567 |
unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
568 |
|
|
|
577 |
if saving and args.save_state:
|
578 |
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
579 |
|
580 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
581 |
+
|
582 |
# end of epoch
|
583 |
|
584 |
metadata["ss_epoch"] = str(num_train_epochs)
|
585 |
+
metadata["ss_training_finished_at"] = str(time.time())
|
586 |
|
587 |
is_main_process = accelerator.is_main_process
|
588 |
if is_main_process:
|
|
|
613 |
train_util.add_sd_models_arguments(parser)
|
614 |
train_util.add_dataset_arguments(parser, True, True, True)
|
615 |
train_util.add_training_arguments(parser, True)
|
616 |
+
train_util.add_optimizer_arguments(parser)
|
617 |
+
config_util.add_config_arguments(parser)
|
618 |
|
619 |
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
620 |
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
|
|
622 |
|
623 |
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
624 |
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
|
|
|
|
|
|
|
|
625 |
|
626 |
parser.add_argument("--network_weights", type=str, default=None,
|
627 |
help="pretrained weights for network / 学習するネットワークの初期重み")
|
train_network_opt.py
CHANGED
@@ -1,8 +1,5 @@
|
|
1 |
-
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
2 |
-
from torch.optim import Optimizer
|
3 |
from torch.cuda.amp import autocast
|
4 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
5 |
-
from typing import Optional, Union
|
6 |
import importlib
|
7 |
import argparse
|
8 |
import gc
|
@@ -17,136 +14,47 @@ import torch
|
|
17 |
from accelerate.utils import set_seed
|
18 |
import diffusers
|
19 |
from diffusers import DDPMScheduler
|
20 |
-
|
21 |
-
#先に
|
22 |
-
#pip install torch_optimizer
|
23 |
-
#が必要
|
24 |
-
try:
|
25 |
-
import torch_optimizer as optim
|
26 |
-
except:
|
27 |
-
print("torch_optimizerがインストールされていないためAdafactorとAdastand以外の追加optimzierは使えません。\noptimizerの変更をしたい場合先にpip install torch_optimizerでライブラリを追加してください")
|
28 |
-
try:
|
29 |
-
import adastand
|
30 |
-
except:
|
31 |
-
print("※Adastandが使えません")
|
32 |
-
|
33 |
-
from transformers.optimization import Adafactor, AdafactorSchedule
|
34 |
-
print("**********************************")
|
35 |
##### バケット拡張のためのモジュール
|
36 |
import append_module
|
37 |
######
|
38 |
import library.train_util as train_util
|
39 |
-
from library.train_util import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
|
42 |
def collate_fn(examples):
|
43 |
return examples[0]
|
44 |
|
45 |
|
46 |
-
|
|
|
47 |
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
53 |
else:
|
54 |
last_lrs = lr_scheduler.get_last_lr()
|
55 |
-
|
56 |
-
logs["lr/
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
logs_names = ["textencoder", "lora_unet_mid_block", "unet_down_blocks", "unet_up_blocks"]
|
61 |
-
elif len(last_lrs) == 8:
|
62 |
-
logs_names = ["textencoder", "unet_midblock"]
|
63 |
-
for i in range(3):
|
64 |
-
logs_names.append(f"unet_down_blocks_{i}")
|
65 |
-
logs_names.append(f"unet_up_blocks_{i+1}")
|
66 |
-
else:
|
67 |
-
logs_names = []
|
68 |
-
for i in range(12):
|
69 |
-
logs_names.append(f"text_model_encoder_layers_{i}_")
|
70 |
-
logs_names.append("unet_midblock")
|
71 |
-
for i in range(3):
|
72 |
-
logs_names.append(f"unet_down_blocks_{i}")
|
73 |
-
logs_names.append(f"unet_up_blocks_{i+1}")
|
74 |
-
|
75 |
-
for last_lr, logs_name in zip(last_lrs, logs_names):
|
76 |
-
logs[f"lr/{logs_name}"] = float(last_lr)
|
77 |
|
78 |
return logs
|
79 |
|
80 |
|
81 |
-
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
82 |
-
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
83 |
-
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
84 |
-
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
85 |
-
|
86 |
-
|
87 |
-
def get_scheduler_fix(
|
88 |
-
name: Union[str, SchedulerType],
|
89 |
-
optimizer: Optimizer,
|
90 |
-
num_warmup_steps: Optional[int] = None,
|
91 |
-
num_training_steps: Optional[int] = None,
|
92 |
-
num_cycles: float = 1.,
|
93 |
-
power: float = 1.0,
|
94 |
-
):
|
95 |
-
"""
|
96 |
-
Unified API to get any scheduler from its name.
|
97 |
-
Args:
|
98 |
-
name (`str` or `SchedulerType`):
|
99 |
-
The name of the scheduler to use.
|
100 |
-
optimizer (`torch.optim.Optimizer`):
|
101 |
-
The optimizer that will be used during training.
|
102 |
-
num_warmup_steps (`int`, *optional*):
|
103 |
-
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
104 |
-
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
105 |
-
num_training_steps (`int``, *optional*):
|
106 |
-
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
107 |
-
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
108 |
-
num_cycles (`int`, *optional*):
|
109 |
-
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
110 |
-
power (`float`, *optional*, defaults to 1.0):
|
111 |
-
Power factor. See `POLYNOMIAL` scheduler
|
112 |
-
last_epoch (`int`, *optional*, defaults to -1):
|
113 |
-
The index of the last epoch when resuming training.
|
114 |
-
"""
|
115 |
-
name = SchedulerType(name)
|
116 |
-
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
117 |
-
if name == SchedulerType.CONSTANT:
|
118 |
-
return schedule_func(optimizer)
|
119 |
-
|
120 |
-
# All other schedulers require `num_warmup_steps`
|
121 |
-
if num_warmup_steps is None:
|
122 |
-
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
123 |
-
|
124 |
-
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
125 |
-
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
126 |
-
|
127 |
-
# All other schedulers require `num_training_steps`
|
128 |
-
if num_training_steps is None:
|
129 |
-
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
130 |
-
|
131 |
-
if name == SchedulerType.COSINE:
|
132 |
-
print(f"{name} num_cycles: {num_cycles}")
|
133 |
-
return schedule_func(
|
134 |
-
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
135 |
-
)
|
136 |
-
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
137 |
-
print(f"{name} num_cycles: {int(num_cycles)}")
|
138 |
-
return schedule_func(
|
139 |
-
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=int(num_cycles)
|
140 |
-
)
|
141 |
-
|
142 |
-
if name == SchedulerType.POLYNOMIAL:
|
143 |
-
return schedule_func(
|
144 |
-
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
145 |
-
)
|
146 |
-
|
147 |
-
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
148 |
-
|
149 |
-
|
150 |
def train(args):
|
151 |
session_id = random.randint(0, 2**32)
|
152 |
training_started_at = time.time()
|
@@ -155,6 +63,7 @@ def train(args):
|
|
155 |
|
156 |
cache_latents = args.cache_latents
|
157 |
use_dreambooth_method = args.in_json is None
|
|
|
158 |
|
159 |
if args.seed is not None:
|
160 |
set_seed(args.seed)
|
@@ -162,40 +71,56 @@ def train(args):
|
|
162 |
tokenizer = train_util.load_tokenizer(args)
|
163 |
|
164 |
# データセットを準備する
|
165 |
-
if
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
print("Use DreamBooth method.")
|
172 |
-
train_dataset = append_module.DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
|
173 |
-
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
|
174 |
-
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
175 |
-
args.bucket_reso_steps, args.bucket_no_upscale,
|
176 |
-
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
|
177 |
-
args.random_crop, args.debug_dataset, args.min_resolution, args.area_step)
|
178 |
else:
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
if args.debug_dataset:
|
193 |
-
train_util.debug_dataset(
|
194 |
return
|
195 |
-
if len(
|
196 |
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
197 |
return
|
198 |
|
|
|
|
|
|
|
|
|
199 |
# acceleratorを準備する
|
200 |
print("prepare accelerator")
|
201 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
@@ -205,9 +130,12 @@ def train(args):
|
|
205 |
|
206 |
# モデルを読み込む
|
207 |
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
208 |
-
|
209 |
-
|
210 |
-
|
|
|
|
|
|
|
211 |
# モデルに xformers とか memory efficient attention を組み込む
|
212 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
213 |
|
@@ -217,7 +145,7 @@ def train(args):
|
|
217 |
vae.requires_grad_(False)
|
218 |
vae.eval()
|
219 |
with torch.no_grad():
|
220 |
-
|
221 |
vae.to("cpu")
|
222 |
if torch.cuda.is_available():
|
223 |
torch.cuda.empty_cache()
|
@@ -253,165 +181,45 @@ def train(args):
|
|
253 |
|
254 |
# 学習に必要なクラスを準備する
|
255 |
print("prepare optimizer, data loader etc.")
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
except:
|
260 |
-
not_torch_optimizer_flag = True
|
261 |
-
try:
|
262 |
-
print(f"adastand version is {adastand.__version__()}")
|
263 |
-
not_adasatand_optimzier_flag = False
|
264 |
-
except:
|
265 |
-
not_adasatand_optimzier_flag = True
|
266 |
-
|
267 |
-
# 8-bit Adamを使う
|
268 |
-
if args.optimizer=="Adafactor" or args.optimizer=="Adastand" or args.optimizer=="Adastand_belief":
|
269 |
-
not_torch_optimizer_flag = False
|
270 |
-
if args.optimizer=="Adafactor":
|
271 |
-
not_adasatand_optimzier_flag = False
|
272 |
-
if not_torch_optimizer_flag or not_adasatand_optimzier_flag:
|
273 |
-
print(f"==========================\n必要なライブラリがないため {args.optimizer} の使用ができません。optimizerを AdamW に変更して実行します\n==========================")
|
274 |
-
args.optimizer="AdamW"
|
275 |
-
if args.use_8bit_adam:
|
276 |
-
if not args.optimizer=="AdamW" and not args.optimizer=="Lamb":
|
277 |
-
print(f"\n==========================\n{args.optimizer} は8bitAdamに実装されていないので8bitAdamをオフにします\n==========================\n")
|
278 |
-
args.use_8bit_adam=False
|
279 |
-
if args.use_8bit_adam:
|
280 |
-
try:
|
281 |
-
import bitsandbytes as bnb
|
282 |
-
except ImportError:
|
283 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
284 |
-
print("use 8-bit Adam optimizer")
|
285 |
-
args.training_comment=f"{args.training_comment} use_8bit_adam=True"
|
286 |
-
if args.optimizer=="Lamb":
|
287 |
-
optimizer_class = bnb.optim.LAMB8bit
|
288 |
-
else:
|
289 |
-
args.optimizer="AdamW"
|
290 |
-
optimizer_class = bnb.optim.AdamW8bit
|
291 |
-
else:
|
292 |
-
print(f"use {args.optimizer}")
|
293 |
-
if args.optimizer=="RAdam":
|
294 |
-
optimizer_class = torch.optim.RAdam
|
295 |
-
elif args.optimizer=="AdaBound":
|
296 |
-
optimizer_class = optim.AdaBound
|
297 |
-
elif args.optimizer=="AdaBelief":
|
298 |
-
optimizer_class = optim.AdaBelief
|
299 |
-
elif args.optimizer=="AdamP":
|
300 |
-
optimizer_class = optim.AdamP
|
301 |
-
elif args.optimizer=="Adafactor":
|
302 |
-
optimizer_class = Adafactor
|
303 |
-
elif args.optimizer=="Adastand":
|
304 |
-
optimizer_class = adastand.Adastand
|
305 |
-
elif args.optimizer=="Adastand_belief":
|
306 |
-
optimizer_class = adastand.Adastand_b
|
307 |
-
elif args.optimizer=="AggMo":
|
308 |
-
optimizer_class = optim.AggMo
|
309 |
-
elif args.optimizer=="Apollo":
|
310 |
-
optimizer_class = optim.Apollo
|
311 |
-
elif args.optimizer=="Lamb":
|
312 |
-
optimizer_class = optim.Lamb
|
313 |
-
elif args.optimizer=="Ranger":
|
314 |
-
optimizer_class = optim.Ranger
|
315 |
-
elif args.optimizer=="RangerVA":
|
316 |
-
optimizer_class = optim.RangerVA
|
317 |
-
elif args.optimizer=="Yogi":
|
318 |
-
optimizer_class = optim.Yogi
|
319 |
-
elif args.optimizer=="Shampoo":
|
320 |
-
optimizer_class = optim.Shampoo
|
321 |
-
elif args.optimizer=="NovoGrad":
|
322 |
-
optimizer_class = optim.NovoGrad
|
323 |
-
elif args.optimizer=="QHAdam":
|
324 |
-
optimizer_class = optim.QHAdam
|
325 |
-
elif args.optimizer=="DiffGrad" or args.optimizer=="Lookahead_DiffGrad":
|
326 |
-
optimizer_class = optim.DiffGrad
|
327 |
-
elif args.optimizer=="MADGRAD":
|
328 |
-
optimizer_class = optim.MADGRAD
|
329 |
-
else:
|
330 |
-
optimizer_class = torch.optim.AdamW
|
331 |
-
|
332 |
-
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
333 |
-
#optimizerデフォ設定
|
334 |
-
if args.optimizer_arg==None:
|
335 |
-
if args.optimizer=="AdaBelief":
|
336 |
-
args.optimizer_arg = ["eps=1e-16","betas=0.9,0.999","weight_decouple=True","rectify=False","fixed_decay=False"]
|
337 |
-
elif args.optimizer=="DiffGrad":
|
338 |
-
args.optimizer_arg = ["eps=1e-16"]
|
339 |
-
optimizer_arg = {}
|
340 |
-
lookahed_arg = {"k": 5, "alpha": 0.5}
|
341 |
-
adafactor_scheduler_arg = {"initial_lr": 0.}
|
342 |
-
int_args = ["k","n_sma_threshold","warmup"]
|
343 |
-
str_args = ["transformer","grad_transformer"]
|
344 |
-
if not args.optimizer_arg==None and len(args.optimizer_arg)>0:
|
345 |
-
for _opt_arg in args.optimizer_arg:
|
346 |
-
key, value = _opt_arg.split("=")
|
347 |
-
if value=="True" or value=="False":
|
348 |
-
optimizer_arg[key]=bool((value=="True"))
|
349 |
-
elif key=="betas" or key=="nus" or key=="eps2" or (key=="eps" and "," in value):
|
350 |
-
_value = value.split(",")
|
351 |
-
optimizer_arg[key] = (float(_value[0]),float(_value[1]))
|
352 |
-
del _value
|
353 |
-
elif key in int_args:
|
354 |
-
if "Lookahead" in args.optimizer:
|
355 |
-
lookahed_arg[key] = int(value)
|
356 |
-
else:
|
357 |
-
optimizer_arg[key] = int(value)
|
358 |
-
elif key in str_args:
|
359 |
-
optimizer_arg[key] = value
|
360 |
-
else:
|
361 |
-
if key=="alpha" and "Lookahead" in args.optimizer:
|
362 |
-
lookahed_arg[key] = int(value)
|
363 |
-
elif key=="initial_lr" and args.optimizer == "Adafactor":
|
364 |
-
adafactor_scheduler_arg[key] = float(value)
|
365 |
-
else:
|
366 |
-
optimizer_arg[key] = float(value)
|
367 |
-
del _opt_arg
|
368 |
-
AdafactorScheduler_Flag = False
|
369 |
-
list_of_init_lr = []
|
370 |
-
if args.optimizer=="Adafactor":
|
371 |
-
if not "relative_step" in optimizer_arg:
|
372 |
-
optimizer_arg["relative_step"] = True
|
373 |
-
if "warmup_init" in optimizer_arg:
|
374 |
-
if optimizer_arg["warmup_init"]==True and optimizer_arg["relative_step"]==False:
|
375 |
-
print("**************\nwarmup_initはrelative_stepがオンである必要があるためrelative_stepをオンにします\n**************")
|
376 |
-
optimizer_arg["relative_step"] = True
|
377 |
-
if optimizer_arg["relative_step"] == True:
|
378 |
-
AdafactorScheduler_Flag = True
|
379 |
-
list_of_init_lr = [0.,0.]
|
380 |
-
if args.text_encoder_lr is not None: list_of_init_lr[0] = float(args.text_encoder_lr)
|
381 |
-
if args.unet_lr is not None: list_of_init_lr[1] = float(args.unet_lr)
|
382 |
-
#if not "initial_lr" in adafactor_scheduler_arg:
|
383 |
-
# adafactor_scheduler_arg = args.learning_rate
|
384 |
-
args.learning_rate = None
|
385 |
-
args.text_encoder_lr = None
|
386 |
-
args.unet_lr = None
|
387 |
-
print(f"optimizer arg: {optimizer_arg}")
|
388 |
-
print("=-----------------------------------=")
|
389 |
-
if not AdafactorScheduler_Flag: args.split_lora_networks = False
|
390 |
if args.split_lora_networks:
|
|
|
391 |
lora_names = append_module.create_split_names(args.split_lora_networks, args.split_lora_level)
|
392 |
append_module.replace_prepare_optimizer_params(network)
|
393 |
-
trainable_params,
|
394 |
else:
|
395 |
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
if args.
|
407 |
-
|
408 |
-
|
409 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
# dataloaderを準備する
|
411 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
412 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
413 |
train_dataloader = torch.utils.data.DataLoader(
|
414 |
-
|
415 |
|
416 |
# 学習ステップ数を計算する
|
417 |
if args.max_train_epochs is not None:
|
@@ -419,22 +227,18 @@ def train(args):
|
|
419 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
420 |
|
421 |
# lr schedulerを用意する
|
422 |
-
|
423 |
-
|
424 |
-
print("===================================\nAdafactorはデフォルトでrelative_stepがオンになっているので lrは自動算出されるためLrScheculerの指定も無効になります\nもし任意のLrやLr_Schedulerを使いたい場合は --optimizer_arg relative_ste=False を指定してください\nまた任意のLrを使う場合は scale_parameter=False も併せて指定するのが推奨です\n===================================")
|
425 |
-
lr_scheduler = append_module.AdafactorSchedule_append(optimizer, **adafactor_scheduler_arg)
|
426 |
-
print(f"AdafactorSchedule initial lrs: {lr_scheduler.get_lr()}")
|
427 |
-
del list_of_init_lr
|
428 |
else:
|
429 |
-
lr_scheduler = get_scheduler_fix(
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
#追加機能の設定をコメントに追記して残す
|
435 |
-
|
436 |
-
|
437 |
-
|
|
|
438 |
if args.min_resolution:
|
439 |
args.training_comment=f"{args.training_comment} min_resolution: {args.min_resolution} area_step: {args.area_step}"
|
440 |
|
@@ -504,17 +308,19 @@ def train(args):
|
|
504 |
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
505 |
|
506 |
# 学習する
|
|
|
507 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
508 |
print("running training / 学習開始")
|
509 |
-
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {
|
510 |
-
print(f" num reg images / 正則化画像の数: {
|
511 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
512 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
513 |
-
print(f" batch size per device / バッチサイズ: {
|
514 |
-
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
515 |
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
516 |
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
517 |
|
|
|
518 |
metadata = {
|
519 |
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
520 |
"ss_training_started_at": training_started_at, # unix timestamp
|
@@ -522,12 +328,10 @@ def train(args):
|
|
522 |
"ss_learning_rate": args.learning_rate,
|
523 |
"ss_text_encoder_lr": args.text_encoder_lr,
|
524 |
"ss_unet_lr": args.unet_lr,
|
525 |
-
"ss_num_train_images":
|
526 |
-
"ss_num_reg_images":
|
527 |
"ss_num_batches_per_epoch": len(train_dataloader),
|
528 |
"ss_num_epochs": num_train_epochs,
|
529 |
-
"ss_batch_size_per_device": args.train_batch_size,
|
530 |
-
"ss_total_batch_size": total_batch_size,
|
531 |
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
532 |
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
533 |
"ss_max_train_steps": args.max_train_steps,
|
@@ -539,28 +343,149 @@ def train(args):
|
|
539 |
"ss_mixed_precision": args.mixed_precision,
|
540 |
"ss_full_fp16": bool(args.full_fp16),
|
541 |
"ss_v2": bool(args.v2),
|
542 |
-
"ss_resolution": args.resolution,
|
543 |
"ss_clip_skip": args.clip_skip,
|
544 |
"ss_max_token_length": args.max_token_length,
|
545 |
-
"ss_color_aug": bool(args.color_aug),
|
546 |
-
"ss_flip_aug": bool(args.flip_aug),
|
547 |
-
"ss_random_crop": bool(args.random_crop),
|
548 |
-
"ss_shuffle_caption": bool(args.shuffle_caption),
|
549 |
"ss_cache_latents": bool(args.cache_latents),
|
550 |
-
"ss_enable_bucket": bool(train_dataset.enable_bucket),
|
551 |
-
"ss_min_bucket_reso": train_dataset.min_bucket_reso,
|
552 |
-
"ss_max_bucket_reso": train_dataset.max_bucket_reso,
|
553 |
"ss_seed": args.seed,
|
554 |
-
"
|
555 |
"ss_noise_offset": args.noise_offset,
|
556 |
-
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
557 |
-
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
558 |
-
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
|
559 |
-
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
|
560 |
"ss_training_comment": args.training_comment, # will not be updated after training
|
561 |
-
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
562 |
}
|
563 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
564 |
# uncomment if another network is added
|
565 |
# for key, value in net_kwargs.items():
|
566 |
# metadata["ss_arg_" + key] = value
|
@@ -596,7 +521,7 @@ def train(args):
|
|
596 |
loss_total = 0.0
|
597 |
for epoch in range(num_train_epochs):
|
598 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
599 |
-
|
600 |
|
601 |
metadata["ss_epoch"] = str(epoch+1)
|
602 |
|
@@ -633,7 +558,7 @@ def train(args):
|
|
633 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
634 |
|
635 |
# Predict the noise residual
|
636 |
-
with autocast():
|
637 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
638 |
|
639 |
if args.v_parameterization:
|
@@ -651,12 +576,18 @@ def train(args):
|
|
651 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
652 |
|
653 |
accelerator.backward(loss)
|
654 |
-
if accelerator.sync_gradients:
|
655 |
params_to_clip = network.get_trainable_params()
|
656 |
-
accelerator.clip_grad_norm_(params_to_clip,
|
657 |
|
|
|
658 |
optimizer.step()
|
659 |
-
lr_scheduler.
|
|
|
|
|
|
|
|
|
|
|
660 |
optimizer.zero_grad(set_to_none=True)
|
661 |
|
662 |
# Checks if the accelerator has performed an optimization step behind the scenes
|
@@ -664,6 +595,8 @@ def train(args):
|
|
664 |
progress_bar.update(1)
|
665 |
global_step += 1
|
666 |
|
|
|
|
|
667 |
current_loss = loss.detach().item()
|
668 |
if epoch == 0:
|
669 |
loss_list.append(current_loss)
|
@@ -676,7 +609,7 @@ def train(args):
|
|
676 |
progress_bar.set_postfix(**logs)
|
677 |
|
678 |
if args.logging_dir is not None:
|
679 |
-
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
680 |
accelerator.log(logs, step=global_step)
|
681 |
|
682 |
if global_step >= args.max_train_steps:
|
@@ -694,6 +627,7 @@ def train(args):
|
|
694 |
def save_func():
|
695 |
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
696 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
|
|
697 |
print(f"saving checkpoint: {ckpt_file}")
|
698 |
unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
699 |
|
@@ -708,9 +642,12 @@ def train(args):
|
|
708 |
if saving and args.save_state:
|
709 |
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
710 |
|
|
|
|
|
711 |
# end of epoch
|
712 |
|
713 |
metadata["ss_epoch"] = str(num_train_epochs)
|
|
|
714 |
|
715 |
is_main_process = accelerator.is_main_process
|
716 |
if is_main_process:
|
@@ -741,6 +678,8 @@ if __name__ == '__main__':
|
|
741 |
train_util.add_sd_models_arguments(parser)
|
742 |
train_util.add_dataset_arguments(parser, True, True, True)
|
743 |
train_util.add_training_arguments(parser, True)
|
|
|
|
|
744 |
|
745 |
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
746 |
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
@@ -748,10 +687,6 @@ if __name__ == '__main__':
|
|
748 |
|
749 |
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
750 |
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
751 |
-
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
752 |
-
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
753 |
-
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
754 |
-
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
755 |
|
756 |
parser.add_argument("--network_weights", type=str, default=None,
|
757 |
help="pretrained weights for network / 学習するネットワークの初期重み")
|
@@ -771,27 +706,29 @@ if __name__ == '__main__':
|
|
771 |
#Optimizer変更関連のオプション追加
|
772 |
append_module.add_append_arguments(parser)
|
773 |
args = append_module.get_config(parser)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
774 |
|
775 |
if args.resolution==args.min_resolution:
|
776 |
args.min_resolution=None
|
777 |
|
778 |
train(args)
|
|
|
779 |
|
780 |
-
#学習が終わったら現在のargsを保存する
|
781 |
-
# import yaml
|
782 |
-
# import datetime
|
783 |
-
# _t = datetime.datetime.today().strftime('%Y%m%d_%H%M')
|
784 |
-
# if args.output_name==None:
|
785 |
-
# config_name = f"train_network_config_{_t}.yaml"
|
786 |
-
# else:
|
787 |
-
# config_name = f"train_network_config_{os.path.basename(args.output_name)}_{_t}.yaml"
|
788 |
-
# print(f"{config_name} に設定を書き出し中...")
|
789 |
-
# with open(config_name, mode="w") as f:
|
790 |
-
# yaml.dump(args.__dict__, f, indent=4)
|
791 |
-
# print("done!")
|
792 |
|
793 |
'''
|
794 |
optimizer設定メモ
|
|
|
|
|
795 |
(optimizer_argから設定できるように変更するためのメモ)
|
796 |
|
797 |
AdamWのweight_decay初期値は1e-2
|
@@ -821,6 +758,7 @@ Adafactor
|
|
821 |
transformerベースのT5学習において最強とかいう噂のoptimizer
|
822 |
huggingfaceのサンプルパラ
|
823 |
eps=1e-30,1e-3 clip_threshold=1.0 decay_rate=-0.8 relative_step=False scale_parameter=False warmup_init=False
|
|
|
824 |
|
825 |
AggMo
|
826 |
|
|
|
|
|
|
|
1 |
from torch.cuda.amp import autocast
|
2 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
3 |
import importlib
|
4 |
import argparse
|
5 |
import gc
|
|
|
14 |
from accelerate.utils import set_seed
|
15 |
import diffusers
|
16 |
from diffusers import DDPMScheduler
|
17 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
##### バケット拡張のためのモジュール
|
19 |
import append_module
|
20 |
######
|
21 |
import library.train_util as train_util
|
22 |
+
from library.train_util import (
|
23 |
+
DreamBoothDataset,
|
24 |
+
)
|
25 |
+
import library.config_util as config_util
|
26 |
+
from library.config_util import (
|
27 |
+
ConfigSanitizer,
|
28 |
+
BlueprintGenerator,
|
29 |
+
)
|
30 |
|
31 |
|
32 |
def collate_fn(examples):
|
33 |
return examples[0]
|
34 |
|
35 |
|
36 |
+
# TODO 他のスクリプトと共通化する
|
37 |
+
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, split_names=None):
|
38 |
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
39 |
+
if not args.split_lora_networks:
|
40 |
+
if args.network_train_unet_only:
|
41 |
+
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
|
42 |
+
elif args.network_train_text_encoder_only:
|
43 |
+
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
44 |
+
else:
|
45 |
+
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
46 |
+
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
|
47 |
else:
|
48 |
last_lrs = lr_scheduler.get_last_lr()
|
49 |
+
for last_lr, t_name in zip(last_lrs, split_names):
|
50 |
+
logs[f"lr/{t_name}"] = float(last_lr)
|
51 |
+
#D-Adaptationの仕様ちゃんと見てないからたぶん分割したのをちゃんと表示するならそれに合わせた記述が必要 でも多分D-Adaptationの挙動的に全部同一の形になるのでいらない
|
52 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
|
53 |
+
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
return logs
|
56 |
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
def train(args):
|
59 |
session_id = random.randint(0, 2**32)
|
60 |
training_started_at = time.time()
|
|
|
63 |
|
64 |
cache_latents = args.cache_latents
|
65 |
use_dreambooth_method = args.in_json is None
|
66 |
+
use_user_config = args.dataset_config is not None
|
67 |
|
68 |
if args.seed is not None:
|
69 |
set_seed(args.seed)
|
|
|
71 |
tokenizer = train_util.load_tokenizer(args)
|
72 |
|
73 |
# データセットを準備する
|
74 |
+
if args.min_resolution:
|
75 |
+
args.min_resolution = tuple([int(r) for r in args.min_resolution.split(',')])
|
76 |
+
if len(args.min_resolution) == 1:
|
77 |
+
args.min_resolution = (args.min_resolution[0], args.min_resolution[0])
|
78 |
+
blueprint_generator = append_module.BlueprintGenerator(append_module.ConfigSanitizer(True, True, True))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
else:
|
80 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
|
81 |
+
if use_user_config:
|
82 |
+
print(f"Load dataset config from {args.dataset_config}")
|
83 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
84 |
+
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
85 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
86 |
+
print(
|
87 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
88 |
+
else:
|
89 |
+
if use_dreambooth_method:
|
90 |
+
print("Use DreamBooth method.")
|
91 |
+
user_config = {
|
92 |
+
"datasets": [{
|
93 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
|
94 |
+
}]
|
95 |
+
}
|
96 |
+
else:
|
97 |
+
print("Train with captions.")
|
98 |
+
user_config = {
|
99 |
+
"datasets": [{
|
100 |
+
"subsets": [{
|
101 |
+
"image_dir": args.train_data_dir,
|
102 |
+
"metadata_file": args.in_json,
|
103 |
+
}]
|
104 |
+
}]
|
105 |
+
}
|
106 |
+
|
107 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
108 |
+
if args.min_resolution:
|
109 |
+
train_dataset_group = append_module.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
110 |
+
else:
|
111 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
112 |
|
113 |
if args.debug_dataset:
|
114 |
+
train_util.debug_dataset(train_dataset_group)
|
115 |
return
|
116 |
+
if len(train_dataset_group) == 0:
|
117 |
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
118 |
return
|
119 |
|
120 |
+
if cache_latents:
|
121 |
+
assert train_dataset_group.is_latent_cacheable(
|
122 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
123 |
+
|
124 |
# acceleratorを準備する
|
125 |
print("prepare accelerator")
|
126 |
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
|
|
130 |
|
131 |
# モデルを読み込む
|
132 |
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
133 |
+
|
134 |
+
# work on low-ram device
|
135 |
+
if args.lowram:
|
136 |
+
text_encoder.to("cuda")
|
137 |
+
unet.to("cuda")
|
138 |
+
|
139 |
# モデルに xformers とか memory efficient attention を組み込む
|
140 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
141 |
|
|
|
145 |
vae.requires_grad_(False)
|
146 |
vae.eval()
|
147 |
with torch.no_grad():
|
148 |
+
train_dataset_group.cache_latents(vae)
|
149 |
vae.to("cpu")
|
150 |
if torch.cuda.is_available():
|
151 |
torch.cuda.empty_cache()
|
|
|
181 |
|
182 |
# 学習に必要なクラスを準備する
|
183 |
print("prepare optimizer, data loader etc.")
|
184 |
+
split_flag = (args.split_lora_networks) or ((not args.network_train_text_encoder_only) and (not args.network_train_unet_only))
|
185 |
+
|
186 |
+
used_names = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
if args.split_lora_networks:
|
188 |
+
lr_dic, block_args_dic = append_module.create_lr_blocks(args.blocks_lr_setting, args.block_optim_args)
|
189 |
lora_names = append_module.create_split_names(args.split_lora_networks, args.split_lora_level)
|
190 |
append_module.replace_prepare_optimizer_params(network)
|
191 |
+
trainable_params, adafactor_scheduler_arg, used_names = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, lora_names, lr_dic, block_args_dic)
|
192 |
else:
|
193 |
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
194 |
+
if split_flag:
|
195 |
+
_t_lr = 0.
|
196 |
+
_u_lr = 0.
|
197 |
+
if args.text_encoder_lr:
|
198 |
+
_t_lr = args.text_encoder_lr
|
199 |
+
if args.unet_lr:
|
200 |
+
_u_lr = args.unet_lr
|
201 |
+
adafactor_scheduler_arg = {"initial_lr": [_t_lr, _u_lr]}
|
202 |
+
|
203 |
+
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
204 |
+
if args.use_lookahead:
|
205 |
+
try:
|
206 |
+
import torch_optimizer
|
207 |
+
lookahed_arg = {"k": 5, "alpha": 0.5}
|
208 |
+
if args.lookahead_arg is not None:
|
209 |
+
for _arg in args.lookahead_arg:
|
210 |
+
k, v = _arg.split("=")
|
211 |
+
if k == "k":
|
212 |
+
lookahed_arg[k] = int(v)
|
213 |
+
else:
|
214 |
+
lookahed_arg[k] = float(v)
|
215 |
+
optimizer = torch_optimizer.Lookahead(optimizer, **lookahed_arg)
|
216 |
+
except:
|
217 |
+
print("\n============\ntorch_optimizerのimportに失敗しました Lookaheadを無効化して処理を続けます\n============\n")
|
218 |
# dataloaderを準備する
|
219 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
220 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
221 |
train_dataloader = torch.utils.data.DataLoader(
|
222 |
+
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
223 |
|
224 |
# 学習ステップ数を計算する
|
225 |
if args.max_train_epochs is not None:
|
|
|
227 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
228 |
|
229 |
# lr schedulerを用意する
|
230 |
+
if args.lr_scheduler.startswith("adafactor") and split_flag:
|
231 |
+
lr_scheduler = append_module.get_scheduler_Adafactor(args.lr_scheduler, optimizer, adafactor_scheduler_arg)
|
|
|
|
|
|
|
|
|
232 |
else:
|
233 |
+
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
234 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
235 |
+
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
236 |
+
|
|
|
237 |
#追加機能の設定をコメントに追記して残す
|
238 |
+
if args.use_lookahead:
|
239 |
+
args.training_comment=f"{args.training_comment} use Lookahead: True Lookahead args: {lookahed_arg}"
|
240 |
+
if args.split_lora_networks:
|
241 |
+
args.training_comment=f"{args.training_comment} split_lora_networks: {args.split_lora_networks} split_level: {args.split_lora_level}"
|
242 |
if args.min_resolution:
|
243 |
args.training_comment=f"{args.training_comment} min_resolution: {args.min_resolution} area_step: {args.area_step}"
|
244 |
|
|
|
308 |
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
309 |
|
310 |
# 学習する
|
311 |
+
# TODO: find a way to handle total batch size when there are multiple datasets
|
312 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
313 |
print("running training / 学習開始")
|
314 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
315 |
+
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
316 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
317 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
318 |
+
print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
319 |
+
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
320 |
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
321 |
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
322 |
|
323 |
+
# TODO refactor metadata creation and move to util
|
324 |
metadata = {
|
325 |
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
326 |
"ss_training_started_at": training_started_at, # unix timestamp
|
|
|
328 |
"ss_learning_rate": args.learning_rate,
|
329 |
"ss_text_encoder_lr": args.text_encoder_lr,
|
330 |
"ss_unet_lr": args.unet_lr,
|
331 |
+
"ss_num_train_images": train_dataset_group.num_train_images,
|
332 |
+
"ss_num_reg_images": train_dataset_group.num_reg_images,
|
333 |
"ss_num_batches_per_epoch": len(train_dataloader),
|
334 |
"ss_num_epochs": num_train_epochs,
|
|
|
|
|
335 |
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
336 |
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
337 |
"ss_max_train_steps": args.max_train_steps,
|
|
|
343 |
"ss_mixed_precision": args.mixed_precision,
|
344 |
"ss_full_fp16": bool(args.full_fp16),
|
345 |
"ss_v2": bool(args.v2),
|
|
|
346 |
"ss_clip_skip": args.clip_skip,
|
347 |
"ss_max_token_length": args.max_token_length,
|
|
|
|
|
|
|
|
|
348 |
"ss_cache_latents": bool(args.cache_latents),
|
|
|
|
|
|
|
349 |
"ss_seed": args.seed,
|
350 |
+
"ss_lowram": args.lowram,
|
351 |
"ss_noise_offset": args.noise_offset,
|
|
|
|
|
|
|
|
|
352 |
"ss_training_comment": args.training_comment, # will not be updated after training
|
353 |
+
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
354 |
+
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
|
355 |
+
"ss_max_grad_norm": args.max_grad_norm,
|
356 |
+
"ss_caption_dropout_rate": args.caption_dropout_rate,
|
357 |
+
"ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
|
358 |
+
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
|
359 |
+
"ss_face_crop_aug_range": args.face_crop_aug_range,
|
360 |
+
"ss_prior_loss_weight": args.prior_loss_weight,
|
361 |
}
|
362 |
|
363 |
+
if use_user_config:
|
364 |
+
# save metadata of multiple datasets
|
365 |
+
# NOTE: pack "ss_datasets" value as json one time
|
366 |
+
# or should also pack nested collections as json?
|
367 |
+
datasets_metadata = []
|
368 |
+
tag_frequency = {} # merge tag frequency for metadata editor
|
369 |
+
dataset_dirs_info = {} # merge subset dirs for metadata editor
|
370 |
+
|
371 |
+
for dataset in train_dataset_group.datasets:
|
372 |
+
is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
|
373 |
+
dataset_metadata = {
|
374 |
+
"is_dreambooth": is_dreambooth_dataset,
|
375 |
+
"batch_size_per_device": dataset.batch_size,
|
376 |
+
"num_train_images": dataset.num_train_images, # includes repeating
|
377 |
+
"num_reg_images": dataset.num_reg_images,
|
378 |
+
"resolution": (dataset.width, dataset.height),
|
379 |
+
"enable_bucket": bool(dataset.enable_bucket),
|
380 |
+
"min_bucket_reso": dataset.min_bucket_reso,
|
381 |
+
"max_bucket_reso": dataset.max_bucket_reso,
|
382 |
+
"tag_frequency": dataset.tag_frequency,
|
383 |
+
"bucket_info": dataset.bucket_info,
|
384 |
+
}
|
385 |
+
|
386 |
+
subsets_metadata = []
|
387 |
+
for subset in dataset.subsets:
|
388 |
+
subset_metadata = {
|
389 |
+
"img_count": subset.img_count,
|
390 |
+
"num_repeats": subset.num_repeats,
|
391 |
+
"color_aug": bool(subset.color_aug),
|
392 |
+
"flip_aug": bool(subset.flip_aug),
|
393 |
+
"random_crop": bool(subset.random_crop),
|
394 |
+
"shuffle_caption": bool(subset.shuffle_caption),
|
395 |
+
"keep_tokens": subset.keep_tokens,
|
396 |
+
}
|
397 |
+
|
398 |
+
image_dir_or_metadata_file = None
|
399 |
+
if subset.image_dir:
|
400 |
+
image_dir = os.path.basename(subset.image_dir)
|
401 |
+
subset_metadata["image_dir"] = image_dir
|
402 |
+
image_dir_or_metadata_file = image_dir
|
403 |
+
|
404 |
+
if is_dreambooth_dataset:
|
405 |
+
subset_metadata["class_tokens"] = subset.class_tokens
|
406 |
+
subset_metadata["is_reg"] = subset.is_reg
|
407 |
+
if subset.is_reg:
|
408 |
+
image_dir_or_metadata_file = None # not merging reg dataset
|
409 |
+
else:
|
410 |
+
metadata_file = os.path.basename(subset.metadata_file)
|
411 |
+
subset_metadata["metadata_file"] = metadata_file
|
412 |
+
image_dir_or_metadata_file = metadata_file # may overwrite
|
413 |
+
|
414 |
+
subsets_metadata.append(subset_metadata)
|
415 |
+
|
416 |
+
# merge dataset dir: not reg subset only
|
417 |
+
# TODO update additional-network extension to show detailed dataset config from metadata
|
418 |
+
if image_dir_or_metadata_file is not None:
|
419 |
+
# datasets may have a certain dir multiple times
|
420 |
+
v = image_dir_or_metadata_file
|
421 |
+
i = 2
|
422 |
+
while v in dataset_dirs_info:
|
423 |
+
v = image_dir_or_metadata_file + f" ({i})"
|
424 |
+
i += 1
|
425 |
+
image_dir_or_metadata_file = v
|
426 |
+
|
427 |
+
dataset_dirs_info[image_dir_or_metadata_file] = {
|
428 |
+
"n_repeats": subset.num_repeats,
|
429 |
+
"img_count": subset.img_count
|
430 |
+
}
|
431 |
+
|
432 |
+
dataset_metadata["subsets"] = subsets_metadata
|
433 |
+
datasets_metadata.append(dataset_metadata)
|
434 |
+
|
435 |
+
# merge tag frequency:
|
436 |
+
for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
|
437 |
+
# あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
|
438 |
+
# もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
|
439 |
+
# なので、ここで複数datasetの回数を合算してもあまり意味はない
|
440 |
+
if ds_dir_name in tag_frequency:
|
441 |
+
continue
|
442 |
+
tag_frequency[ds_dir_name] = ds_freq_for_dir
|
443 |
+
|
444 |
+
metadata["ss_datasets"] = json.dumps(datasets_metadata)
|
445 |
+
metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
|
446 |
+
metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
|
447 |
+
else:
|
448 |
+
# conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
|
449 |
+
assert len(
|
450 |
+
train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
|
451 |
+
|
452 |
+
dataset = train_dataset_group.datasets[0]
|
453 |
+
|
454 |
+
dataset_dirs_info = {}
|
455 |
+
reg_dataset_dirs_info = {}
|
456 |
+
if use_dreambooth_method:
|
457 |
+
for subset in dataset.subsets:
|
458 |
+
info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
|
459 |
+
info[os.path.basename(subset.image_dir)] = {
|
460 |
+
"n_repeats": subset.num_repeats,
|
461 |
+
"img_count": subset.img_count
|
462 |
+
}
|
463 |
+
else:
|
464 |
+
for subset in dataset.subsets:
|
465 |
+
dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
|
466 |
+
"n_repeats": subset.num_repeats,
|
467 |
+
"img_count": subset.img_count
|
468 |
+
}
|
469 |
+
|
470 |
+
metadata.update({
|
471 |
+
"ss_batch_size_per_device": args.train_batch_size,
|
472 |
+
"ss_total_batch_size": total_batch_size,
|
473 |
+
"ss_resolution": args.resolution,
|
474 |
+
"ss_color_aug": bool(args.color_aug),
|
475 |
+
"ss_flip_aug": bool(args.flip_aug),
|
476 |
+
"ss_random_crop": bool(args.random_crop),
|
477 |
+
"ss_shuffle_caption": bool(args.shuffle_caption),
|
478 |
+
"ss_enable_bucket": bool(dataset.enable_bucket),
|
479 |
+
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
|
480 |
+
"ss_min_bucket_reso": dataset.min_bucket_reso,
|
481 |
+
"ss_max_bucket_reso": dataset.max_bucket_reso,
|
482 |
+
"ss_keep_tokens": args.keep_tokens,
|
483 |
+
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
|
484 |
+
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
|
485 |
+
"ss_tag_frequency": json.dumps(dataset.tag_frequency),
|
486 |
+
"ss_bucket_info": json.dumps(dataset.bucket_info),
|
487 |
+
})
|
488 |
+
|
489 |
# uncomment if another network is added
|
490 |
# for key, value in net_kwargs.items():
|
491 |
# metadata["ss_arg_" + key] = value
|
|
|
521 |
loss_total = 0.0
|
522 |
for epoch in range(num_train_epochs):
|
523 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
524 |
+
train_dataset_group.set_current_epoch(epoch + 1)
|
525 |
|
526 |
metadata["ss_epoch"] = str(epoch+1)
|
527 |
|
|
|
558 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
559 |
|
560 |
# Predict the noise residual
|
561 |
+
with accelerator.autocast():
|
562 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
563 |
|
564 |
if args.v_parameterization:
|
|
|
576 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
577 |
|
578 |
accelerator.backward(loss)
|
579 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
580 |
params_to_clip = network.get_trainable_params()
|
581 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
582 |
|
583 |
+
scale = accelerator.scaler.get_scale()
|
584 |
optimizer.step()
|
585 |
+
if args.lr_scheduler.startswith("adafactor"):
|
586 |
+
skip_lr_sched = (scale >= accelerator.scaler.get_scale())
|
587 |
+
else:
|
588 |
+
skip_lr_sched = True
|
589 |
+
if not skip_lr_sched:
|
590 |
+
lr_scheduler.step()
|
591 |
optimizer.zero_grad(set_to_none=True)
|
592 |
|
593 |
# Checks if the accelerator has performed an optimization step behind the scenes
|
|
|
595 |
progress_bar.update(1)
|
596 |
global_step += 1
|
597 |
|
598 |
+
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
599 |
+
|
600 |
current_loss = loss.detach().item()
|
601 |
if epoch == 0:
|
602 |
loss_list.append(current_loss)
|
|
|
609 |
progress_bar.set_postfix(**logs)
|
610 |
|
611 |
if args.logging_dir is not None:
|
612 |
+
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, used_names)
|
613 |
accelerator.log(logs, step=global_step)
|
614 |
|
615 |
if global_step >= args.max_train_steps:
|
|
|
627 |
def save_func():
|
628 |
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
629 |
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
630 |
+
metadata["ss_training_finished_at"] = str(time.time())
|
631 |
print(f"saving checkpoint: {ckpt_file}")
|
632 |
unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
633 |
|
|
|
642 |
if saving and args.save_state:
|
643 |
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
644 |
|
645 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
646 |
+
|
647 |
# end of epoch
|
648 |
|
649 |
metadata["ss_epoch"] = str(num_train_epochs)
|
650 |
+
metadata["ss_training_finished_at"] = str(time.time())
|
651 |
|
652 |
is_main_process = accelerator.is_main_process
|
653 |
if is_main_process:
|
|
|
678 |
train_util.add_sd_models_arguments(parser)
|
679 |
train_util.add_dataset_arguments(parser, True, True, True)
|
680 |
train_util.add_training_arguments(parser, True)
|
681 |
+
train_util.add_optimizer_arguments(parser)
|
682 |
+
config_util.add_config_arguments(parser)
|
683 |
|
684 |
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
685 |
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
|
|
687 |
|
688 |
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
689 |
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
|
|
|
|
|
|
|
|
690 |
|
691 |
parser.add_argument("--network_weights", type=str, default=None,
|
692 |
help="pretrained weights for network / 学習するネットワークの初期重み")
|
|
|
706 |
#Optimizer変更関連のオプション追加
|
707 |
append_module.add_append_arguments(parser)
|
708 |
args = append_module.get_config(parser)
|
709 |
+
#argsを保存する
|
710 |
+
import yaml
|
711 |
+
import datetime
|
712 |
+
_t = datetime.datetime.today().strftime('%Y%m%d_%H%M')
|
713 |
+
if args.output_name==None:
|
714 |
+
config_name = f"train_network_config_{_t}.yaml"
|
715 |
+
else:
|
716 |
+
config_name = f"train_network_config_{os.path.basename(args.output_name)}_{_t}.yaml"
|
717 |
+
print(f"{config_name} に設定を書き出し中...")
|
718 |
+
with open(config_name, mode="w") as f:
|
719 |
+
yaml.dump(args.__dict__, f, indent=4)
|
720 |
|
721 |
if args.resolution==args.min_resolution:
|
722 |
args.min_resolution=None
|
723 |
|
724 |
train(args)
|
725 |
+
print("done!")
|
726 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
727 |
|
728 |
'''
|
729 |
optimizer設定メモ
|
730 |
+
torch_optimizer.AdaBelief
|
731 |
+
adastand.Adastand
|
732 |
(optimizer_argから設定できるように変更するためのメモ)
|
733 |
|
734 |
AdamWのweight_decay初期値は1e-2
|
|
|
758 |
transformerベースのT5学習において最強とかいう噂のoptimizer
|
759 |
huggingfaceのサンプルパラ
|
760 |
eps=1e-30,1e-3 clip_threshold=1.0 decay_rate=-0.8 relative_step=False scale_parameter=False warmup_init=False
|
761 |
+
epsの二つ目の値1e-3が学習率に影響大きい
|
762 |
|
763 |
AggMo
|
764 |
|
train_textual_inversion.py
CHANGED
@@ -11,7 +11,11 @@ import diffusers
|
|
11 |
from diffusers import DDPMScheduler
|
12 |
|
13 |
import library.train_util as train_util
|
14 |
-
|
|
|
|
|
|
|
|
|
15 |
|
16 |
imagenet_templates_small = [
|
17 |
"a photo of a {}",
|
@@ -79,7 +83,6 @@ def train(args):
|
|
79 |
train_util.prepare_dataset_args(args, True)
|
80 |
|
81 |
cache_latents = args.cache_latents
|
82 |
-
use_dreambooth_method = args.in_json is None
|
83 |
|
84 |
if args.seed is not None:
|
85 |
set_seed(args.seed)
|
@@ -139,21 +142,35 @@ def train(args):
|
|
139 |
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
140 |
|
141 |
# データセットを準備する
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
else:
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
159 |
if use_template:
|
@@ -163,20 +180,25 @@ def train(args):
|
|
163 |
captions = []
|
164 |
for tmpl in templates:
|
165 |
captions.append(tmpl.format(replace_to))
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
|
|
|
|
172 |
|
173 |
if args.debug_dataset:
|
174 |
-
train_util.debug_dataset(
|
175 |
return
|
176 |
-
if len(
|
177 |
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
178 |
return
|
179 |
|
|
|
|
|
|
|
180 |
# モデルに xformers とか memory efficient attention を組み込む
|
181 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
182 |
|
@@ -186,7 +208,7 @@ def train(args):
|
|
186 |
vae.requires_grad_(False)
|
187 |
vae.eval()
|
188 |
with torch.no_grad():
|
189 |
-
|
190 |
vae.to("cpu")
|
191 |
if torch.cuda.is_available():
|
192 |
torch.cuda.empty_cache()
|
@@ -198,35 +220,14 @@ def train(args):
|
|
198 |
|
199 |
# 学習に必要なクラスを準備する
|
200 |
print("prepare optimizer, data loader etc.")
|
201 |
-
|
202 |
-
# 8-bit Adamを使う
|
203 |
-
if args.use_8bit_adam:
|
204 |
-
try:
|
205 |
-
import bitsandbytes as bnb
|
206 |
-
except ImportError:
|
207 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
208 |
-
print("use 8-bit Adam optimizer")
|
209 |
-
optimizer_class = bnb.optim.AdamW8bit
|
210 |
-
elif args.use_lion_optimizer:
|
211 |
-
try:
|
212 |
-
import lion_pytorch
|
213 |
-
except ImportError:
|
214 |
-
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
215 |
-
print("use Lion optimizer")
|
216 |
-
optimizer_class = lion_pytorch.Lion
|
217 |
-
else:
|
218 |
-
optimizer_class = torch.optim.AdamW
|
219 |
-
|
220 |
trainable_params = text_encoder.get_input_embeddings().parameters()
|
221 |
-
|
222 |
-
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
223 |
-
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
224 |
|
225 |
# dataloaderを準備する
|
226 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
227 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
228 |
train_dataloader = torch.utils.data.DataLoader(
|
229 |
-
|
230 |
|
231 |
# 学習ステップ数を計算する
|
232 |
if args.max_train_epochs is not None:
|
@@ -234,8 +235,9 @@ def train(args):
|
|
234 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
235 |
|
236 |
# lr schedulerを用意する
|
237 |
-
lr_scheduler =
|
238 |
-
|
|
|
239 |
|
240 |
# acceleratorがなんかよろしくやってくれるらしい
|
241 |
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
@@ -283,8 +285,8 @@ def train(args):
|
|
283 |
# 学習する
|
284 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
285 |
print("running training / 学習開始")
|
286 |
-
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {
|
287 |
-
print(f" num reg images / 正則化画像の数: {
|
288 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
289 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
290 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
@@ -303,12 +305,11 @@ def train(args):
|
|
303 |
|
304 |
for epoch in range(num_train_epochs):
|
305 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
306 |
-
|
307 |
|
308 |
text_encoder.train()
|
309 |
|
310 |
loss_total = 0
|
311 |
-
bef_epo_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
312 |
for step, batch in enumerate(train_dataloader):
|
313 |
with accelerator.accumulate(text_encoder):
|
314 |
with torch.no_grad():
|
@@ -357,9 +358,9 @@ def train(args):
|
|
357 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
358 |
|
359 |
accelerator.backward(loss)
|
360 |
-
if accelerator.sync_gradients:
|
361 |
params_to_clip = text_encoder.get_input_embeddings().parameters()
|
362 |
-
accelerator.clip_grad_norm_(params_to_clip,
|
363 |
|
364 |
optimizer.step()
|
365 |
lr_scheduler.step()
|
@@ -374,9 +375,14 @@ def train(args):
|
|
374 |
progress_bar.update(1)
|
375 |
global_step += 1
|
376 |
|
|
|
|
|
|
|
377 |
current_loss = loss.detach().item()
|
378 |
if args.logging_dir is not None:
|
379 |
-
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
|
|
380 |
accelerator.log(logs, step=global_step)
|
381 |
|
382 |
loss_total += current_loss
|
@@ -394,8 +400,6 @@ def train(args):
|
|
394 |
accelerator.wait_for_everyone()
|
395 |
|
396 |
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
397 |
-
# d = updated_embs - bef_epo_embs
|
398 |
-
# print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
|
399 |
|
400 |
if args.save_every_n_epochs is not None:
|
401 |
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
@@ -417,6 +421,9 @@ def train(args):
|
|
417 |
if saving and args.save_state:
|
418 |
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
419 |
|
|
|
|
|
|
|
420 |
# end of epoch
|
421 |
|
422 |
is_main_process = accelerator.is_main_process
|
@@ -491,6 +498,8 @@ if __name__ == '__main__':
|
|
491 |
train_util.add_sd_models_arguments(parser)
|
492 |
train_util.add_dataset_arguments(parser, True, True, False)
|
493 |
train_util.add_training_arguments(parser, True)
|
|
|
|
|
494 |
|
495 |
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
496 |
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|
|
|
11 |
from diffusers import DDPMScheduler
|
12 |
|
13 |
import library.train_util as train_util
|
14 |
+
import library.config_util as config_util
|
15 |
+
from library.config_util import (
|
16 |
+
ConfigSanitizer,
|
17 |
+
BlueprintGenerator,
|
18 |
+
)
|
19 |
|
20 |
imagenet_templates_small = [
|
21 |
"a photo of a {}",
|
|
|
83 |
train_util.prepare_dataset_args(args, True)
|
84 |
|
85 |
cache_latents = args.cache_latents
|
|
|
86 |
|
87 |
if args.seed is not None:
|
88 |
set_seed(args.seed)
|
|
|
142 |
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
143 |
|
144 |
# データセットを準備する
|
145 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
|
146 |
+
if args.dataset_config is not None:
|
147 |
+
print(f"Load dataset config from {args.dataset_config}")
|
148 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
149 |
+
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
150 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
151 |
+
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
152 |
else:
|
153 |
+
use_dreambooth_method = args.in_json is None
|
154 |
+
if use_dreambooth_method:
|
155 |
+
print("Use DreamBooth method.")
|
156 |
+
user_config = {
|
157 |
+
"datasets": [{
|
158 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
|
159 |
+
}]
|
160 |
+
}
|
161 |
+
else:
|
162 |
+
print("Train with captions.")
|
163 |
+
user_config = {
|
164 |
+
"datasets": [{
|
165 |
+
"subsets": [{
|
166 |
+
"image_dir": args.train_data_dir,
|
167 |
+
"metadata_file": args.in_json,
|
168 |
+
}]
|
169 |
+
}]
|
170 |
+
}
|
171 |
+
|
172 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
173 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
174 |
|
175 |
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
176 |
if use_template:
|
|
|
180 |
captions = []
|
181 |
for tmpl in templates:
|
182 |
captions.append(tmpl.format(replace_to))
|
183 |
+
train_dataset_group.add_replacement("", captions)
|
184 |
+
else:
|
185 |
+
if args.num_vectors_per_token > 1:
|
186 |
+
replace_to = " ".join(token_strings)
|
187 |
+
train_dataset_group.add_replacement(args.token_string, replace_to)
|
188 |
+
prompt_replacement = (args.token_string, replace_to)
|
189 |
+
else:
|
190 |
+
prompt_replacement = None
|
191 |
|
192 |
if args.debug_dataset:
|
193 |
+
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
|
194 |
return
|
195 |
+
if len(train_dataset_group) == 0:
|
196 |
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
197 |
return
|
198 |
|
199 |
+
if cache_latents:
|
200 |
+
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
201 |
+
|
202 |
# モデルに xformers とか memory efficient attention を組み込む
|
203 |
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
204 |
|
|
|
208 |
vae.requires_grad_(False)
|
209 |
vae.eval()
|
210 |
with torch.no_grad():
|
211 |
+
train_dataset_group.cache_latents(vae)
|
212 |
vae.to("cpu")
|
213 |
if torch.cuda.is_available():
|
214 |
torch.cuda.empty_cache()
|
|
|
220 |
|
221 |
# 学習に必要なクラスを準備する
|
222 |
print("prepare optimizer, data loader etc.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
trainable_params = text_encoder.get_input_embeddings().parameters()
|
224 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
|
|
|
|
225 |
|
226 |
# dataloaderを準備する
|
227 |
# DataLoaderのプロセス数:0はメインプロセスになる
|
228 |
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
229 |
train_dataloader = torch.utils.data.DataLoader(
|
230 |
+
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
231 |
|
232 |
# 学習ステップ数を計算する
|
233 |
if args.max_train_epochs is not None:
|
|
|
235 |
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
236 |
|
237 |
# lr schedulerを用意する
|
238 |
+
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
239 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
240 |
+
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
241 |
|
242 |
# acceleratorがなんかよろしくやってくれるらしい
|
243 |
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
|
|
285 |
# 学習する
|
286 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
287 |
print("running training / 学習開始")
|
288 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
289 |
+
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
290 |
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
291 |
print(f" num epochs / epoch数: {num_train_epochs}")
|
292 |
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
|
|
305 |
|
306 |
for epoch in range(num_train_epochs):
|
307 |
print(f"epoch {epoch+1}/{num_train_epochs}")
|
308 |
+
train_dataset_group.set_current_epoch(epoch + 1)
|
309 |
|
310 |
text_encoder.train()
|
311 |
|
312 |
loss_total = 0
|
|
|
313 |
for step, batch in enumerate(train_dataloader):
|
314 |
with accelerator.accumulate(text_encoder):
|
315 |
with torch.no_grad():
|
|
|
358 |
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
359 |
|
360 |
accelerator.backward(loss)
|
361 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
362 |
params_to_clip = text_encoder.get_input_embeddings().parameters()
|
363 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
364 |
|
365 |
optimizer.step()
|
366 |
lr_scheduler.step()
|
|
|
375 |
progress_bar.update(1)
|
376 |
global_step += 1
|
377 |
|
378 |
+
train_util.sample_images(accelerator, args, None, global_step, accelerator.device,
|
379 |
+
vae, tokenizer, text_encoder, unet, prompt_replacement)
|
380 |
+
|
381 |
current_loss = loss.detach().item()
|
382 |
if args.logging_dir is not None:
|
383 |
+
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
384 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
385 |
+
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
386 |
accelerator.log(logs, step=global_step)
|
387 |
|
388 |
loss_total += current_loss
|
|
|
400 |
accelerator.wait_for_everyone()
|
401 |
|
402 |
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
|
|
|
|
403 |
|
404 |
if args.save_every_n_epochs is not None:
|
405 |
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
|
|
421 |
if saving and args.save_state:
|
422 |
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
423 |
|
424 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device,
|
425 |
+
vae, tokenizer, text_encoder, unet, prompt_replacement)
|
426 |
+
|
427 |
# end of epoch
|
428 |
|
429 |
is_main_process = accelerator.is_main_process
|
|
|
498 |
train_util.add_sd_models_arguments(parser)
|
499 |
train_util.add_dataset_arguments(parser, True, True, False)
|
500 |
train_util.add_training_arguments(parser, True)
|
501 |
+
train_util.add_optimizer_arguments(parser)
|
502 |
+
config_util.add_config_arguments(parser)
|
503 |
|
504 |
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
505 |
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|