|
import gc |
|
import json |
|
import time |
|
from functools import partial |
|
from typing import Union |
|
import os |
|
import tkinter as tk |
|
from tkinter import filedialog as fd, ttk |
|
from tkinter import simpledialog as sd |
|
from tkinter import messagebox as mb |
|
|
|
import torch.cuda |
|
import train_network |
|
import library.train_util as util |
|
import argparse |
|
|
|
|
|
class ArgStore: |
|
|
|
def __init__(self): |
|
|
|
self.base_model: str = r"" |
|
self.img_folder: str = r"" |
|
self.output_folder: str = r"" |
|
self.change_output_name: Union[str, None] = None |
|
self.save_json_folder: Union[str, None] = None |
|
self.load_json_path: Union[str, None] = None |
|
self.json_load_skip_list: Union[list[str], None] = ["save_json_folder", "reg_img_folder", |
|
"lora_model_for_resume", "change_output_name", |
|
"training_comment", |
|
"json_load_skip_list"] |
|
self.caption_dropout_rate: Union[float, None] = None |
|
self.caption_dropout_every_n_epochs: Union[int, None] = None |
|
|
|
self.caption_tag_dropout_rate: Union[float, None] = None |
|
self.noise_offset: Union[float, None] = None |
|
|
|
|
|
|
|
self.net_dim: int = 128 |
|
self.alpha: float = 128 |
|
|
|
self.scheduler: str = "cosine_with_restarts" |
|
self.cosine_restarts: Union[int, None] = 1 |
|
self.scheduler_power: Union[float, None] = 1 |
|
self.warmup_lr_ratio: Union[float, None] = None |
|
self.learning_rate: Union[float, None] = 1e-4 |
|
self.text_encoder_lr: Union[float, None] = None |
|
self.unet_lr: Union[float, None] = None |
|
self.num_workers: int = 1 |
|
self.persistent_workers: bool = True |
|
|
|
self.batch_size: int = 1 |
|
self.num_epochs: int = 1 |
|
self.save_every_n_epochs: Union[int, None] = None |
|
self.shuffle_captions: bool = False |
|
self.keep_tokens: Union[int, None] = None |
|
self.max_steps: Union[int, None] = None |
|
self.tag_occurrence_txt_file: bool = False |
|
|
|
|
|
self.sort_tag_occurrence_alphabetically: bool = False |
|
|
|
|
|
|
|
self.train_resolution: int = 512 |
|
self.min_bucket_resolution: int = 320 |
|
self.max_bucket_resolution: int = 960 |
|
self.lora_model_for_resume: Union[str, None] = None |
|
self.save_state: bool = False |
|
self.load_previous_save_state: Union[str, None] = None |
|
self.training_comment: Union[str, None] = None |
|
self.unet_only: bool = False |
|
self.text_only: bool = False |
|
|
|
|
|
self.reg_img_folder: Union[str, None] = None |
|
self.clip_skip: int = 2 |
|
self.test_seed: int = 23 |
|
self.prior_loss_weight: float = 1 |
|
self.gradient_checkpointing: bool = False |
|
self.gradient_acc_steps: Union[int, None] = None |
|
self.mixed_precision: str = "fp16" |
|
self.save_precision: str = "fp16" |
|
self.save_as: str = "safetensors" |
|
self.caption_extension: str = ".txt" |
|
self.max_clip_token_length = 150 |
|
self.buckets: bool = True |
|
self.xformers: bool = True |
|
self.use_8bit_adam: bool = True |
|
self.cache_latents: bool = True |
|
self.color_aug: bool = False |
|
self.flip_aug: bool = False |
|
self.vae: Union[str, None] = None |
|
self.no_meta: bool = False |
|
self.log_dir: Union[str, None] = None |
|
self.v2: bool = False |
|
self.v_parameterization: bool = False |
|
|
|
|
|
@staticmethod |
|
def convert_args_to_dict(): |
|
return ArgStore().__dict__ |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
setup_args(parser) |
|
pre_args = parser.parse_args() |
|
queues = 0 |
|
args_queue = [] |
|
cont = True |
|
while cont: |
|
arg_dict = ArgStore.convert_args_to_dict() |
|
ret = mb.askyesno(message="Do you want to load a json config file?") |
|
if ret: |
|
load_json(ask_file("select json to load from", {"json"}), arg_dict) |
|
arg_dict = ask_elements_trunc(arg_dict) |
|
else: |
|
arg_dict = ask_elements(arg_dict) |
|
if pre_args.save_json_path or arg_dict["save_json_folder"]: |
|
save_json(pre_args.save_json_path if pre_args.save_json_path else arg_dict['save_json_folder'], arg_dict) |
|
args = create_arg_space(arg_dict) |
|
args = parser.parse_args(args) |
|
queues += 1 |
|
args_queue.append(args) |
|
if arg_dict['tag_occurrence_txt_file']: |
|
get_occurrence_of_tags(arg_dict) |
|
ret = mb.askyesno(message="Do you want to queue another training?") |
|
if not ret: |
|
cont = False |
|
for args in args_queue: |
|
try: |
|
train_network.train(args) |
|
except Exception as e: |
|
print(f"Failed to train this set of args.\nSkipping this training session.\nError is: {e}") |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def create_arg_space(args: dict) -> [str]: |
|
|
|
output = ["--network_module=networks.lora", f"--pretrained_model_name_or_path={args['base_model']}", |
|
f"--train_data_dir={args['img_folder']}", f"--output_dir={args['output_folder']}", |
|
f"--prior_loss_weight={args['prior_loss_weight']}", f"--caption_extension=" + args['caption_extension'], |
|
f"--resolution={args['train_resolution']}", f"--train_batch_size={args['batch_size']}", |
|
f"--mixed_precision={args['mixed_precision']}", f"--save_precision={args['save_precision']}", |
|
f"--network_dim={args['net_dim']}", f"--save_model_as={args['save_as']}", |
|
f"--clip_skip={args['clip_skip']}", f"--seed={args['test_seed']}", |
|
f"--max_token_length={args['max_clip_token_length']}", f"--lr_scheduler={args['scheduler']}", |
|
f"--network_alpha={args['alpha']}", f"--max_data_loader_n_workers={args['num_workers']}"] |
|
if not args['max_steps']: |
|
output.append(f"--max_train_epochs={args['num_epochs']}") |
|
output += create_optional_args(args, find_max_steps(args)) |
|
else: |
|
output.append(f"--max_train_steps={args['max_steps']}") |
|
output += create_optional_args(args, args['max_steps']) |
|
return output |
|
|
|
|
|
def create_optional_args(args: dict, steps): |
|
output = [] |
|
if args["reg_img_folder"]: |
|
output.append(f"--reg_data_dir={args['reg_img_folder']}") |
|
|
|
if args['lora_model_for_resume']: |
|
output.append(f"--network_weights={args['lora_model_for_resume']}") |
|
|
|
if args['save_every_n_epochs']: |
|
output.append(f"--save_every_n_epochs={args['save_every_n_epochs']}") |
|
else: |
|
output.append("--save_every_n_epochs=999999") |
|
|
|
if args['shuffle_captions']: |
|
output.append("--shuffle_caption") |
|
|
|
if args['keep_tokens'] and args['keep_tokens'] > 0: |
|
output.append(f"--keep_tokens={args['keep_tokens']}") |
|
|
|
if args['buckets']: |
|
output.append("--enable_bucket") |
|
output.append(f"--min_bucket_reso={args['min_bucket_resolution']}") |
|
output.append(f"--max_bucket_reso={args['max_bucket_resolution']}") |
|
|
|
if args['use_8bit_adam']: |
|
output.append("--use_8bit_adam") |
|
|
|
if args['xformers']: |
|
output.append("--xformers") |
|
|
|
if args['color_aug']: |
|
if args['cache_latents']: |
|
print("color_aug and cache_latents conflict with one another. Please select only one") |
|
quit(1) |
|
output.append("--color_aug") |
|
|
|
if args['flip_aug']: |
|
output.append("--flip_aug") |
|
|
|
if args['cache_latents']: |
|
output.append("--cache_latents") |
|
|
|
if args['warmup_lr_ratio'] and args['warmup_lr_ratio'] > 0: |
|
warmup_steps = int(steps * args['warmup_lr_ratio']) |
|
output.append(f"--lr_warmup_steps={warmup_steps}") |
|
|
|
if args['gradient_checkpointing']: |
|
output.append("--gradient_checkpointing") |
|
|
|
if args['gradient_acc_steps'] and args['gradient_acc_steps'] > 0 and args['gradient_checkpointing']: |
|
output.append(f"--gradient_accumulation_steps={args['gradient_acc_steps']}") |
|
|
|
if args['learning_rate'] and args['learning_rate'] > 0: |
|
output.append(f"--learning_rate={args['learning_rate']}") |
|
|
|
if args['text_encoder_lr'] and args['text_encoder_lr'] > 0: |
|
output.append(f"--text_encoder_lr={args['text_encoder_lr']}") |
|
|
|
if args['unet_lr'] and args['unet_lr'] > 0: |
|
output.append(f"--unet_lr={args['unet_lr']}") |
|
|
|
if args['vae']: |
|
output.append(f"--vae={args['vae']}") |
|
|
|
if args['no_meta']: |
|
output.append("--no_metadata") |
|
|
|
if args['save_state']: |
|
output.append("--save_state") |
|
|
|
if args['load_previous_save_state']: |
|
output.append(f"--resume={args['load_previous_save_state']}") |
|
|
|
if args['change_output_name']: |
|
output.append(f"--output_name={args['change_output_name']}") |
|
|
|
if args['training_comment']: |
|
output.append(f"--training_comment={args['training_comment']}") |
|
|
|
if args['cosine_restarts'] and args['scheduler'] == "cosine_with_restarts": |
|
output.append(f"--lr_scheduler_num_cycles={args['cosine_restarts']}") |
|
|
|
if args['scheduler_power'] and args['scheduler'] == "polynomial": |
|
output.append(f"--lr_scheduler_power={args['scheduler_power']}") |
|
|
|
if args['persistent_workers']: |
|
output.append(f"--persistent_data_loader_workers") |
|
|
|
if args['unet_only']: |
|
output.append("--network_train_unet_only") |
|
|
|
if args['text_only'] and not args['unet_only']: |
|
output.append("--network_train_text_encoder_only") |
|
|
|
if args["log_dir"]: |
|
output.append(f"--logging_dir={args['log_dir']}") |
|
|
|
if args['caption_dropout_rate']: |
|
output.append(f"--caption_dropout_rate={args['caption_dropout_rate']}") |
|
|
|
if args['caption_dropout_every_n_epochs']: |
|
output.append(f"--caption_dropout_every_n_epochs={args['caption_dropout_every_n_epochs']}") |
|
|
|
if args['caption_tag_dropout_rate']: |
|
output.append(f"--caption_tag_dropout_rate={args['caption_tag_dropout_rate']}") |
|
|
|
if args['v2']: |
|
output.append("--v2") |
|
|
|
if args['v2'] and args['v_parameterization']: |
|
output.append("--v_parameterization") |
|
|
|
if args['noise_offset']: |
|
output.append(f"--noise_offset={args['noise_offset']}") |
|
return output |
|
|
|
|
|
def find_max_steps(args: dict) -> int: |
|
total_steps = 0 |
|
folders = os.listdir(args["img_folder"]) |
|
for folder in folders: |
|
if not os.path.isdir(os.path.join(args["img_folder"], folder)): |
|
continue |
|
num_repeats = folder.split("_") |
|
if len(num_repeats) < 2: |
|
print(f"folder {folder} is not in the correct format. Format is x_name. skipping") |
|
continue |
|
try: |
|
num_repeats = int(num_repeats[0]) |
|
except ValueError: |
|
print(f"folder {folder} is not in the correct format. Format is x_name. skipping") |
|
continue |
|
imgs = 0 |
|
for file in os.listdir(os.path.join(args["img_folder"], folder)): |
|
if os.path.isdir(file): |
|
continue |
|
ext = file.split(".") |
|
if ext[-1].lower() in {"png", "bmp", "gif", "jpeg", "jpg", "webp"}: |
|
imgs += 1 |
|
total_steps += (num_repeats * imgs) |
|
total_steps = int((total_steps / args["batch_size"]) * args["num_epochs"]) |
|
return total_steps |
|
|
|
|
|
def add_misc_args(parser): |
|
parser.add_argument("--save_json_path", type=str, default=None, |
|
help="Path to save a configuration json file to") |
|
parser.add_argument("--load_json_path", type=str, default=None, |
|
help="Path to a json file to configure things from") |
|
parser.add_argument("--no_metadata", action='store_true', |
|
help="do not save metadata in output model / メタデータを出力先モデルに保存しない") |
|
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"], |
|
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)") |
|
|
|
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") |
|
parser.add_argument("--text_encoder_lr", type=float, default=None, |
|
help="learning rate for Text Encoder / Text Encoderの学習率") |
|
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1, |
|
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数") |
|
parser.add_argument("--lr_scheduler_power", type=float, default=1, |
|
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power") |
|
|
|
parser.add_argument("--network_weights", type=str, default=None, |
|
help="pretrained weights for network / 学習するネットワークの初期重み") |
|
parser.add_argument("--network_module", type=str, default=None, |
|
help='network module to train / 学習対象のネットワークのモジュール') |
|
parser.add_argument("--network_dim", type=int, default=None, |
|
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') |
|
parser.add_argument("--network_alpha", type=float, default=1, |
|
help='alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)') |
|
parser.add_argument("--network_args", type=str, default=None, nargs='*', |
|
help='additional argmuments for network (key=value) / ネットワークへの追加の引数') |
|
parser.add_argument("--network_train_unet_only", action="store_true", |
|
help="only training U-Net part / U-Net関連部分のみ学習する") |
|
parser.add_argument("--network_train_text_encoder_only", action="store_true", |
|
help="only training Text Encoder part / Text Encoder関連部分のみ学習する") |
|
parser.add_argument("--training_comment", type=str, default=None, |
|
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列") |
|
|
|
|
|
def setup_args(parser): |
|
util.add_sd_models_arguments(parser) |
|
util.add_dataset_arguments(parser, True, True, True) |
|
util.add_training_arguments(parser, True) |
|
add_misc_args(parser) |
|
|
|
|
|
def get_occurrence_of_tags(args): |
|
extension = args['caption_extension'] |
|
img_folder = args['img_folder'] |
|
output_folder = args['output_folder'] |
|
occurrence_dict = {} |
|
print(img_folder) |
|
for folder in os.listdir(img_folder): |
|
print(folder) |
|
if not os.path.isdir(os.path.join(img_folder, folder)): |
|
continue |
|
for file in os.listdir(os.path.join(img_folder, folder)): |
|
if not os.path.isfile(os.path.join(img_folder, folder, file)): |
|
continue |
|
ext = os.path.splitext(file)[1] |
|
if ext != extension: |
|
continue |
|
get_tags_from_file(os.path.join(img_folder, folder, file), occurrence_dict) |
|
if not args['sort_tag_occurrence_alphabetically']: |
|
output_list = {k: v for k, v in sorted(occurrence_dict.items(), key=lambda item: item[1], reverse=True)} |
|
else: |
|
output_list = {k: v for k, v in sorted(occurrence_dict.items(), key=lambda item: item[0])} |
|
name = args['change_output_name'] if args['change_output_name'] else "last" |
|
with open(os.path.join(output_folder, f"{name}.txt"), "w") as f: |
|
f.write(f"Below is a list of keywords used during the training of {args['change_output_name']}:\n") |
|
for k, v in output_list.items(): |
|
f.write(f"[{v}] {k}\n") |
|
print(f"Created a txt file named {name}.txt in the output folder") |
|
|
|
|
|
def get_tags_from_file(file, occurrence_dict): |
|
f = open(file) |
|
temp = f.read().replace(", ", ",").split(",") |
|
f.close() |
|
for tag in temp: |
|
if tag in occurrence_dict: |
|
occurrence_dict[tag] += 1 |
|
else: |
|
occurrence_dict[tag] = 1 |
|
|
|
|
|
def ask_file(message, accepted_ext_list, file_path=None): |
|
mb.showinfo(message=message) |
|
res = "" |
|
_initialdir = "" |
|
_initialfile = "" |
|
if file_path != None: |
|
_initialdir = os.path.dirname(file_path) if os.path.exists(file_path) else "" |
|
_initialfile = os.path.basename(file_path) if os.path.exists(file_path) else "" |
|
|
|
while res == "": |
|
res = fd.askopenfilename(title=message, initialdir=_initialdir, initialfile=_initialfile) |
|
if res == "" or type(res) == tuple: |
|
ret = mb.askretrycancel(message="Do you want to to cancel training?") |
|
if not ret: |
|
exit() |
|
continue |
|
elif not os.path.exists(res): |
|
res = "" |
|
continue |
|
_, name = os.path.split(res) |
|
split_name = name.split(".") |
|
if split_name[-1] not in accepted_ext_list: |
|
res = "" |
|
return res |
|
|
|
|
|
def ask_dir(message, dir_path=None): |
|
mb.showinfo(message=message) |
|
res = "" |
|
_initialdir = "" |
|
if dir_path != None: |
|
_initialdir = dir_path if os.path.exists(dir_path) else "" |
|
while res == "": |
|
res = fd.askdirectory(title=message, initialdir=_initialdir) |
|
if res == "" or type(res) == tuple: |
|
ret = mb.askretrycancel(message="Do you want to to cancel training?") |
|
if not ret: |
|
exit() |
|
continue |
|
if not os.path.exists(res): |
|
res = "" |
|
return res |
|
|
|
|
|
def ask_elements_trunc(args: dict): |
|
args['base_model'] = ask_file("Select your base model", {"ckpt", "safetensors"}, args['base_model']) |
|
args['img_folder'] = ask_dir("Select your image folder", args['img_folder']) |
|
args['output_folder'] = ask_dir("Select your output folder", args['output_folder']) |
|
|
|
ret = mb.askyesno(message="Do you want to save a json of your configuration?") |
|
if ret: |
|
args['save_json_folder'] = ask_dir("Select the folder to save json files to", args['save_json_folder']) |
|
else: |
|
args['save_json_folder'] = None |
|
|
|
ret = mb.askyesno(message="Are you training on a SD2 based model?") |
|
if ret: |
|
args['v2'] = True |
|
|
|
ret = mb.askyesno(message="Are you training on an realistic model?") |
|
if ret: |
|
args['clip_skip'] = 1 |
|
|
|
if args['v2']: |
|
ret = mb.askyesno(message="Are you training on a model based on the 768x version of SD2?") |
|
if ret: |
|
args['v_parameterization'] = True |
|
|
|
ret = mb.askyesno(message="Do you want to use regularization images?") |
|
if ret: |
|
args['reg_img_folder'] = ask_dir("Select your regularization folder", args['reg_img_folder']) |
|
else: |
|
args['reg_img_folder'] = None |
|
|
|
ret = mb.askyesno(message="Do you want to continue from an earlier version?") |
|
if ret: |
|
args['lora_model_for_resume'] = ask_file("Select your lora model", {"ckpt", "pt", "safetensors"}, |
|
args['lora_model_for_resume']) |
|
else: |
|
args['lora_model_for_resume'] = None |
|
|
|
ret = mb.askyesno(message="Do you want to flip all of your images? It is supposed to reduce biases\n" |
|
"within your dataset but it can also ruin learning an asymmetrical element\n") |
|
if ret: |
|
args['flip_aug'] = True |
|
|
|
ret = mb.askyesno(message="Do you want to change the name of output checkpoints?") |
|
if ret: |
|
ret = sd.askstring(title="output_name", prompt="What do you want your output name to be?\n" |
|
"Cancel keeps outputs the original") |
|
if ret: |
|
args['change_output_name'] = ret |
|
else: |
|
args['change_output_name'] = None |
|
|
|
ret = sd.askstring(title="comment", |
|
prompt="Do you want to set a comment that gets put into the metadata?\nA good use of this would " |
|
"be to include how to use, such as activation keywords.\nCancel will leave empty") |
|
if ret is None: |
|
args['training_comment'] = ret |
|
else: |
|
args['training_comment'] = None |
|
|
|
ret = mb.askyesno(message="Do you want to train only one of unet and text encoder?") |
|
if ret: |
|
button = ButtonBox("Which do you want to train with?", ["unet_only", "text_only"]) |
|
button.window.mainloop() |
|
if button.current_value != "": |
|
args[button.current_value] = True |
|
|
|
ret = mb.askyesno(message="Do you want to save a txt file that contains a list\n" |
|
"of all tags that you have used in your training data?\n") |
|
if ret: |
|
args['tag_occurrence_txt_file'] = True |
|
button = ButtonBox("How do you want tags to be ordered?", ["alphabetically", "occurrence-ly"]) |
|
button.window.mainloop() |
|
if button.current_value == "alphabetically": |
|
args['sort_tag_occurrence_alphabetically'] = True |
|
|
|
ret = mb.askyesno(message="Do you want to use caption dropout?") |
|
if ret: |
|
ret = mb.askyesno(message="Do you want full caption files to dropout randomly?") |
|
if ret: |
|
ret = sd.askinteger(title="Caption_File_Dropout", |
|
prompt="How often do you want caption files to drop out?\n" |
|
"enter a number from 0 to 100 that is the percentage chance of dropout\n" |
|
"Cancel sets to 0") |
|
if ret and 0 <= ret <= 100: |
|
args['caption_dropout_rate'] = ret / 100.0 |
|
|
|
ret = mb.askyesno(message="Do you want to have full epochs have no captions?") |
|
if ret: |
|
ret = sd.askinteger(title="Caption_epoch_dropout", prompt="The number set here is how often you will have an" |
|
"epoch with no captions\nSo if you set 3, then every" |
|
"three epochs will not have captions (3, 6, 9)\n" |
|
"Cancel will set to None") |
|
if ret: |
|
args['caption_dropout_every_n_epochs'] = ret |
|
|
|
ret = mb.askyesno(message="Do you want to have tags to randomly drop?") |
|
if ret: |
|
ret = sd.askinteger(title="Caption_tag_dropout", prompt="How often do you want tags to randomly drop out?\n" |
|
"Enter a number between 0 and 100, that is the percentage" |
|
"chance of dropout.\nCancel sets to 0") |
|
if ret and 0 <= ret <= 100: |
|
args['caption_tag_dropout_rate'] = ret / 100.0 |
|
|
|
ret = mb.askyesno(message="Do you want to use noise offset? Noise offset seems to allow for SD to better generate\n" |
|
"darker or lighter images using this than normal.") |
|
if ret: |
|
ret = sd.askfloat(title="noise_offset", prompt="What value do you want to set? recommended value is 0.1,\n" |
|
"but it can go higher. Cancel defaults to 0.1") |
|
if ret: |
|
args['noise_offset'] = ret |
|
else: |
|
args['noise_offset'] = 0.1 |
|
return args |
|
|
|
|
|
def ask_elements(args: dict): |
|
|
|
args['base_model'] = ask_file("Select your base model", {"ckpt", "safetensors"}, args['base_model']) |
|
args['img_folder'] = ask_dir("Select your image folder", args['img_folder']) |
|
args['output_folder'] = ask_dir("Select your output folder", args['output_folder']) |
|
|
|
|
|
ret = mb.askyesno(message="Do you want to save a json of your configuration?") |
|
if ret: |
|
args['save_json_folder'] = ask_dir("Select the folder to save json files to", args['save_json_folder']) |
|
else: |
|
args['save_json_folder'] = None |
|
|
|
ret = mb.askyesno(message="Are you training on a SD2 based model?") |
|
if ret: |
|
args['v2'] = True |
|
|
|
ret = mb.askyesno(message="Are you training on an realistic model?") |
|
if ret: |
|
args['clip_skip'] = 1 |
|
|
|
if args['v2']: |
|
ret = mb.askyesno(message="Are you training on a model based on the 768x version of SD2?") |
|
if ret: |
|
args['v_parameterization'] = True |
|
|
|
ret = mb.askyesno(message="Do you want to use regularization images?") |
|
if ret: |
|
args['reg_img_folder'] = ask_dir("Select your regularization folder", args['reg_img_folder']) |
|
else: |
|
args['reg_img_folder'] = None |
|
|
|
ret = mb.askyesno(message="Do you want to continue from an earlier version?") |
|
if ret: |
|
args['lora_model_for_resume'] = ask_file("Select your lora model", {"ckpt", "pt", "safetensors"}, |
|
args['lora_model_for_resume']) |
|
else: |
|
args['lora_model_for_resume'] = None |
|
|
|
ret = mb.askyesno(message="Do you want to flip all of your images? It is supposed to reduce biases\n" |
|
"within your dataset but it can also ruin learning an asymmetrical element\n") |
|
if ret: |
|
args['flip_aug'] = True |
|
|
|
|
|
ret = sd.askinteger(title="batch_size", |
|
prompt="The number of images that get processed at one time, this is directly proportional to " |
|
"your vram and resolution. with 12gb of vram, at 512 reso, you can get a maximum of 6 " |
|
"batch size\nHow large is your batch size going to be?\nCancel will default to 1") |
|
if ret is None: |
|
args['batch_size'] = 1 |
|
else: |
|
args['batch_size'] = ret |
|
|
|
ret = sd.askinteger(title="num_epochs", prompt="How many epochs do you want?\nCancel will default to 1") |
|
if ret is None: |
|
args['num_epochs'] = 1 |
|
else: |
|
args['num_epochs'] = ret |
|
|
|
ret = sd.askinteger(title="network_dim", prompt="What is the dim size you want to use?\nCancel will default to 128") |
|
if ret is None: |
|
args['net_dim'] = 128 |
|
else: |
|
args['net_dim'] = ret |
|
|
|
ret = sd.askfloat(title="alpha", prompt="Alpha is the scalar of the training, generally a good starting point is " |
|
"0.5x dim size\nWhat Alpha do you want?\nCancel will default to equal to " |
|
"0.5 x network_dim") |
|
if ret is None: |
|
args['alpha'] = args['net_dim'] / 2 |
|
else: |
|
args['alpha'] = ret |
|
|
|
ret = sd.askinteger(title="resolution", prompt="How large of a resolution do you want to train at?\n" |
|
"Cancel will default to 512") |
|
if ret is None: |
|
args['train_resolution'] = 512 |
|
else: |
|
args['train_resolution'] = ret |
|
|
|
ret = sd.askfloat(title="learning_rate", prompt="What learning rate do you want to use?\n" |
|
"Cancel will default to 1e-4") |
|
if ret is None: |
|
args['learning_rate'] = 1e-4 |
|
else: |
|
args['learning_rate'] = ret |
|
|
|
ret = sd.askfloat(title="text_encoder_lr", prompt="Do you want to set the text_encoder_lr?\n" |
|
"Cancel will default to None") |
|
if ret is None: |
|
args['text_encoder_lr'] = None |
|
else: |
|
args['text_encoder_lr'] = ret |
|
|
|
ret = sd.askfloat(title="unet_lr", prompt="Do you want to set the unet_lr?\nCancel will default to None") |
|
if ret is None: |
|
args['unet_lr'] = None |
|
else: |
|
args['unet_lr'] = ret |
|
|
|
button = ButtonBox("Which scheduler do you want?", ["cosine_with_restarts", "cosine", "polynomial", |
|
"constant", "constant_with_warmup", "linear"]) |
|
button.window.mainloop() |
|
args['scheduler'] = button.current_value if button.current_value != "" else "cosine_with_restarts" |
|
|
|
if args['scheduler'] == "cosine_with_restarts": |
|
ret = sd.askinteger(title="Cycle Count", |
|
prompt="How many times do you want cosine to restart?\nThis is the entire amount of times " |
|
"it will restart for the entire training\nCancel will default to 1") |
|
if ret is None: |
|
args['cosine_restarts'] = 1 |
|
else: |
|
args['cosine_restarts'] = ret |
|
|
|
if args['scheduler'] == "polynomial": |
|
ret = sd.askfloat(title="Poly Strength", |
|
prompt="What power do you want to set your polynomial to?\nhigher power means that the " |
|
"model reduces the learning more more aggressively from initial training.\n1 = " |
|
"linear\nCancel sets to 1") |
|
if ret is None: |
|
args['scheduler_power'] = 1 |
|
else: |
|
args['scheduler_power'] = ret |
|
|
|
ret = mb.askyesno(message="Do you want to save epochs as it trains?") |
|
if ret: |
|
ret = sd.askinteger(title="save_epoch", |
|
prompt="How often do you want to save epochs?\nCancel will default to 1") |
|
if ret is None: |
|
args['save_every_n_epochs'] = 1 |
|
else: |
|
args['save_every_n_epochs'] = ret |
|
|
|
ret = mb.askyesno(message="Do you want to shuffle captions?") |
|
if ret: |
|
args['shuffle_captions'] = True |
|
else: |
|
args['shuffle_captions'] = False |
|
|
|
ret = mb.askyesno(message="Do you want to keep some tokens at the front of your captions?") |
|
if ret: |
|
ret = sd.askinteger(title="keep_tokens", prompt="How many do you want to keep at the front?" |
|
"\nCancel will default to 1") |
|
if ret is None: |
|
args['keep_tokens'] = 1 |
|
else: |
|
args['keep_tokens'] = ret |
|
|
|
ret = mb.askyesno(message="Do you want to have a warmup ratio?") |
|
if ret: |
|
ret = sd.askfloat(title="warmup_ratio", prompt="What is the ratio of steps to use as warmup " |
|
"steps?\nCancel will default to None") |
|
if ret is None: |
|
args['warmup_lr_ratio'] = None |
|
else: |
|
args['warmup_lr_ratio'] = ret |
|
|
|
ret = mb.askyesno(message="Do you want to change the name of output checkpoints?") |
|
if ret: |
|
ret = sd.askstring(title="output_name", prompt="What do you want your output name to be?\n" |
|
"Cancel keeps outputs the original") |
|
if ret: |
|
args['change_output_name'] = ret |
|
else: |
|
args['change_output_name'] = None |
|
|
|
ret = sd.askstring(title="comment", |
|
prompt="Do you want to set a comment that gets put into the metadata?\nA good use of this would " |
|
"be to include how to use, such as activation keywords.\nCancel will leave empty") |
|
if ret is None: |
|
args['training_comment'] = ret |
|
else: |
|
args['training_comment'] = None |
|
|
|
ret = mb.askyesno(message="Do you want to train only one of unet and text encoder?") |
|
if ret: |
|
if ret: |
|
button = ButtonBox("Which do you want to train with?", ["unet_only", "text_only"]) |
|
button.window.mainloop() |
|
if button.current_value != "": |
|
args[button.current_value] = True |
|
|
|
ret = mb.askyesno(message="Do you want to save a txt file that contains a list\n" |
|
"of all tags that you have used in your training data?\n") |
|
if ret: |
|
args['tag_occurrence_txt_file'] = True |
|
button = ButtonBox("How do you want tags to be ordered?", ["alphabetically", "occurrence-ly"]) |
|
button.window.mainloop() |
|
if button.current_value == "alphabetically": |
|
args['sort_tag_occurrence_alphabetically'] = True |
|
|
|
ret = mb.askyesno(message="Do you want to use caption dropout?") |
|
if ret: |
|
ret = mb.askyesno(message="Do you want full caption files to dropout randomly?") |
|
if ret: |
|
ret = sd.askinteger(title="Caption_File_Dropout", |
|
prompt="How often do you want caption files to drop out?\n" |
|
"enter a number from 0 to 100 that is the percentage chance of dropout\n" |
|
"Cancel sets to 0") |
|
if ret and 0 <= ret <= 100: |
|
args['caption_dropout_rate'] = ret / 100.0 |
|
|
|
ret = mb.askyesno(message="Do you want to have full epochs have no captions?") |
|
if ret: |
|
ret = sd.askinteger(title="Caption_epoch_dropout", prompt="The number set here is how often you will have an" |
|
"epoch with no captions\nSo if you set 3, then every" |
|
"three epochs will not have captions (3, 6, 9)\n" |
|
"Cancel will set to None") |
|
if ret: |
|
args['caption_dropout_every_n_epochs'] = ret |
|
|
|
ret = mb.askyesno(message="Do you want to have tags to randomly drop?") |
|
if ret: |
|
ret = sd.askinteger(title="Caption_tag_dropout", prompt="How often do you want tags to randomly drop out?\n" |
|
"Enter a number between 0 and 100, that is the percentage" |
|
"chance of dropout.\nCancel sets to 0") |
|
if ret and 0 <= ret <= 100: |
|
args['caption_tag_dropout_rate'] = ret / 100.0 |
|
|
|
ret = mb.askyesno(message="Do you want to use noise offset? Noise offset seems to allow for SD to better generate\n" |
|
"darker or lighter images using this than normal.") |
|
if ret: |
|
ret = sd.askfloat(title="noise_offset", prompt="What value do you want to set? recommended value is 0.1,\n" |
|
"but it can go higher. Cancel defaults to 0.1") |
|
if ret: |
|
args['noise_offset'] = ret |
|
else: |
|
args['noise_offset'] = 0.1 |
|
return args |
|
|
|
|
|
def save_json(path, obj: dict) -> None: |
|
fp = open(os.path.join(path, f"config-{time.time()}.json"), "w") |
|
json.dump(obj, fp=fp, indent=4) |
|
fp.close() |
|
|
|
|
|
def load_json(path, obj: dict) -> dict: |
|
with open(path) as f: |
|
json_obj = json.loads(f.read()) |
|
print("loaded json, setting variables...") |
|
ui_name_scheme = {"pretrained_model_name_or_path": "base_model", "logging_dir": "log_dir", |
|
"train_data_dir": "img_folder", "reg_data_dir": "reg_img_folder", |
|
"output_dir": "output_folder", "max_resolution": "train_resolution", |
|
"lr_scheduler": "scheduler", "lr_warmup": "warmup_lr_ratio", |
|
"train_batch_size": "batch_size", "epoch": "num_epochs", |
|
"save_at_n_epochs": "save_every_n_epochs", "num_cpu_threads_per_process": "num_workers", |
|
"enable_bucket": "buckets", "save_model_as": "save_as", "shuffle_caption": "shuffle_captions", |
|
"resume": "load_previous_save_state", "network_dim": "net_dim", |
|
"gradient_accumulation_steps": "gradient_acc_steps", "output_name": "change_output_name", |
|
"network_alpha": "alpha", "lr_scheduler_num_cycles": "cosine_restarts", |
|
"lr_scheduler_power": "scheduler_power"} |
|
|
|
for key in list(json_obj): |
|
if key in ui_name_scheme: |
|
json_obj[ui_name_scheme[key]] = json_obj[key] |
|
if ui_name_scheme[key] in {"batch_size", "num_epochs"}: |
|
try: |
|
json_obj[ui_name_scheme[key]] = int(json_obj[ui_name_scheme[key]]) |
|
except ValueError: |
|
print(f"attempting to load {key} from json failed as input isn't an integer") |
|
quit(1) |
|
|
|
for key in list(json_obj): |
|
if obj["json_load_skip_list"] and key in obj["json_load_skip_list"]: |
|
continue |
|
if key in obj: |
|
if key in {"keep_tokens", "warmup_lr_ratio"}: |
|
json_obj[key] = int(json_obj[key]) if json_obj[key] is not None else None |
|
if key in {"learning_rate", "unet_lr", "text_encoder_lr"}: |
|
json_obj[key] = float(json_obj[key]) if json_obj[key] is not None else None |
|
if obj[key] != json_obj[key]: |
|
print_change(key, obj[key], json_obj[key]) |
|
obj[key] = json_obj[key] |
|
print("completed changing variables.") |
|
return obj |
|
|
|
|
|
def print_change(value, old, new): |
|
print(f"{value} changed from {old} to {new}") |
|
|
|
|
|
class ButtonBox: |
|
def __init__(self, label: str, button_name_list: list[str]) -> None: |
|
self.window = tk.Tk() |
|
self.button_list = [] |
|
self.current_value = "" |
|
|
|
self.window.attributes("-topmost", True) |
|
self.window.resizable(False, False) |
|
self.window.eval('tk::PlaceWindow . center') |
|
|
|
def del_window(): |
|
self.window.quit() |
|
self.window.destroy() |
|
|
|
self.window.protocol("WM_DELETE_WINDOW", del_window) |
|
tk.Label(text=label, master=self.window).pack() |
|
for button in button_name_list: |
|
self.button_list.append(ttk.Button(text=button, master=self.window, |
|
command=partial(self.set_current_value, button))) |
|
self.button_list[-1].pack() |
|
|
|
def set_current_value(self, value): |
|
self.current_value = value |
|
self.window.quit() |
|
self.window.destroy() |
|
|
|
|
|
root = tk.Tk() |
|
root.attributes('-topmost', True) |
|
root.withdraw() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|