abc
commited on
Commit
·
b9a80b5
1
Parent(s):
bd600ff
Delete train_network_opt.py
Browse files- train_network_opt.py +0 -832
train_network_opt.py
DELETED
@@ -1,832 +0,0 @@
|
|
1 |
-
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
2 |
-
from torch.optim import Optimizer
|
3 |
-
from torch.cuda.amp import autocast
|
4 |
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
5 |
-
from typing import Optional, Union
|
6 |
-
import importlib
|
7 |
-
import argparse
|
8 |
-
import gc
|
9 |
-
import math
|
10 |
-
import os
|
11 |
-
import random
|
12 |
-
import time
|
13 |
-
import json
|
14 |
-
|
15 |
-
from tqdm import tqdm
|
16 |
-
import torch
|
17 |
-
from accelerate.utils import set_seed
|
18 |
-
import diffusers
|
19 |
-
from diffusers import DDPMScheduler
|
20 |
-
print("**********************************")
|
21 |
-
#先に
|
22 |
-
#pip install torch_optimizer
|
23 |
-
#が必要
|
24 |
-
try:
|
25 |
-
import torch_optimizer as optim
|
26 |
-
except:
|
27 |
-
print("torch_optimizerがインストールされていないためAdafactorとAdastand以外の追加optimzierは使えません。\noptimizerの変更をしたい場合先にpip install torch_optimizerでライブラリを追加してください")
|
28 |
-
try:
|
29 |
-
import adastand
|
30 |
-
except:
|
31 |
-
print("※Adastandが使えません")
|
32 |
-
|
33 |
-
from transformers.optimization import Adafactor, AdafactorSchedule
|
34 |
-
print("**********************************")
|
35 |
-
##### バケット拡張のためのモジュール
|
36 |
-
import append_module
|
37 |
-
######
|
38 |
-
import library.train_util as train_util
|
39 |
-
from library.train_util import DreamBoothDataset, FineTuningDataset
|
40 |
-
|
41 |
-
|
42 |
-
def collate_fn(examples):
|
43 |
-
return examples[0]
|
44 |
-
|
45 |
-
|
46 |
-
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
47 |
-
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
48 |
-
|
49 |
-
if args.network_train_unet_only:
|
50 |
-
logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
|
51 |
-
elif args.network_train_text_encoder_only:
|
52 |
-
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
53 |
-
else:
|
54 |
-
last_lrs = lr_scheduler.get_last_lr()
|
55 |
-
if len(last_lrs) == 2:
|
56 |
-
logs["lr/textencoder"] = float(last_lrs[0])
|
57 |
-
logs["lr/unet"] = float(last_lrs[-1]) # may be same to textencoder
|
58 |
-
else:
|
59 |
-
if len(last_lrs) == 4:
|
60 |
-
logs_names = ["textencoder", "lora_unet_mid_block", "unet_down_blocks", "unet_up_blocks"]
|
61 |
-
elif len(last_lrs) == 8:
|
62 |
-
logs_names = ["textencoder", "unet_midblock"]
|
63 |
-
for i in range(3):
|
64 |
-
logs_names.append(f"unet_down_blocks_{i}")
|
65 |
-
logs_names.append(f"unet_up_blocks_{i+1}")
|
66 |
-
else:
|
67 |
-
logs_names = []
|
68 |
-
for i in range(12):
|
69 |
-
logs_names.append(f"text_model_encoder_layers_{i}_")
|
70 |
-
logs_names.append("unet_midblock")
|
71 |
-
for i in range(3):
|
72 |
-
logs_names.append(f"unet_down_blocks_{i}")
|
73 |
-
logs_names.append(f"unet_up_blocks_{i+1}")
|
74 |
-
|
75 |
-
for last_lr, logs_name in zip(last_lrs, logs_names):
|
76 |
-
logs[f"lr/{logs_name}"] = float(last_lr)
|
77 |
-
|
78 |
-
return logs
|
79 |
-
|
80 |
-
|
81 |
-
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
82 |
-
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
83 |
-
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
84 |
-
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
85 |
-
|
86 |
-
|
87 |
-
def get_scheduler_fix(
|
88 |
-
name: Union[str, SchedulerType],
|
89 |
-
optimizer: Optimizer,
|
90 |
-
num_warmup_steps: Optional[int] = None,
|
91 |
-
num_training_steps: Optional[int] = None,
|
92 |
-
num_cycles: float = 1.,
|
93 |
-
power: float = 1.0,
|
94 |
-
):
|
95 |
-
"""
|
96 |
-
Unified API to get any scheduler from its name.
|
97 |
-
Args:
|
98 |
-
name (`str` or `SchedulerType`):
|
99 |
-
The name of the scheduler to use.
|
100 |
-
optimizer (`torch.optim.Optimizer`):
|
101 |
-
The optimizer that will be used during training.
|
102 |
-
num_warmup_steps (`int`, *optional*):
|
103 |
-
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
104 |
-
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
105 |
-
num_training_steps (`int``, *optional*):
|
106 |
-
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
107 |
-
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
108 |
-
num_cycles (`int`, *optional*):
|
109 |
-
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
110 |
-
power (`float`, *optional*, defaults to 1.0):
|
111 |
-
Power factor. See `POLYNOMIAL` scheduler
|
112 |
-
last_epoch (`int`, *optional*, defaults to -1):
|
113 |
-
The index of the last epoch when resuming training.
|
114 |
-
"""
|
115 |
-
name = SchedulerType(name)
|
116 |
-
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
117 |
-
if name == SchedulerType.CONSTANT:
|
118 |
-
return schedule_func(optimizer)
|
119 |
-
|
120 |
-
# All other schedulers require `num_warmup_steps`
|
121 |
-
if num_warmup_steps is None:
|
122 |
-
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
123 |
-
|
124 |
-
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
125 |
-
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
126 |
-
|
127 |
-
# All other schedulers require `num_training_steps`
|
128 |
-
if num_training_steps is None:
|
129 |
-
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
130 |
-
|
131 |
-
if name == SchedulerType.COSINE:
|
132 |
-
print(f"{name} num_cycles: {num_cycles}")
|
133 |
-
return schedule_func(
|
134 |
-
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
135 |
-
)
|
136 |
-
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
137 |
-
print(f"{name} num_cycles: {int(num_cycles)}")
|
138 |
-
return schedule_func(
|
139 |
-
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=int(num_cycles)
|
140 |
-
)
|
141 |
-
|
142 |
-
if name == SchedulerType.POLYNOMIAL:
|
143 |
-
return schedule_func(
|
144 |
-
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
145 |
-
)
|
146 |
-
|
147 |
-
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
148 |
-
|
149 |
-
|
150 |
-
def train(args):
|
151 |
-
session_id = random.randint(0, 2**32)
|
152 |
-
training_started_at = time.time()
|
153 |
-
train_util.verify_training_args(args)
|
154 |
-
train_util.prepare_dataset_args(args, True)
|
155 |
-
|
156 |
-
cache_latents = args.cache_latents
|
157 |
-
use_dreambooth_method = args.in_json is None
|
158 |
-
|
159 |
-
if args.seed is not None:
|
160 |
-
set_seed(args.seed)
|
161 |
-
|
162 |
-
tokenizer = train_util.load_tokenizer(args)
|
163 |
-
|
164 |
-
# データセットを準備する
|
165 |
-
if use_dreambooth_method:
|
166 |
-
if args.min_resolution:
|
167 |
-
args.min_resolution = tuple([int(r) for r in args.min_resolution.split(',')])
|
168 |
-
if len(args.min_resolution) == 1:
|
169 |
-
args.min_resolution = (args.min_resolution[0], args.min_resolution[0])
|
170 |
-
|
171 |
-
print("Use DreamBooth method.")
|
172 |
-
train_dataset = append_module.DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
|
173 |
-
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
|
174 |
-
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
175 |
-
args.bucket_reso_steps, args.bucket_no_upscale,
|
176 |
-
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
|
177 |
-
args.random_crop, args.debug_dataset, args.min_resolution, args.area_step)
|
178 |
-
else:
|
179 |
-
print("Train with captions.")
|
180 |
-
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
181 |
-
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
182 |
-
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
183 |
-
args.bucket_reso_steps, args.bucket_no_upscale,
|
184 |
-
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
185 |
-
args.dataset_repeats, args.debug_dataset)
|
186 |
-
|
187 |
-
# 学習データのdropout率を設定する
|
188 |
-
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
|
189 |
-
|
190 |
-
train_dataset.make_buckets()
|
191 |
-
|
192 |
-
if args.debug_dataset:
|
193 |
-
train_util.debug_dataset(train_dataset)
|
194 |
-
return
|
195 |
-
if len(train_dataset) == 0:
|
196 |
-
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
197 |
-
return
|
198 |
-
|
199 |
-
# acceleratorを準備する
|
200 |
-
print("prepare accelerator")
|
201 |
-
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
202 |
-
|
203 |
-
# mixed precisionに対応した型を用意しておき適宜castする
|
204 |
-
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
205 |
-
|
206 |
-
# モデルを読み込む
|
207 |
-
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
208 |
-
# unnecessary, but work on low-ram device
|
209 |
-
text_encoder.to("cuda")
|
210 |
-
unet.to("cuda")
|
211 |
-
# モデルに xformers とか memory efficient attention を組み込む
|
212 |
-
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
213 |
-
|
214 |
-
# 学習を準備する
|
215 |
-
if cache_latents:
|
216 |
-
vae.to(accelerator.device, dtype=weight_dtype)
|
217 |
-
vae.requires_grad_(False)
|
218 |
-
vae.eval()
|
219 |
-
with torch.no_grad():
|
220 |
-
train_dataset.cache_latents(vae)
|
221 |
-
vae.to("cpu")
|
222 |
-
if torch.cuda.is_available():
|
223 |
-
torch.cuda.empty_cache()
|
224 |
-
gc.collect()
|
225 |
-
|
226 |
-
# prepare network
|
227 |
-
print("import network module:", args.network_module)
|
228 |
-
network_module = importlib.import_module(args.network_module)
|
229 |
-
|
230 |
-
net_kwargs = {}
|
231 |
-
if args.network_args is not None:
|
232 |
-
for net_arg in args.network_args:
|
233 |
-
key, value = net_arg.split('=')
|
234 |
-
net_kwargs[key] = value
|
235 |
-
|
236 |
-
# if a new network is added in future, add if ~ then blocks for each network (;'∀')
|
237 |
-
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
|
238 |
-
if network is None:
|
239 |
-
return
|
240 |
-
|
241 |
-
if args.network_weights is not None:
|
242 |
-
print("load network weights from:", args.network_weights)
|
243 |
-
network.load_weights(args.network_weights)
|
244 |
-
|
245 |
-
train_unet = not args.network_train_text_encoder_only
|
246 |
-
train_text_encoder = not args.network_train_unet_only
|
247 |
-
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
248 |
-
|
249 |
-
if args.gradient_checkpointing:
|
250 |
-
unet.enable_gradient_checkpointing()
|
251 |
-
text_encoder.gradient_checkpointing_enable()
|
252 |
-
network.enable_gradient_checkpointing() # may have no effect
|
253 |
-
|
254 |
-
# 学習に必要なクラスを準備する
|
255 |
-
print("prepare optimizer, data loader etc.")
|
256 |
-
try:
|
257 |
-
print(f"torch_optimzier version is {optim.__version__}")
|
258 |
-
not_torch_optimizer_flag = False
|
259 |
-
except:
|
260 |
-
not_torch_optimizer_flag = True
|
261 |
-
try:
|
262 |
-
print(f"adastand version is {adastand.__version__()}")
|
263 |
-
not_adasatand_optimzier_flag = False
|
264 |
-
except:
|
265 |
-
not_adasatand_optimzier_flag = True
|
266 |
-
|
267 |
-
# 8-bit Adamを使う
|
268 |
-
if args.optimizer=="Adafactor" or args.optimizer=="Adastand" or args.optimizer=="Adastand_belief":
|
269 |
-
not_torch_optimizer_flag = False
|
270 |
-
if args.optimizer=="Adafactor":
|
271 |
-
not_adasatand_optimzier_flag = False
|
272 |
-
if not_torch_optimizer_flag or not_adasatand_optimzier_flag:
|
273 |
-
print(f"==========================\n必要なライブラリがないため {args.optimizer} の使用ができません。optimizerを AdamW に変更して実行します\n==========================")
|
274 |
-
args.optimizer="AdamW"
|
275 |
-
if args.use_8bit_adam:
|
276 |
-
if not args.optimizer=="AdamW" and not args.optimizer=="Lamb":
|
277 |
-
print(f"\n==========================\n{args.optimizer} は8bitAdamに実装されていないので8bitAdamをオフにします\n==========================\n")
|
278 |
-
args.use_8bit_adam=False
|
279 |
-
if args.use_8bit_adam:
|
280 |
-
try:
|
281 |
-
import bitsandbytes as bnb
|
282 |
-
except ImportError:
|
283 |
-
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
284 |
-
print("use 8-bit Adam optimizer")
|
285 |
-
args.training_comment=f"{args.training_comment} use_8bit_adam=True"
|
286 |
-
if args.optimizer=="Lamb":
|
287 |
-
optimizer_class = bnb.optim.LAMB8bit
|
288 |
-
else:
|
289 |
-
args.optimizer="AdamW"
|
290 |
-
optimizer_class = bnb.optim.AdamW8bit
|
291 |
-
else:
|
292 |
-
print(f"use {args.optimizer}")
|
293 |
-
if args.optimizer=="RAdam":
|
294 |
-
optimizer_class = torch.optim.RAdam
|
295 |
-
elif args.optimizer=="AdaBound":
|
296 |
-
optimizer_class = optim.AdaBound
|
297 |
-
elif args.optimizer=="AdaBelief":
|
298 |
-
optimizer_class = optim.AdaBelief
|
299 |
-
elif args.optimizer=="AdamP":
|
300 |
-
optimizer_class = optim.AdamP
|
301 |
-
elif args.optimizer=="Adafactor":
|
302 |
-
optimizer_class = Adafactor
|
303 |
-
elif args.optimizer=="Adastand":
|
304 |
-
optimizer_class = adastand.Adastand
|
305 |
-
elif args.optimizer=="Adastand_belief":
|
306 |
-
optimizer_class = adastand.Adastand_b
|
307 |
-
elif args.optimizer=="AggMo":
|
308 |
-
optimizer_class = optim.AggMo
|
309 |
-
elif args.optimizer=="Apollo":
|
310 |
-
optimizer_class = optim.Apollo
|
311 |
-
elif args.optimizer=="Lamb":
|
312 |
-
optimizer_class = optim.Lamb
|
313 |
-
elif args.optimizer=="Ranger":
|
314 |
-
optimizer_class = optim.Ranger
|
315 |
-
elif args.optimizer=="RangerVA":
|
316 |
-
optimizer_class = optim.RangerVA
|
317 |
-
elif args.optimizer=="Yogi":
|
318 |
-
optimizer_class = optim.Yogi
|
319 |
-
elif args.optimizer=="Shampoo":
|
320 |
-
optimizer_class = optim.Shampoo
|
321 |
-
elif args.optimizer=="NovoGrad":
|
322 |
-
optimizer_class = optim.NovoGrad
|
323 |
-
elif args.optimizer=="QHAdam":
|
324 |
-
optimizer_class = optim.QHAdam
|
325 |
-
elif args.optimizer=="DiffGrad" or args.optimizer=="Lookahead_DiffGrad":
|
326 |
-
optimizer_class = optim.DiffGrad
|
327 |
-
elif args.optimizer=="MADGRAD":
|
328 |
-
optimizer_class = optim.MADGRAD
|
329 |
-
else:
|
330 |
-
optimizer_class = torch.optim.AdamW
|
331 |
-
|
332 |
-
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
333 |
-
#optimizerデフォ設定
|
334 |
-
if args.optimizer_arg==None:
|
335 |
-
if args.optimizer=="AdaBelief":
|
336 |
-
args.optimizer_arg = ["eps=1e-16","betas=0.9,0.999","weight_decouple=True","rectify=False","fixed_decay=False"]
|
337 |
-
elif args.optimizer=="DiffGrad":
|
338 |
-
args.optimizer_arg = ["eps=1e-16"]
|
339 |
-
optimizer_arg = {}
|
340 |
-
lookahed_arg = {"k": 5, "alpha": 0.5}
|
341 |
-
adafactor_scheduler_arg = {"initial_lr": 0.}
|
342 |
-
int_args = ["k","n_sma_threshold","warmup"]
|
343 |
-
str_args = ["transformer","grad_transformer"]
|
344 |
-
if not args.optimizer_arg==None and len(args.optimizer_arg)>0:
|
345 |
-
for _opt_arg in args.optimizer_arg:
|
346 |
-
key, value = _opt_arg.split("=")
|
347 |
-
if value=="True" or value=="False":
|
348 |
-
optimizer_arg[key]=bool((value=="True"))
|
349 |
-
elif key=="betas" or key=="nus" or key=="eps2" or (key=="eps" and "," in value):
|
350 |
-
_value = value.split(",")
|
351 |
-
optimizer_arg[key] = (float(_value[0]),float(_value[1]))
|
352 |
-
del _value
|
353 |
-
elif key in int_args:
|
354 |
-
if "Lookahead" in args.optimizer:
|
355 |
-
lookahed_arg[key] = int(value)
|
356 |
-
else:
|
357 |
-
optimizer_arg[key] = int(value)
|
358 |
-
elif key in str_args:
|
359 |
-
optimizer_arg[key] = value
|
360 |
-
else:
|
361 |
-
if key=="alpha" and "Lookahead" in args.optimizer:
|
362 |
-
lookahed_arg[key] = int(value)
|
363 |
-
elif key=="initial_lr" and args.optimizer == "Adafactor":
|
364 |
-
adafactor_scheduler_arg[key] = float(value)
|
365 |
-
else:
|
366 |
-
optimizer_arg[key] = float(value)
|
367 |
-
del _opt_arg
|
368 |
-
AdafactorScheduler_Flag = False
|
369 |
-
list_of_init_lr = []
|
370 |
-
if args.optimizer=="Adafactor":
|
371 |
-
if not "relative_step" in optimizer_arg:
|
372 |
-
optimizer_arg["relative_step"] = True
|
373 |
-
if "warmup_init" in optimizer_arg:
|
374 |
-
if optimizer_arg["warmup_init"]==True and optimizer_arg["relative_step"]==False:
|
375 |
-
print("**************\nwarmup_initはrelative_stepがオンである必要があるためrelative_stepをオンにします\n**************")
|
376 |
-
optimizer_arg["relative_step"] = True
|
377 |
-
if optimizer_arg["relative_step"] == True:
|
378 |
-
AdafactorScheduler_Flag = True
|
379 |
-
list_of_init_lr = [0.,0.]
|
380 |
-
if args.text_encoder_lr is not None: list_of_init_lr[0] = float(args.text_encoder_lr)
|
381 |
-
if args.unet_lr is not None: list_of_init_lr[1] = float(args.unet_lr)
|
382 |
-
#if not "initial_lr" in adafactor_scheduler_arg:
|
383 |
-
# adafactor_scheduler_arg = args.learning_rate
|
384 |
-
args.learning_rate = None
|
385 |
-
args.text_encoder_lr = None
|
386 |
-
args.unet_lr = None
|
387 |
-
print(f"optimizer arg: {optimizer_arg}")
|
388 |
-
print("=-----------------------------------=")
|
389 |
-
if not AdafactorScheduler_Flag: args.split_lora_networks = False
|
390 |
-
if args.split_lora_networks:
|
391 |
-
lora_names = append_module.create_split_names(args.split_lora_networks, args.split_lora_level)
|
392 |
-
append_module.replace_prepare_optimizer_params(network)
|
393 |
-
trainable_params, _list_of_init_lr = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, list_of_init_lr, lora_names)
|
394 |
-
else:
|
395 |
-
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
396 |
-
_list_of_init_lr = []
|
397 |
-
print(f"trainable_params_len: {len(trainable_params)}")
|
398 |
-
if len(_list_of_init_lr)>0:
|
399 |
-
list_of_init_lr = _list_of_init_lr
|
400 |
-
print(f"split loras network is {len(list_of_init_lr)}")
|
401 |
-
if len(list_of_init_lr) > 0:
|
402 |
-
adafactor_scheduler_arg["initial_lr"] = list_of_init_lr
|
403 |
-
|
404 |
-
optimizer = optimizer_class(trainable_params, lr=args.learning_rate, **optimizer_arg)
|
405 |
-
|
406 |
-
if args.optimizer=="Lookahead_DiffGrad" or args.optimizer=="Lookahedad_Adam":
|
407 |
-
optimizer = optim.Lookahead(optimizer, **lookahed_arg)
|
408 |
-
print(f"lookahed_arg: {lookahed_arg}")
|
409 |
-
|
410 |
-
# dataloaderを準備する
|
411 |
-
# DataLoaderのプロセス数:0はメインプロセスになる
|
412 |
-
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
413 |
-
train_dataloader = torch.utils.data.DataLoader(
|
414 |
-
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
415 |
-
|
416 |
-
# 学習ステップ数を計算する
|
417 |
-
if args.max_train_epochs is not None:
|
418 |
-
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
419 |
-
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
420 |
-
|
421 |
-
# lr schedulerを用意する
|
422 |
-
# lr_scheduler = diffusers.optimization.get_scheduler(
|
423 |
-
if AdafactorScheduler_Flag:
|
424 |
-
print("===================================\nAdafactorはデフォルトでrelative_stepがオンになっているので lrは自動算出されるためLrScheculerの指定も無効になります\nもし任意のLrやLr_Schedulerを使いたい場合は --optimizer_arg relative_ste=False を指定してください\nまた任意のLrを使う場合は scale_parameter=False も併せて指定するのが推奨です\n===================================")
|
425 |
-
lr_scheduler = append_module.AdafactorSchedule_append(optimizer, **adafactor_scheduler_arg)
|
426 |
-
print(f"AdafactorSchedule initial lrs: {lr_scheduler.get_lr()}")
|
427 |
-
del list_of_init_lr
|
428 |
-
else:
|
429 |
-
lr_scheduler = get_scheduler_fix(
|
430 |
-
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
431 |
-
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
432 |
-
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
433 |
-
|
434 |
-
#追加機能の設定をコメントに追記して残す
|
435 |
-
args.training_comment=f"{args.training_comment} optimizer: {args.optimizer} / optimizer_arg: {args.optimizer_arg}"
|
436 |
-
if AdafactorScheduler_Flag:
|
437 |
-
args.training_comment=f"{args.training_comment} split_lora_networks: {args.split_lora_networks}"
|
438 |
-
if args.min_resolution:
|
439 |
-
args.training_comment=f"{args.training_comment} min_resolution: {args.min_resolution} area_step: {args.area_step}"
|
440 |
-
|
441 |
-
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
442 |
-
if args.full_fp16:
|
443 |
-
assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
444 |
-
print("enable full fp16 training.")
|
445 |
-
network.to(weight_dtype)
|
446 |
-
|
447 |
-
# acceleratorがなんかよろしくやってくれるらしい
|
448 |
-
if train_unet and train_text_encoder:
|
449 |
-
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
450 |
-
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler)
|
451 |
-
elif train_unet:
|
452 |
-
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
453 |
-
unet, network, optimizer, train_dataloader, lr_scheduler)
|
454 |
-
elif train_text_encoder:
|
455 |
-
text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
456 |
-
text_encoder, network, optimizer, train_dataloader, lr_scheduler)
|
457 |
-
else:
|
458 |
-
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
459 |
-
network, optimizer, train_dataloader, lr_scheduler)
|
460 |
-
|
461 |
-
unet.requires_grad_(False)
|
462 |
-
unet.to(accelerator.device, dtype=weight_dtype)
|
463 |
-
text_encoder.requires_grad_(False)
|
464 |
-
text_encoder.to(accelerator.device)
|
465 |
-
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
|
466 |
-
unet.train()
|
467 |
-
text_encoder.train()
|
468 |
-
|
469 |
-
# set top parameter requires_grad = True for gradient checkpointing works
|
470 |
-
if type(text_encoder) == DDP:
|
471 |
-
text_encoder.module.text_model.embeddings.requires_grad_(True)
|
472 |
-
else:
|
473 |
-
text_encoder.text_model.embeddings.requires_grad_(True)
|
474 |
-
else:
|
475 |
-
unet.eval()
|
476 |
-
text_encoder.eval()
|
477 |
-
|
478 |
-
# support DistributedDataParallel
|
479 |
-
if type(text_encoder) == DDP:
|
480 |
-
text_encoder = text_encoder.module
|
481 |
-
unet = unet.module
|
482 |
-
network = network.module
|
483 |
-
|
484 |
-
network.prepare_grad_etc(text_encoder, unet)
|
485 |
-
|
486 |
-
if not cache_latents:
|
487 |
-
vae.requires_grad_(False)
|
488 |
-
vae.eval()
|
489 |
-
vae.to(accelerator.device, dtype=weight_dtype)
|
490 |
-
|
491 |
-
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
492 |
-
if args.full_fp16:
|
493 |
-
train_util.patch_accelerator_for_fp16_training(accelerator)
|
494 |
-
|
495 |
-
# resumeする
|
496 |
-
if args.resume is not None:
|
497 |
-
print(f"resume training from state: {args.resume}")
|
498 |
-
accelerator.load_state(args.resume)
|
499 |
-
|
500 |
-
# epoch数を計算する
|
501 |
-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
502 |
-
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
503 |
-
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
504 |
-
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
505 |
-
|
506 |
-
# 学習する
|
507 |
-
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
508 |
-
print("running training / 学習開始")
|
509 |
-
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
|
510 |
-
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
|
511 |
-
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
512 |
-
print(f" num epochs / epoch数: {num_train_epochs}")
|
513 |
-
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
514 |
-
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
515 |
-
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
516 |
-
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
517 |
-
|
518 |
-
metadata = {
|
519 |
-
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
520 |
-
"ss_training_started_at": training_started_at, # unix timestamp
|
521 |
-
"ss_output_name": args.output_name,
|
522 |
-
"ss_learning_rate": args.learning_rate,
|
523 |
-
"ss_text_encoder_lr": args.text_encoder_lr,
|
524 |
-
"ss_unet_lr": args.unet_lr,
|
525 |
-
"ss_num_train_images": train_dataset.num_train_images, # includes repeating
|
526 |
-
"ss_num_reg_images": train_dataset.num_reg_images,
|
527 |
-
"ss_num_batches_per_epoch": len(train_dataloader),
|
528 |
-
"ss_num_epochs": num_train_epochs,
|
529 |
-
"ss_batch_size_per_device": args.train_batch_size,
|
530 |
-
"ss_total_batch_size": total_batch_size,
|
531 |
-
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
532 |
-
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
533 |
-
"ss_max_train_steps": args.max_train_steps,
|
534 |
-
"ss_lr_warmup_steps": args.lr_warmup_steps,
|
535 |
-
"ss_lr_scheduler": args.lr_scheduler,
|
536 |
-
"ss_network_module": args.network_module,
|
537 |
-
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
|
538 |
-
"ss_network_alpha": args.network_alpha, # some networks may not use this value
|
539 |
-
"ss_mixed_precision": args.mixed_precision,
|
540 |
-
"ss_full_fp16": bool(args.full_fp16),
|
541 |
-
"ss_v2": bool(args.v2),
|
542 |
-
"ss_resolution": args.resolution,
|
543 |
-
"ss_clip_skip": args.clip_skip,
|
544 |
-
"ss_max_token_length": args.max_token_length,
|
545 |
-
"ss_color_aug": bool(args.color_aug),
|
546 |
-
"ss_flip_aug": bool(args.flip_aug),
|
547 |
-
"ss_random_crop": bool(args.random_crop),
|
548 |
-
"ss_shuffle_caption": bool(args.shuffle_caption),
|
549 |
-
"ss_cache_latents": bool(args.cache_latents),
|
550 |
-
"ss_enable_bucket": bool(train_dataset.enable_bucket),
|
551 |
-
"ss_min_bucket_reso": train_dataset.min_bucket_reso,
|
552 |
-
"ss_max_bucket_reso": train_dataset.max_bucket_reso,
|
553 |
-
"ss_seed": args.seed,
|
554 |
-
"ss_keep_tokens": args.keep_tokens,
|
555 |
-
"ss_noise_offset": args.noise_offset,
|
556 |
-
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
557 |
-
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
558 |
-
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
|
559 |
-
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
|
560 |
-
"ss_training_comment": args.training_comment, # will not be updated after training
|
561 |
-
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash()
|
562 |
-
}
|
563 |
-
|
564 |
-
# uncomment if another network is added
|
565 |
-
# for key, value in net_kwargs.items():
|
566 |
-
# metadata["ss_arg_" + key] = value
|
567 |
-
|
568 |
-
if args.pretrained_model_name_or_path is not None:
|
569 |
-
sd_model_name = args.pretrained_model_name_or_path
|
570 |
-
if os.path.exists(sd_model_name):
|
571 |
-
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
|
572 |
-
metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name)
|
573 |
-
sd_model_name = os.path.basename(sd_model_name)
|
574 |
-
metadata["ss_sd_model_name"] = sd_model_name
|
575 |
-
|
576 |
-
if args.vae is not None:
|
577 |
-
vae_name = args.vae
|
578 |
-
if os.path.exists(vae_name):
|
579 |
-
metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
|
580 |
-
metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name)
|
581 |
-
vae_name = os.path.basename(vae_name)
|
582 |
-
metadata["ss_vae_name"] = vae_name
|
583 |
-
|
584 |
-
metadata = {k: str(v) for k, v in metadata.items()}
|
585 |
-
|
586 |
-
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
587 |
-
global_step = 0
|
588 |
-
|
589 |
-
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
590 |
-
num_train_timesteps=1000, clip_sample=False)
|
591 |
-
|
592 |
-
if accelerator.is_main_process:
|
593 |
-
accelerator.init_trackers("network_train")
|
594 |
-
|
595 |
-
loss_list = []
|
596 |
-
loss_total = 0.0
|
597 |
-
for epoch in range(num_train_epochs):
|
598 |
-
print(f"epoch {epoch+1}/{num_train_epochs}")
|
599 |
-
train_dataset.set_current_epoch(epoch + 1)
|
600 |
-
|
601 |
-
metadata["ss_epoch"] = str(epoch+1)
|
602 |
-
|
603 |
-
network.on_epoch_start(text_encoder, unet)
|
604 |
-
|
605 |
-
for step, batch in enumerate(train_dataloader):
|
606 |
-
with accelerator.accumulate(network):
|
607 |
-
with torch.no_grad():
|
608 |
-
if "latents" in batch and batch["latents"] is not None:
|
609 |
-
latents = batch["latents"].to(accelerator.device)
|
610 |
-
else:
|
611 |
-
# latentに変換
|
612 |
-
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
613 |
-
latents = latents * 0.18215
|
614 |
-
b_size = latents.shape[0]
|
615 |
-
|
616 |
-
with torch.set_grad_enabled(train_text_encoder):
|
617 |
-
# Get the text embedding for conditioning
|
618 |
-
input_ids = batch["input_ids"].to(accelerator.device)
|
619 |
-
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
|
620 |
-
|
621 |
-
# Sample noise that we'll add to the latents
|
622 |
-
noise = torch.randn_like(latents, device=latents.device)
|
623 |
-
if args.noise_offset:
|
624 |
-
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
625 |
-
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
626 |
-
|
627 |
-
# Sample a random timestep for each image
|
628 |
-
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
629 |
-
timesteps = timesteps.long()
|
630 |
-
|
631 |
-
# Add noise to the latents according to the noise magnitude at each timestep
|
632 |
-
# (this is the forward diffusion process)
|
633 |
-
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
634 |
-
|
635 |
-
# Predict the noise residual
|
636 |
-
with autocast():
|
637 |
-
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
638 |
-
|
639 |
-
if args.v_parameterization:
|
640 |
-
# v-parameterization training
|
641 |
-
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
642 |
-
else:
|
643 |
-
target = noise
|
644 |
-
|
645 |
-
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
646 |
-
loss = loss.mean([1, 2, 3])
|
647 |
-
|
648 |
-
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
649 |
-
loss = loss * loss_weights
|
650 |
-
|
651 |
-
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
652 |
-
|
653 |
-
accelerator.backward(loss)
|
654 |
-
if accelerator.sync_gradients:
|
655 |
-
params_to_clip = network.get_trainable_params()
|
656 |
-
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
657 |
-
|
658 |
-
optimizer.step()
|
659 |
-
lr_scheduler.step()
|
660 |
-
optimizer.zero_grad(set_to_none=True)
|
661 |
-
|
662 |
-
# Checks if the accelerator has performed an optimization step behind the scenes
|
663 |
-
if accelerator.sync_gradients:
|
664 |
-
progress_bar.update(1)
|
665 |
-
global_step += 1
|
666 |
-
|
667 |
-
current_loss = loss.detach().item()
|
668 |
-
if epoch == 0:
|
669 |
-
loss_list.append(current_loss)
|
670 |
-
else:
|
671 |
-
loss_total -= loss_list[step]
|
672 |
-
loss_list[step] = current_loss
|
673 |
-
loss_total += current_loss
|
674 |
-
avr_loss = loss_total / len(loss_list)
|
675 |
-
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
676 |
-
progress_bar.set_postfix(**logs)
|
677 |
-
|
678 |
-
if args.logging_dir is not None:
|
679 |
-
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
680 |
-
accelerator.log(logs, step=global_step)
|
681 |
-
|
682 |
-
if global_step >= args.max_train_steps:
|
683 |
-
break
|
684 |
-
|
685 |
-
if args.logging_dir is not None:
|
686 |
-
logs = {"loss/epoch": loss_total / len(loss_list)}
|
687 |
-
accelerator.log(logs, step=epoch+1)
|
688 |
-
|
689 |
-
accelerator.wait_for_everyone()
|
690 |
-
|
691 |
-
if args.save_every_n_epochs is not None:
|
692 |
-
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
693 |
-
|
694 |
-
def save_func():
|
695 |
-
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
696 |
-
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
697 |
-
print(f"saving checkpoint: {ckpt_file}")
|
698 |
-
unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
699 |
-
|
700 |
-
def remove_old_func(old_epoch_no):
|
701 |
-
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
702 |
-
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
703 |
-
if os.path.exists(old_ckpt_file):
|
704 |
-
print(f"removing old checkpoint: {old_ckpt_file}")
|
705 |
-
os.remove(old_ckpt_file)
|
706 |
-
|
707 |
-
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
708 |
-
if saving and args.save_state:
|
709 |
-
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
710 |
-
|
711 |
-
# end of epoch
|
712 |
-
|
713 |
-
metadata["ss_epoch"] = str(num_train_epochs)
|
714 |
-
|
715 |
-
is_main_process = accelerator.is_main_process
|
716 |
-
if is_main_process:
|
717 |
-
network = unwrap_model(network)
|
718 |
-
|
719 |
-
accelerator.end_training()
|
720 |
-
|
721 |
-
if args.save_state:
|
722 |
-
train_util.save_state_on_train_end(args, accelerator)
|
723 |
-
|
724 |
-
del accelerator # この後メモリを使うのでこれは消す
|
725 |
-
|
726 |
-
if is_main_process:
|
727 |
-
os.makedirs(args.output_dir, exist_ok=True)
|
728 |
-
|
729 |
-
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
730 |
-
ckpt_name = model_name + '.' + args.save_model_as
|
731 |
-
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
732 |
-
|
733 |
-
print(f"save trained model to {ckpt_file}")
|
734 |
-
network.save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
735 |
-
print("model saved.")
|
736 |
-
|
737 |
-
|
738 |
-
if __name__ == '__main__':
|
739 |
-
parser = argparse.ArgumentParser()
|
740 |
-
|
741 |
-
train_util.add_sd_models_arguments(parser)
|
742 |
-
train_util.add_dataset_arguments(parser, True, True, True)
|
743 |
-
train_util.add_training_arguments(parser, True)
|
744 |
-
|
745 |
-
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
746 |
-
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
747 |
-
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)")
|
748 |
-
|
749 |
-
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
750 |
-
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
751 |
-
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
752 |
-
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
753 |
-
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
754 |
-
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
755 |
-
|
756 |
-
parser.add_argument("--network_weights", type=str, default=None,
|
757 |
-
help="pretrained weights for network / 学習するネットワークの初期重み")
|
758 |
-
parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール')
|
759 |
-
parser.add_argument("--network_dim", type=int, default=None,
|
760 |
-
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
|
761 |
-
parser.add_argument("--network_alpha", type=float, default=1,
|
762 |
-
help='alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)')
|
763 |
-
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
764 |
-
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
765 |
-
parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する")
|
766 |
-
parser.add_argument("--network_train_text_encoder_only", action="store_true",
|
767 |
-
help="only training Text Encoder part / Text Encoder関連部分のみ学習する")
|
768 |
-
parser.add_argument("--training_comment", type=str, default=None,
|
769 |
-
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
|
770 |
-
|
771 |
-
#Optimizer変更関連のオプション追加
|
772 |
-
append_module.add_append_arguments(parser)
|
773 |
-
args = append_module.get_config(parser)
|
774 |
-
|
775 |
-
if args.resolution==args.min_resolution:
|
776 |
-
args.min_resolution=None
|
777 |
-
|
778 |
-
train(args)
|
779 |
-
|
780 |
-
#学習が終わったら現在のargsを保存する
|
781 |
-
# import yaml
|
782 |
-
# import datetime
|
783 |
-
# _t = datetime.datetime.today().strftime('%Y%m%d_%H%M')
|
784 |
-
# if args.output_name==None:
|
785 |
-
# config_name = f"train_network_config_{_t}.yaml"
|
786 |
-
# else:
|
787 |
-
# config_name = f"train_network_config_{os.path.basename(args.output_name)}_{_t}.yaml"
|
788 |
-
# print(f"{config_name} に設定を書き出し中...")
|
789 |
-
# with open(config_name, mode="w") as f:
|
790 |
-
# yaml.dump(args.__dict__, f, indent=4)
|
791 |
-
# print("done!")
|
792 |
-
|
793 |
-
'''
|
794 |
-
optimizer設定メモ
|
795 |
-
(optimizer_argから設定できるように変更するためのメモ)
|
796 |
-
|
797 |
-
AdamWのweight_decay初期値は1e-2
|
798 |
-
|
799 |
-
RAdam
|
800 |
-
weight_decay=1e-2 (? 初期値は0だけど上手くいかないのでweight_decayを設定したほうがいいのかもしれない)
|
801 |
-
|
802 |
-
AdaBelief
|
803 |
-
eps=1e-16 betas=0.9,0.999 weight_decouple=True rectify=False fiexed_decay=False
|
804 |
-
論文中でTransformerの成績が良かった設定
|
805 |
-
weight_decay=1e-4 weight_decouple=True rectify=True fixed_decay=False
|
806 |
-
|
807 |
-
DiffGrad
|
808 |
-
eps=1e-16
|
809 |
-
MADGRAD
|
810 |
-
eps=1e-6
|
811 |
-
|
812 |
-
NovoGrad
|
813 |
-
論文ではAdamWと比較して汎化性能が物凄く良くなってるらしい
|
814 |
-
再現実装の比較によるとAdamより性能は高いけど収束に必要なstepが多い
|
815 |
-
実験paramater: weight_decay=1e-2 amsgrad=True
|
816 |
-
|
817 |
-
QHAdam
|
818 |
-
収束早い
|
819 |
-
|
820 |
-
Adafactor
|
821 |
-
transformerベースのT5学習において最強とかいう噂のoptimizer
|
822 |
-
huggingfaceのサンプルパラ
|
823 |
-
eps=1e-30,1e-3 clip_threshold=1.0 decay_rate=-0.8 relative_step=False scale_parameter=False warmup_init=False
|
824 |
-
|
825 |
-
AggMo
|
826 |
-
|
827 |
-
Apollo(学習内容によってやたらとVRAM消費が増え続けるので非推奨)
|
828 |
-
|
829 |
-
未整頓のパラメータメモ
|
830 |
-
Ranger(収束できる設定が見つかる気がしない)
|
831 |
-
RangerVA(収束できる設定が見つかる気がしない)
|
832 |
-
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|