|
|
|
|
|
import argparse |
|
import ast |
|
import asyncio |
|
import datetime |
|
import importlib |
|
import json |
|
import pathlib |
|
import re |
|
import shutil |
|
import time |
|
from typing import ( |
|
Dict, |
|
List, |
|
NamedTuple, |
|
Optional, |
|
Sequence, |
|
Tuple, |
|
Union, |
|
) |
|
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs |
|
import gc |
|
import glob |
|
import math |
|
import os |
|
import random |
|
import hashlib |
|
import subprocess |
|
from io import BytesIO |
|
import toml |
|
|
|
from tqdm import tqdm |
|
import torch |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
from torch.optim import Optimizer |
|
from torchvision import transforms |
|
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection |
|
import transformers |
|
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION |
|
from diffusers import ( |
|
StableDiffusionPipeline, |
|
DDPMScheduler, |
|
EulerAncestralDiscreteScheduler, |
|
DPMSolverMultistepScheduler, |
|
DPMSolverSinglestepScheduler, |
|
LMSDiscreteScheduler, |
|
PNDMScheduler, |
|
DDIMScheduler, |
|
EulerDiscreteScheduler, |
|
HeunDiscreteScheduler, |
|
KDPM2DiscreteScheduler, |
|
KDPM2AncestralDiscreteScheduler, |
|
AutoencoderKL, |
|
) |
|
from external.llite.library import custom_train_functions |
|
from external.llite.library.original_unet import UNet2DConditionModel |
|
from huggingface_hub import hf_hub_download |
|
import numpy as np |
|
from PIL import Image |
|
import cv2 |
|
import safetensors.torch |
|
from external.llite.library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline |
|
import external.llite.library.model_util as model_util |
|
import external.llite.library.huggingface_util as huggingface_util |
|
import external.llite.library.sai_model_spec as sai_model_spec |
|
|
|
|
|
|
|
from external.llite.library.original_unet import UNet2DConditionModel |
|
|
|
|
|
TOKENIZER_PATH = "openai/clip-vit-large-patch14" |
|
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" |
|
|
|
|
|
EPOCH_STATE_NAME = "{}-{:06d}-state" |
|
EPOCH_FILE_NAME = "{}-{:06d}" |
|
EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}" |
|
LAST_STATE_NAME = "{}-state" |
|
DEFAULT_EPOCH_NAME = "epoch" |
|
DEFAULT_LAST_OUTPUT_NAME = "last" |
|
|
|
DEFAULT_STEP_NAME = "at" |
|
STEP_STATE_NAME = "{}-step{:08d}-state" |
|
STEP_FILE_NAME = "{}-step{:08d}" |
|
STEP_DIFFUSERS_DIR_NAME = "{}-step{:08d}" |
|
|
|
|
|
|
|
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] |
|
|
|
try: |
|
import pillow_avif |
|
|
|
IMAGE_EXTENSIONS.extend([".avif", ".AVIF"]) |
|
except: |
|
pass |
|
|
|
|
|
try: |
|
from jxlpy import JXLImagePlugin |
|
|
|
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) |
|
except: |
|
pass |
|
|
|
|
|
try: |
|
import pillow_jxl |
|
|
|
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) |
|
except: |
|
pass |
|
|
|
IMAGE_TRANSFORMS = transforms.Compose( |
|
[ |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5], [0.5]), |
|
] |
|
) |
|
|
|
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" |
|
|
|
|
|
class ImageInfo: |
|
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: |
|
self.image_key: str = image_key |
|
self.num_repeats: int = num_repeats |
|
self.caption: str = caption |
|
self.is_reg: bool = is_reg |
|
self.absolute_path: str = absolute_path |
|
self.image_size: Tuple[int, int] = None |
|
self.resized_size: Tuple[int, int] = None |
|
self.bucket_reso: Tuple[int, int] = None |
|
self.latents: torch.Tensor = None |
|
self.latents_flipped: torch.Tensor = None |
|
self.latents_npz: str = None |
|
self.latents_original_size: Tuple[int, int] = None |
|
self.latents_crop_ltrb: Tuple[int, int] = None |
|
self.cond_img_path: str = None |
|
self.image: Optional[Image.Image] = None |
|
|
|
self.text_encoder_outputs_npz: Optional[str] = None |
|
self.text_encoder_outputs1: Optional[torch.Tensor] = None |
|
self.text_encoder_outputs2: Optional[torch.Tensor] = None |
|
self.text_encoder_pool2: Optional[torch.Tensor] = None |
|
|
|
|
|
class BucketManager: |
|
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: |
|
if max_size is not None: |
|
if max_reso is not None: |
|
assert max_size >= max_reso[0], "the max_size should be larger than the width of max_reso" |
|
assert max_size >= max_reso[1], "the max_size should be larger than the height of max_reso" |
|
if min_size is not None: |
|
assert max_size >= min_size, "the max_size should be larger than the min_size" |
|
|
|
self.no_upscale = no_upscale |
|
if max_reso is None: |
|
self.max_reso = None |
|
self.max_area = None |
|
else: |
|
self.max_reso = max_reso |
|
self.max_area = max_reso[0] * max_reso[1] |
|
self.min_size = min_size |
|
self.max_size = max_size |
|
self.reso_steps = reso_steps |
|
|
|
self.resos = [] |
|
self.reso_to_id = {} |
|
self.buckets = [] |
|
|
|
def add_image(self, reso, image_or_info): |
|
bucket_id = self.reso_to_id[reso] |
|
self.buckets[bucket_id].append(image_or_info) |
|
|
|
def shuffle(self): |
|
for bucket in self.buckets: |
|
random.shuffle(bucket) |
|
|
|
def sort(self): |
|
|
|
sorted_resos = self.resos.copy() |
|
sorted_resos.sort() |
|
|
|
sorted_buckets = [] |
|
sorted_reso_to_id = {} |
|
for i, reso in enumerate(sorted_resos): |
|
bucket_id = self.reso_to_id[reso] |
|
sorted_buckets.append(self.buckets[bucket_id]) |
|
sorted_reso_to_id[reso] = i |
|
|
|
self.resos = sorted_resos |
|
self.buckets = sorted_buckets |
|
self.reso_to_id = sorted_reso_to_id |
|
|
|
def make_buckets(self): |
|
resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps) |
|
self.set_predefined_resos(resos) |
|
|
|
def set_predefined_resos(self, resos): |
|
|
|
self.predefined_resos = resos.copy() |
|
self.predefined_resos_set = set(resos) |
|
self.predefined_aspect_ratios = np.array([w / h for w, h in resos]) |
|
|
|
def add_if_new_reso(self, reso): |
|
if reso not in self.reso_to_id: |
|
bucket_id = len(self.resos) |
|
self.reso_to_id[reso] = bucket_id |
|
self.resos.append(reso) |
|
self.buckets.append([]) |
|
|
|
|
|
def round_to_steps(self, x): |
|
x = int(x + 0.5) |
|
return x - x % self.reso_steps |
|
|
|
def select_bucket(self, image_width, image_height): |
|
aspect_ratio = image_width / image_height |
|
if not self.no_upscale: |
|
|
|
|
|
reso = (image_width, image_height) |
|
if reso in self.predefined_resos_set: |
|
pass |
|
else: |
|
ar_errors = self.predefined_aspect_ratios - aspect_ratio |
|
predefined_bucket_id = np.abs(ar_errors).argmin() |
|
reso = self.predefined_resos[predefined_bucket_id] |
|
|
|
ar_reso = reso[0] / reso[1] |
|
if aspect_ratio > ar_reso: |
|
scale = reso[1] / image_height |
|
else: |
|
scale = reso[0] / image_width |
|
|
|
resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5)) |
|
|
|
else: |
|
|
|
if image_width * image_height > self.max_area: |
|
|
|
resized_width = math.sqrt(self.max_area * aspect_ratio) |
|
resized_height = self.max_area / resized_width |
|
assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal" |
|
|
|
|
|
|
|
b_width_rounded = self.round_to_steps(resized_width) |
|
b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio) |
|
ar_width_rounded = b_width_rounded / b_height_in_wr |
|
|
|
b_height_rounded = self.round_to_steps(resized_height) |
|
b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio) |
|
ar_height_rounded = b_width_in_hr / b_height_rounded |
|
|
|
|
|
|
|
|
|
if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio): |
|
resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + 0.5)) |
|
else: |
|
resized_size = (int(b_height_rounded * aspect_ratio + 0.5), b_height_rounded) |
|
|
|
else: |
|
resized_size = (image_width, image_height) |
|
|
|
|
|
bucket_width = resized_size[0] - resized_size[0] % self.reso_steps |
|
bucket_height = resized_size[1] - resized_size[1] % self.reso_steps |
|
|
|
|
|
reso = (bucket_width, bucket_height) |
|
|
|
self.add_if_new_reso(reso) |
|
|
|
ar_error = (reso[0] / reso[1]) - aspect_ratio |
|
return reso, resized_size, ar_error |
|
|
|
@staticmethod |
|
def get_crop_ltrb(bucket_reso: Tuple[int, int], image_size: Tuple[int, int]): |
|
|
|
|
|
|
|
bucket_ar = bucket_reso[0] / bucket_reso[1] |
|
image_ar = image_size[0] / image_size[1] |
|
if bucket_ar > image_ar: |
|
|
|
resized_width = bucket_reso[1] * image_ar |
|
resized_height = bucket_reso[1] |
|
else: |
|
resized_width = bucket_reso[0] |
|
resized_height = bucket_reso[0] / image_ar |
|
crop_left = (bucket_reso[0] - resized_width) // 2 |
|
crop_top = (bucket_reso[1] - resized_height) // 2 |
|
crop_right = crop_left + resized_width |
|
crop_bottom = crop_top + resized_height |
|
return crop_left, crop_top, crop_right, crop_bottom |
|
|
|
|
|
class BucketBatchIndex(NamedTuple): |
|
bucket_index: int |
|
bucket_batch_size: int |
|
batch_index: int |
|
|
|
|
|
class AugHelper: |
|
|
|
|
|
def __init__(self): |
|
pass |
|
|
|
def color_aug(self, image: np.ndarray): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hue_shift_limit = 8 |
|
|
|
|
|
if random.random() <= 0.33: |
|
if random.random() > 0.5: |
|
|
|
hsv_img = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) |
|
hue_shift = random.uniform(-hue_shift_limit, hue_shift_limit) |
|
if hue_shift < 0: |
|
hue_shift = 180 + hue_shift |
|
hsv_img[:, :, 0] = (hsv_img[:, :, 0] + hue_shift) % 180 |
|
image = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR) |
|
else: |
|
|
|
gamma = random.uniform(0.95, 1.05) |
|
image = np.clip(image**gamma, 0, 255).astype(np.uint8) |
|
|
|
return {"image": image} |
|
|
|
def get_augmentor(self, use_color_aug: bool): |
|
return self.color_aug if use_color_aug else None |
|
|
|
|
|
class BaseSubset: |
|
def __init__( |
|
self, |
|
image_dir: Optional[str], |
|
num_repeats: int, |
|
shuffle_caption: bool, |
|
caption_separator: str, |
|
keep_tokens: int, |
|
keep_tokens_separator: str, |
|
color_aug: bool, |
|
flip_aug: bool, |
|
face_crop_aug_range: Optional[Tuple[float, float]], |
|
random_crop: bool, |
|
caption_dropout_rate: float, |
|
caption_dropout_every_n_epochs: int, |
|
caption_tag_dropout_rate: float, |
|
caption_prefix: Optional[str], |
|
caption_suffix: Optional[str], |
|
token_warmup_min: int, |
|
token_warmup_step: Union[float, int], |
|
) -> None: |
|
self.image_dir = image_dir |
|
self.num_repeats = num_repeats |
|
self.shuffle_caption = shuffle_caption |
|
self.caption_separator = caption_separator |
|
self.keep_tokens = keep_tokens |
|
self.keep_tokens_separator = keep_tokens_separator |
|
self.color_aug = color_aug |
|
self.flip_aug = flip_aug |
|
self.face_crop_aug_range = face_crop_aug_range |
|
self.random_crop = random_crop |
|
self.caption_dropout_rate = caption_dropout_rate |
|
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs |
|
self.caption_tag_dropout_rate = caption_tag_dropout_rate |
|
self.caption_prefix = caption_prefix |
|
self.caption_suffix = caption_suffix |
|
|
|
self.token_warmup_min = token_warmup_min |
|
self.token_warmup_step = token_warmup_step |
|
|
|
self.img_count = 0 |
|
|
|
|
|
class DreamBoothSubset(BaseSubset): |
|
def __init__( |
|
self, |
|
image_dir: str, |
|
is_reg: bool, |
|
class_tokens: Optional[str], |
|
caption_extension: str, |
|
num_repeats, |
|
shuffle_caption, |
|
caption_separator: str, |
|
keep_tokens, |
|
keep_tokens_separator, |
|
color_aug, |
|
flip_aug, |
|
face_crop_aug_range, |
|
random_crop, |
|
caption_dropout_rate, |
|
caption_dropout_every_n_epochs, |
|
caption_tag_dropout_rate, |
|
caption_prefix, |
|
caption_suffix, |
|
token_warmup_min, |
|
token_warmup_step, |
|
) -> None: |
|
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" |
|
|
|
super().__init__( |
|
image_dir, |
|
num_repeats, |
|
shuffle_caption, |
|
caption_separator, |
|
keep_tokens, |
|
keep_tokens_separator, |
|
color_aug, |
|
flip_aug, |
|
face_crop_aug_range, |
|
random_crop, |
|
caption_dropout_rate, |
|
caption_dropout_every_n_epochs, |
|
caption_tag_dropout_rate, |
|
caption_prefix, |
|
caption_suffix, |
|
token_warmup_min, |
|
token_warmup_step, |
|
) |
|
|
|
self.is_reg = is_reg |
|
self.class_tokens = class_tokens |
|
self.caption_extension = caption_extension |
|
if self.caption_extension and not self.caption_extension.startswith("."): |
|
self.caption_extension = "." + self.caption_extension |
|
|
|
def __eq__(self, other) -> bool: |
|
if not isinstance(other, DreamBoothSubset): |
|
return NotImplemented |
|
return self.image_dir == other.image_dir |
|
|
|
|
|
class FineTuningSubset(BaseSubset): |
|
def __init__( |
|
self, |
|
image_dir, |
|
metadata_file: str, |
|
num_repeats, |
|
shuffle_caption, |
|
caption_separator, |
|
keep_tokens, |
|
keep_tokens_separator, |
|
color_aug, |
|
flip_aug, |
|
face_crop_aug_range, |
|
random_crop, |
|
caption_dropout_rate, |
|
caption_dropout_every_n_epochs, |
|
caption_tag_dropout_rate, |
|
caption_prefix, |
|
caption_suffix, |
|
token_warmup_min, |
|
token_warmup_step, |
|
) -> None: |
|
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" |
|
|
|
super().__init__( |
|
image_dir, |
|
num_repeats, |
|
shuffle_caption, |
|
caption_separator, |
|
keep_tokens, |
|
keep_tokens_separator, |
|
color_aug, |
|
flip_aug, |
|
face_crop_aug_range, |
|
random_crop, |
|
caption_dropout_rate, |
|
caption_dropout_every_n_epochs, |
|
caption_tag_dropout_rate, |
|
caption_prefix, |
|
caption_suffix, |
|
token_warmup_min, |
|
token_warmup_step, |
|
) |
|
|
|
self.metadata_file = metadata_file |
|
|
|
def __eq__(self, other) -> bool: |
|
if not isinstance(other, FineTuningSubset): |
|
return NotImplemented |
|
return self.metadata_file == other.metadata_file |
|
|
|
|
|
class ControlNetSubset(BaseSubset): |
|
def __init__( |
|
self, |
|
image_dir: str, |
|
conditioning_data_dir: str, |
|
caption_extension: str, |
|
num_repeats, |
|
shuffle_caption, |
|
caption_separator, |
|
keep_tokens, |
|
keep_tokens_separator, |
|
color_aug, |
|
flip_aug, |
|
face_crop_aug_range, |
|
random_crop, |
|
caption_dropout_rate, |
|
caption_dropout_every_n_epochs, |
|
caption_tag_dropout_rate, |
|
caption_prefix, |
|
caption_suffix, |
|
token_warmup_min, |
|
token_warmup_step, |
|
) -> None: |
|
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" |
|
|
|
super().__init__( |
|
image_dir, |
|
num_repeats, |
|
shuffle_caption, |
|
caption_separator, |
|
keep_tokens, |
|
keep_tokens_separator, |
|
color_aug, |
|
flip_aug, |
|
face_crop_aug_range, |
|
random_crop, |
|
caption_dropout_rate, |
|
caption_dropout_every_n_epochs, |
|
caption_tag_dropout_rate, |
|
caption_prefix, |
|
caption_suffix, |
|
token_warmup_min, |
|
token_warmup_step, |
|
) |
|
|
|
self.conditioning_data_dir = conditioning_data_dir |
|
self.caption_extension = caption_extension |
|
if self.caption_extension and not self.caption_extension.startswith("."): |
|
self.caption_extension = "." + self.caption_extension |
|
|
|
def __eq__(self, other) -> bool: |
|
if not isinstance(other, ControlNetSubset): |
|
return NotImplemented |
|
return self.image_dir == other.image_dir and self.conditioning_data_dir == other.conditioning_data_dir |
|
|
|
|
|
class BaseDataset(torch.utils.data.Dataset): |
|
def __init__( |
|
self, |
|
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]], |
|
max_token_length: int, |
|
resolution: Optional[Tuple[int, int]], |
|
debug_dataset: bool, |
|
) -> None: |
|
super().__init__() |
|
|
|
self.tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] |
|
|
|
self.max_token_length = max_token_length |
|
|
|
self.width, self.height = (None, None) if resolution is None else resolution |
|
self.debug_dataset = debug_dataset |
|
|
|
self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = [] |
|
|
|
self.token_padding_disabled = False |
|
self.tag_frequency = {} |
|
self.XTI_layers = None |
|
self.token_strings = None |
|
|
|
self.enable_bucket = False |
|
self.bucket_manager: BucketManager = None |
|
self.min_bucket_reso = None |
|
self.max_bucket_reso = None |
|
self.bucket_reso_steps = None |
|
self.bucket_no_upscale = None |
|
self.bucket_info = None |
|
|
|
self.tokenizer_max_length = self.tokenizers[0].model_max_length if max_token_length is None else max_token_length + 2 |
|
|
|
self.current_epoch: int = 0 |
|
|
|
self.current_step: int = 0 |
|
self.max_train_steps: int = 0 |
|
self.seed: int = 0 |
|
|
|
|
|
self.aug_helper = AugHelper() |
|
|
|
self.image_transforms = IMAGE_TRANSFORMS |
|
|
|
self.image_data: Dict[str, ImageInfo] = {} |
|
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} |
|
|
|
self.replacements = {} |
|
|
|
|
|
self.caching_mode = None |
|
|
|
def set_seed(self, seed): |
|
self.seed = seed |
|
|
|
def set_caching_mode(self, mode): |
|
self.caching_mode = mode |
|
|
|
def set_current_epoch(self, epoch): |
|
if not self.current_epoch == epoch: |
|
self.shuffle_buckets() |
|
self.current_epoch = epoch |
|
|
|
def set_current_step(self, step): |
|
self.current_step = step |
|
|
|
def set_max_train_steps(self, max_train_steps): |
|
self.max_train_steps = max_train_steps |
|
|
|
def set_tag_frequency(self, dir_name, captions): |
|
frequency_for_dir = self.tag_frequency.get(dir_name, {}) |
|
self.tag_frequency[dir_name] = frequency_for_dir |
|
for caption in captions: |
|
for tag in caption.split(","): |
|
tag = tag.strip() |
|
if tag: |
|
tag = tag.lower() |
|
frequency = frequency_for_dir.get(tag, 0) |
|
frequency_for_dir[tag] = frequency + 1 |
|
|
|
def disable_token_padding(self): |
|
self.token_padding_disabled = True |
|
|
|
def enable_XTI(self, layers=None, token_strings=None): |
|
self.XTI_layers = layers |
|
self.token_strings = token_strings |
|
|
|
def add_replacement(self, str_from, str_to): |
|
self.replacements[str_from] = str_to |
|
|
|
def process_caption(self, subset: BaseSubset, caption): |
|
|
|
if subset.caption_prefix: |
|
caption = subset.caption_prefix + " " + caption |
|
if subset.caption_suffix: |
|
caption = caption + " " + subset.caption_suffix |
|
|
|
|
|
is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate |
|
is_drop_out = ( |
|
is_drop_out |
|
or subset.caption_dropout_every_n_epochs > 0 |
|
and self.current_epoch % subset.caption_dropout_every_n_epochs == 0 |
|
) |
|
|
|
if is_drop_out: |
|
caption = "" |
|
else: |
|
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: |
|
fixed_tokens = [] |
|
flex_tokens = [] |
|
if ( |
|
hasattr(subset, "keep_tokens_separator") |
|
and subset.keep_tokens_separator |
|
and subset.keep_tokens_separator in caption |
|
): |
|
fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1) |
|
fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()] |
|
flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()] |
|
else: |
|
tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)] |
|
flex_tokens = tokens[:] |
|
if subset.keep_tokens > 0: |
|
fixed_tokens = flex_tokens[: subset.keep_tokens] |
|
flex_tokens = tokens[subset.keep_tokens :] |
|
|
|
if subset.token_warmup_step < 1: |
|
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) |
|
if subset.token_warmup_step and self.current_step < subset.token_warmup_step: |
|
tokens_len = ( |
|
math.floor( |
|
(self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step)) |
|
) |
|
+ subset.token_warmup_min |
|
) |
|
flex_tokens = flex_tokens[:tokens_len] |
|
|
|
def dropout_tags(tokens): |
|
if subset.caption_tag_dropout_rate <= 0: |
|
return tokens |
|
l = [] |
|
for token in tokens: |
|
if random.random() >= subset.caption_tag_dropout_rate: |
|
l.append(token) |
|
return l |
|
|
|
if subset.shuffle_caption: |
|
random.shuffle(flex_tokens) |
|
|
|
flex_tokens = dropout_tags(flex_tokens) |
|
|
|
caption = ", ".join(fixed_tokens + flex_tokens) |
|
|
|
|
|
for str_from, str_to in self.replacements.items(): |
|
if str_from == "": |
|
|
|
if type(str_to) == list: |
|
caption = random.choice(str_to) |
|
else: |
|
caption = str_to |
|
else: |
|
caption = caption.replace(str_from, str_to) |
|
|
|
return caption |
|
|
|
def get_input_ids(self, caption, tokenizer=None): |
|
if tokenizer is None: |
|
tokenizer = self.tokenizers[0] |
|
|
|
input_ids = tokenizer( |
|
caption, padding="max_length", truncation=True, max_length=self.tokenizer_max_length, return_tensors="pt" |
|
).input_ids |
|
|
|
if self.tokenizer_max_length > tokenizer.model_max_length: |
|
input_ids = input_ids.squeeze(0) |
|
iids_list = [] |
|
if tokenizer.pad_token_id == tokenizer.eos_token_id: |
|
|
|
|
|
|
|
for i in range( |
|
1, self.tokenizer_max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2 |
|
): |
|
ids_chunk = ( |
|
input_ids[0].unsqueeze(0), |
|
input_ids[i : i + tokenizer.model_max_length - 2], |
|
input_ids[-1].unsqueeze(0), |
|
) |
|
ids_chunk = torch.cat(ids_chunk) |
|
iids_list.append(ids_chunk) |
|
else: |
|
|
|
|
|
for i in range(1, self.tokenizer_max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): |
|
ids_chunk = ( |
|
input_ids[0].unsqueeze(0), |
|
input_ids[i : i + tokenizer.model_max_length - 2], |
|
input_ids[-1].unsqueeze(0), |
|
) |
|
ids_chunk = torch.cat(ids_chunk) |
|
|
|
|
|
|
|
if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id: |
|
ids_chunk[-1] = tokenizer.eos_token_id |
|
|
|
if ids_chunk[1] == tokenizer.pad_token_id: |
|
ids_chunk[1] = tokenizer.eos_token_id |
|
|
|
iids_list.append(ids_chunk) |
|
|
|
input_ids = torch.stack(iids_list) |
|
return input_ids |
|
|
|
def register_image(self, info: ImageInfo, subset: BaseSubset): |
|
self.image_data[info.image_key] = info |
|
self.image_to_subset[info.image_key] = subset |
|
|
|
def make_buckets(self): |
|
""" |
|
bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) |
|
min_size and max_size are ignored when enable_bucket is False |
|
""" |
|
print("loading image sizes.") |
|
for info in tqdm(self.image_data.values()): |
|
if info.image_size is None: |
|
info.image_size = self.get_image_size(info.absolute_path) |
|
|
|
if self.enable_bucket: |
|
print("make buckets") |
|
else: |
|
print("prepare dataset") |
|
|
|
|
|
if self.enable_bucket: |
|
if self.bucket_manager is None: |
|
self.bucket_manager = BucketManager( |
|
self.bucket_no_upscale, |
|
(self.width, self.height), |
|
self.min_bucket_reso, |
|
self.max_bucket_reso, |
|
self.bucket_reso_steps, |
|
) |
|
if not self.bucket_no_upscale: |
|
self.bucket_manager.make_buckets() |
|
else: |
|
print( |
|
"min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます" |
|
) |
|
|
|
img_ar_errors = [] |
|
for image_info in self.image_data.values(): |
|
image_width, image_height = image_info.image_size |
|
image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket( |
|
image_width, image_height |
|
) |
|
|
|
|
|
img_ar_errors.append(abs(ar_error)) |
|
|
|
self.bucket_manager.sort() |
|
else: |
|
self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None) |
|
self.bucket_manager.set_predefined_resos([(self.width, self.height)]) |
|
for image_info in self.image_data.values(): |
|
image_width, image_height = image_info.image_size |
|
image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height) |
|
|
|
for image_info in self.image_data.values(): |
|
for _ in range(image_info.num_repeats): |
|
self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key) |
|
|
|
|
|
if self.enable_bucket: |
|
self.bucket_info = {"buckets": {}} |
|
print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") |
|
for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)): |
|
count = len(bucket) |
|
if count > 0: |
|
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)} |
|
print(f"bucket {i}: resolution {reso}, count: {len(bucket)}") |
|
|
|
img_ar_errors = np.array(img_ar_errors) |
|
mean_img_ar_error = np.mean(np.abs(img_ar_errors)) |
|
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error |
|
print(f"mean ar error (without repeats): {mean_img_ar_error}") |
|
|
|
|
|
self.buckets_indices: List(BucketBatchIndex) = [] |
|
for bucket_index, bucket in enumerate(self.bucket_manager.buckets): |
|
batch_count = int(math.ceil(len(bucket) / self.batch_size)) |
|
for batch_index in range(batch_count): |
|
self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.shuffle_buckets() |
|
self._length = len(self.buckets_indices) |
|
|
|
def shuffle_buckets(self): |
|
|
|
random.seed(self.seed + self.current_epoch) |
|
|
|
random.shuffle(self.buckets_indices) |
|
self.bucket_manager.shuffle() |
|
|
|
def verify_bucket_reso_steps(self, min_steps: int): |
|
assert self.bucket_reso_steps is None or self.bucket_reso_steps % min_steps == 0, ( |
|
f"bucket_reso_steps is {self.bucket_reso_steps}. it must be divisible by {min_steps}.\n" |
|
+ f"bucket_reso_stepsが{self.bucket_reso_steps}です。{min_steps}で割り切れる必要があります" |
|
) |
|
|
|
def is_latent_cacheable(self): |
|
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) |
|
|
|
def is_text_encoder_output_cacheable(self): |
|
return all( |
|
[ |
|
not ( |
|
subset.caption_dropout_rate > 0 |
|
or subset.shuffle_caption |
|
or subset.token_warmup_step > 0 |
|
or subset.caption_tag_dropout_rate > 0 |
|
) |
|
for subset in self.subsets |
|
] |
|
) |
|
|
|
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): |
|
|
|
print("caching latents.") |
|
|
|
image_infos = list(self.image_data.values()) |
|
|
|
|
|
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) |
|
|
|
|
|
batches = [] |
|
batch = [] |
|
print("checking cache validity...") |
|
for info in tqdm(image_infos): |
|
subset = self.image_to_subset[info.image_key] |
|
|
|
if info.latents_npz is not None: |
|
continue |
|
|
|
|
|
if cache_to_disk: |
|
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz" |
|
if not is_main_process: |
|
continue |
|
|
|
cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug) |
|
|
|
if cache_available: |
|
continue |
|
|
|
|
|
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: |
|
batches.append(batch) |
|
batch = [] |
|
|
|
batch.append(info) |
|
|
|
|
|
if len(batch) >= vae_batch_size: |
|
batches.append(batch) |
|
batch = [] |
|
|
|
if len(batch) > 0: |
|
batches.append(batch) |
|
|
|
if cache_to_disk and not is_main_process: |
|
return |
|
|
|
|
|
print("caching latents...") |
|
for batch in tqdm(batches, smoothing=1, total=len(batches)): |
|
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop) |
|
|
|
|
|
|
|
|
|
def cache_text_encoder_outputs( |
|
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True |
|
): |
|
assert len(tokenizers) == 2, "only support SDXL" |
|
|
|
|
|
|
|
print("caching text encoder outputs.") |
|
image_infos = list(self.image_data.values()) |
|
|
|
print("checking cache existence...") |
|
image_infos_to_cache = [] |
|
for info in tqdm(image_infos): |
|
|
|
if cache_to_disk: |
|
te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX |
|
info.text_encoder_outputs_npz = te_out_npz |
|
|
|
if not is_main_process: |
|
continue |
|
|
|
if os.path.exists(te_out_npz): |
|
continue |
|
|
|
image_infos_to_cache.append(info) |
|
|
|
if cache_to_disk and not is_main_process: |
|
return |
|
|
|
|
|
for text_encoder in text_encoders: |
|
text_encoder.to(device) |
|
if weight_dtype is not None: |
|
text_encoder.to(dtype=weight_dtype) |
|
|
|
|
|
batch = [] |
|
batches = [] |
|
for info in image_infos_to_cache: |
|
input_ids1 = self.get_input_ids(info.caption, tokenizers[0]) |
|
input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) |
|
batch.append((info, input_ids1, input_ids2)) |
|
|
|
if len(batch) >= self.batch_size: |
|
batches.append(batch) |
|
batch = [] |
|
|
|
if len(batch) > 0: |
|
batches.append(batch) |
|
|
|
|
|
print("caching text encoder outputs...") |
|
for batch in tqdm(batches): |
|
infos, input_ids1, input_ids2 = zip(*batch) |
|
input_ids1 = torch.stack(input_ids1, dim=0) |
|
input_ids2 = torch.stack(input_ids2, dim=0) |
|
cache_batch_text_encoder_outputs( |
|
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype |
|
) |
|
|
|
def get_image_size(self, image_path): |
|
image = Image.open(image_path) |
|
return image.size |
|
|
|
def load_image_with_face_info(self, subset: BaseSubset, image_path: str): |
|
img = load_image(image_path) |
|
|
|
face_cx = face_cy = face_w = face_h = 0 |
|
if subset.face_crop_aug_range is not None: |
|
tokens = os.path.splitext(os.path.basename(image_path))[0].split("_") |
|
if len(tokens) >= 5: |
|
face_cx = int(tokens[-4]) |
|
face_cy = int(tokens[-3]) |
|
face_w = int(tokens[-2]) |
|
face_h = int(tokens[-1]) |
|
|
|
return img, face_cx, face_cy, face_w, face_h |
|
|
|
|
|
def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h): |
|
height, width = image.shape[0:2] |
|
if height == self.height and width == self.width: |
|
return image |
|
|
|
|
|
face_size = max(face_w, face_h) |
|
size = min(self.height, self.width) |
|
min_scale = max(self.height / height, self.width / width) |
|
min_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[1]))) |
|
max_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[0]))) |
|
if min_scale >= max_scale: |
|
scale = min_scale |
|
else: |
|
scale = random.uniform(min_scale, max_scale) |
|
|
|
nh = int(height * scale + 0.5) |
|
nw = int(width * scale + 0.5) |
|
assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" |
|
image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) |
|
face_cx = int(face_cx * scale + 0.5) |
|
face_cy = int(face_cy * scale + 0.5) |
|
height, width = nh, nw |
|
|
|
|
|
for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))): |
|
p1 = face_p - target_size // 2 |
|
|
|
if subset.random_crop: |
|
|
|
range = max(length - face_p, face_p) |
|
p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range |
|
else: |
|
|
|
if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]: |
|
if face_size > size // 10 and face_size >= 40: |
|
p1 = p1 + random.randint(-face_size // 20, +face_size // 20) |
|
|
|
p1 = max(0, min(p1, length - target_size)) |
|
|
|
if axis == 0: |
|
image = image[p1 : p1 + target_size, :] |
|
else: |
|
image = image[:, p1 : p1 + target_size] |
|
|
|
return image |
|
|
|
def __len__(self): |
|
return self._length |
|
|
|
def __getitem__(self, index): |
|
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index] |
|
bucket_batch_size = self.buckets_indices[index].bucket_batch_size |
|
image_index = self.buckets_indices[index].batch_index * bucket_batch_size |
|
|
|
if self.caching_mode is not None: |
|
return self.get_item_for_caching(bucket, bucket_batch_size, image_index) |
|
|
|
loss_weights = [] |
|
captions = [] |
|
input_ids_list = [] |
|
input_ids2_list = [] |
|
latents_list = [] |
|
images = [] |
|
original_sizes_hw = [] |
|
crop_top_lefts = [] |
|
target_sizes_hw = [] |
|
flippeds = [] |
|
text_encoder_outputs1_list = [] |
|
text_encoder_outputs2_list = [] |
|
text_encoder_pool2_list = [] |
|
|
|
for image_key in bucket[image_index : image_index + bucket_batch_size]: |
|
image_info = self.image_data[image_key] |
|
subset = self.image_to_subset[image_key] |
|
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) |
|
|
|
flipped = subset.flip_aug and random.random() < 0.5 |
|
|
|
|
|
if image_info.latents is not None: |
|
original_size = image_info.latents_original_size |
|
crop_ltrb = image_info.latents_crop_ltrb |
|
if not flipped: |
|
latents = image_info.latents |
|
else: |
|
latents = image_info.latents_flipped |
|
|
|
image = None |
|
elif image_info.latents_npz is not None: |
|
latents, original_size, crop_ltrb, flipped_latents = load_latents_from_disk(image_info.latents_npz) |
|
if flipped: |
|
latents = flipped_latents |
|
del flipped_latents |
|
latents = torch.FloatTensor(latents) |
|
|
|
image = None |
|
else: |
|
|
|
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path) |
|
im_h, im_w = img.shape[0:2] |
|
|
|
if self.enable_bucket: |
|
img, original_size, crop_ltrb = trim_and_resize_if_required( |
|
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size |
|
) |
|
else: |
|
if face_cx > 0: |
|
img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h) |
|
elif im_h > self.height or im_w > self.width: |
|
assert ( |
|
subset.random_crop |
|
), f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}" |
|
if im_h > self.height: |
|
p = random.randint(0, im_h - self.height) |
|
img = img[p : p + self.height] |
|
if im_w > self.width: |
|
p = random.randint(0, im_w - self.width) |
|
img = img[:, p : p + self.width] |
|
|
|
im_h, im_w = img.shape[0:2] |
|
assert ( |
|
im_h == self.height and im_w == self.width |
|
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" |
|
|
|
original_size = [im_w, im_h] |
|
crop_ltrb = (0, 0, 0, 0) |
|
|
|
|
|
aug = self.aug_helper.get_augmentor(subset.color_aug) |
|
if aug is not None: |
|
img = aug(image=img)["image"] |
|
|
|
if flipped: |
|
img = img[:, ::-1, :].copy() |
|
|
|
latents = None |
|
image = self.image_transforms(img) |
|
|
|
images.append(image) |
|
latents_list.append(latents) |
|
|
|
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) |
|
|
|
if not flipped: |
|
crop_left_top = (crop_ltrb[0], crop_ltrb[1]) |
|
else: |
|
|
|
crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1]) |
|
|
|
original_sizes_hw.append((int(original_size[1]), int(original_size[0]))) |
|
crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0]))) |
|
target_sizes_hw.append((int(target_size[1]), int(target_size[0]))) |
|
flippeds.append(flipped) |
|
|
|
|
|
caption = image_info.caption |
|
if image_info.text_encoder_outputs1 is not None: |
|
text_encoder_outputs1_list.append(image_info.text_encoder_outputs1) |
|
text_encoder_outputs2_list.append(image_info.text_encoder_outputs2) |
|
text_encoder_pool2_list.append(image_info.text_encoder_pool2) |
|
captions.append(caption) |
|
elif image_info.text_encoder_outputs_npz is not None: |
|
text_encoder_outputs1, text_encoder_outputs2, text_encoder_pool2 = load_text_encoder_outputs_from_disk( |
|
image_info.text_encoder_outputs_npz |
|
) |
|
text_encoder_outputs1_list.append(text_encoder_outputs1) |
|
text_encoder_outputs2_list.append(text_encoder_outputs2) |
|
text_encoder_pool2_list.append(text_encoder_pool2) |
|
captions.append(caption) |
|
else: |
|
caption = self.process_caption(subset, image_info.caption) |
|
if self.XTI_layers: |
|
caption_layer = [] |
|
for layer in self.XTI_layers: |
|
token_strings_from = " ".join(self.token_strings) |
|
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) |
|
caption_ = caption.replace(token_strings_from, token_strings_to) |
|
caption_layer.append(caption_) |
|
captions.append(caption_layer) |
|
else: |
|
captions.append(caption) |
|
|
|
if not self.token_padding_disabled: |
|
if self.XTI_layers: |
|
token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) |
|
else: |
|
token_caption = self.get_input_ids(caption, self.tokenizers[0]) |
|
input_ids_list.append(token_caption) |
|
|
|
if len(self.tokenizers) > 1: |
|
if self.XTI_layers: |
|
token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) |
|
else: |
|
token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) |
|
input_ids2_list.append(token_caption2) |
|
|
|
example = {} |
|
example["loss_weights"] = torch.FloatTensor(loss_weights) |
|
|
|
if len(text_encoder_outputs1_list) == 0: |
|
if self.token_padding_disabled: |
|
|
|
example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids |
|
if len(self.tokenizers) > 1: |
|
example["input_ids2"] = self.tokenizer[1]( |
|
captions, padding=True, truncation=True, return_tensors="pt" |
|
).input_ids |
|
else: |
|
example["input_ids2"] = None |
|
else: |
|
example["input_ids"] = torch.stack(input_ids_list) |
|
example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None |
|
example["text_encoder_outputs1_list"] = None |
|
example["text_encoder_outputs2_list"] = None |
|
example["text_encoder_pool2_list"] = None |
|
else: |
|
example["input_ids"] = None |
|
example["input_ids2"] = None |
|
|
|
|
|
|
|
example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list) |
|
example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list) |
|
example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list) |
|
|
|
if images[0] is not None: |
|
images = torch.stack(images) |
|
images = images.to(memory_format=torch.contiguous_format).float() |
|
else: |
|
images = None |
|
example["images"] = images |
|
|
|
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None |
|
example["captions"] = captions |
|
|
|
example["original_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in original_sizes_hw]) |
|
example["crop_top_lefts"] = torch.stack([torch.LongTensor(x) for x in crop_top_lefts]) |
|
example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw]) |
|
example["flippeds"] = flippeds |
|
|
|
if self.debug_dataset: |
|
example["image_keys"] = bucket[image_index : image_index + self.batch_size] |
|
return example |
|
|
|
def get_item_for_caching(self, bucket, bucket_batch_size, image_index): |
|
captions = [] |
|
images = [] |
|
input_ids1_list = [] |
|
input_ids2_list = [] |
|
absolute_paths = [] |
|
resized_sizes = [] |
|
bucket_reso = None |
|
flip_aug = None |
|
random_crop = None |
|
|
|
for image_key in bucket[image_index : image_index + bucket_batch_size]: |
|
image_info = self.image_data[image_key] |
|
subset = self.image_to_subset[image_key] |
|
|
|
if flip_aug is None: |
|
flip_aug = subset.flip_aug |
|
random_crop = subset.random_crop |
|
bucket_reso = image_info.bucket_reso |
|
else: |
|
assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch" |
|
assert random_crop == subset.random_crop, "random_crop must be same in a batch" |
|
assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch" |
|
|
|
caption = image_info.caption |
|
|
|
if self.caching_mode == "latents": |
|
image = load_image(image_info.absolute_path) |
|
else: |
|
image = None |
|
|
|
if self.caching_mode == "text": |
|
input_ids1 = self.get_input_ids(caption, self.tokenizers[0]) |
|
input_ids2 = self.get_input_ids(caption, self.tokenizers[1]) |
|
else: |
|
input_ids1 = None |
|
input_ids2 = None |
|
|
|
captions.append(caption) |
|
images.append(image) |
|
input_ids1_list.append(input_ids1) |
|
input_ids2_list.append(input_ids2) |
|
absolute_paths.append(image_info.absolute_path) |
|
resized_sizes.append(image_info.resized_size) |
|
|
|
example = {} |
|
|
|
if images[0] is None: |
|
images = None |
|
example["images"] = images |
|
|
|
example["captions"] = captions |
|
example["input_ids1_list"] = input_ids1_list |
|
example["input_ids2_list"] = input_ids2_list |
|
example["absolute_paths"] = absolute_paths |
|
example["resized_sizes"] = resized_sizes |
|
example["flip_aug"] = flip_aug |
|
example["random_crop"] = random_crop |
|
example["bucket_reso"] = bucket_reso |
|
return example |
|
|
|
|
|
class DreamBoothDataset(BaseDataset): |
|
def __init__( |
|
self, |
|
subsets: Sequence[DreamBoothSubset], |
|
batch_size: int, |
|
tokenizer, |
|
max_token_length, |
|
resolution, |
|
enable_bucket: bool, |
|
min_bucket_reso: int, |
|
max_bucket_reso: int, |
|
bucket_reso_steps: int, |
|
bucket_no_upscale: bool, |
|
prior_loss_weight: float, |
|
debug_dataset, |
|
) -> None: |
|
super().__init__(tokenizer, max_token_length, resolution, debug_dataset) |
|
|
|
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" |
|
|
|
self.batch_size = batch_size |
|
self.size = min(self.width, self.height) |
|
self.prior_loss_weight = prior_loss_weight |
|
self.latents_cache = None |
|
|
|
self.enable_bucket = enable_bucket |
|
if self.enable_bucket: |
|
assert ( |
|
min(resolution) >= min_bucket_reso |
|
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" |
|
assert ( |
|
max(resolution) <= max_bucket_reso |
|
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" |
|
self.min_bucket_reso = min_bucket_reso |
|
self.max_bucket_reso = max_bucket_reso |
|
self.bucket_reso_steps = bucket_reso_steps |
|
self.bucket_no_upscale = bucket_no_upscale |
|
else: |
|
self.min_bucket_reso = None |
|
self.max_bucket_reso = None |
|
self.bucket_reso_steps = None |
|
self.bucket_no_upscale = False |
|
|
|
def read_caption(img_path, caption_extension): |
|
|
|
base_name = os.path.splitext(img_path)[0] |
|
base_name_face_det = base_name |
|
tokens = base_name.split("_") |
|
if len(tokens) >= 5: |
|
base_name_face_det = "_".join(tokens[:-4]) |
|
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] |
|
|
|
caption = None |
|
for cap_path in cap_paths: |
|
if os.path.isfile(cap_path): |
|
with open(cap_path, "rt", encoding="utf-8") as f: |
|
try: |
|
lines = f.readlines() |
|
except UnicodeDecodeError as e: |
|
print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") |
|
raise e |
|
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" |
|
caption = lines[0].strip() |
|
break |
|
return caption |
|
|
|
def load_dreambooth_dir(subset: DreamBoothSubset): |
|
if not os.path.isdir(subset.image_dir): |
|
print(f"not directory: {subset.image_dir}") |
|
return [], [] |
|
|
|
img_paths = glob_images(subset.image_dir, "*") |
|
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") |
|
|
|
|
|
captions = [] |
|
missing_captions = [] |
|
for img_path in img_paths: |
|
cap_for_img = read_caption(img_path, subset.caption_extension) |
|
if cap_for_img is None and subset.class_tokens is None: |
|
print( |
|
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}" |
|
) |
|
captions.append("") |
|
missing_captions.append(img_path) |
|
else: |
|
if cap_for_img is None: |
|
captions.append(subset.class_tokens) |
|
missing_captions.append(img_path) |
|
else: |
|
captions.append(cap_for_img) |
|
|
|
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) |
|
|
|
if missing_captions: |
|
number_of_missing_captions = len(missing_captions) |
|
number_of_missing_captions_to_show = 5 |
|
remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show |
|
|
|
print( |
|
f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。" |
|
) |
|
for i, missing_caption in enumerate(missing_captions): |
|
if i >= number_of_missing_captions_to_show: |
|
print(missing_caption + f"... and {remaining_missing_captions} more") |
|
break |
|
print(missing_caption) |
|
return img_paths, captions |
|
|
|
print("prepare images.") |
|
num_train_images = 0 |
|
num_reg_images = 0 |
|
reg_infos: List[ImageInfo] = [] |
|
for subset in subsets: |
|
if subset.num_repeats < 1: |
|
print( |
|
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" |
|
) |
|
continue |
|
|
|
if subset in self.subsets: |
|
print( |
|
f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" |
|
) |
|
continue |
|
|
|
img_paths, captions = load_dreambooth_dir(subset) |
|
if len(img_paths) < 1: |
|
print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します") |
|
continue |
|
|
|
if subset.is_reg: |
|
num_reg_images += subset.num_repeats * len(img_paths) |
|
else: |
|
num_train_images += subset.num_repeats * len(img_paths) |
|
|
|
for img_path, caption in zip(img_paths, captions): |
|
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) |
|
if subset.is_reg: |
|
reg_infos.append(info) |
|
else: |
|
self.register_image(info, subset) |
|
|
|
subset.img_count = len(img_paths) |
|
self.subsets.append(subset) |
|
|
|
print(f"{num_train_images} train images with repeating.") |
|
self.num_train_images = num_train_images |
|
|
|
print(f"{num_reg_images} reg images.") |
|
if num_train_images < num_reg_images: |
|
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") |
|
|
|
if num_reg_images == 0: |
|
print("no regularization images / 正則化画像が見つかりませんでした") |
|
else: |
|
|
|
n = 0 |
|
first_loop = True |
|
while n < num_train_images: |
|
for info in reg_infos: |
|
if first_loop: |
|
self.register_image(info, subset) |
|
n += info.num_repeats |
|
else: |
|
info.num_repeats += 1 |
|
n += 1 |
|
if n >= num_train_images: |
|
break |
|
first_loop = False |
|
|
|
self.num_reg_images = num_reg_images |
|
|
|
|
|
class FineTuningDataset(BaseDataset): |
|
def __init__( |
|
self, |
|
subsets: Sequence[FineTuningSubset], |
|
batch_size: int, |
|
tokenizer, |
|
max_token_length, |
|
resolution, |
|
enable_bucket: bool, |
|
min_bucket_reso: int, |
|
max_bucket_reso: int, |
|
bucket_reso_steps: int, |
|
bucket_no_upscale: bool, |
|
debug_dataset, |
|
) -> None: |
|
super().__init__(tokenizer, max_token_length, resolution, debug_dataset) |
|
|
|
self.batch_size = batch_size |
|
|
|
self.num_train_images = 0 |
|
self.num_reg_images = 0 |
|
|
|
for subset in subsets: |
|
if subset.num_repeats < 1: |
|
print( |
|
f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" |
|
) |
|
continue |
|
|
|
if subset in self.subsets: |
|
print( |
|
f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" |
|
) |
|
continue |
|
|
|
|
|
if os.path.exists(subset.metadata_file): |
|
print(f"loading existing metadata: {subset.metadata_file}") |
|
with open(subset.metadata_file, "rt", encoding="utf-8") as f: |
|
metadata = json.load(f) |
|
else: |
|
raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}") |
|
|
|
if len(metadata) < 1: |
|
print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します") |
|
continue |
|
|
|
tags_list = [] |
|
for image_key, img_md in metadata.items(): |
|
|
|
abs_path = None |
|
|
|
|
|
if os.path.exists(image_key): |
|
abs_path = image_key |
|
else: |
|
|
|
paths = glob_images(subset.image_dir, image_key) |
|
if len(paths) > 0: |
|
abs_path = paths[0] |
|
|
|
|
|
if abs_path is None: |
|
if os.path.exists(os.path.splitext(image_key)[0] + ".npz"): |
|
abs_path = os.path.splitext(image_key)[0] + ".npz" |
|
else: |
|
npz_path = os.path.join(subset.image_dir, image_key + ".npz") |
|
if os.path.exists(npz_path): |
|
abs_path = npz_path |
|
|
|
assert abs_path is not None, f"no image / 画像がありません: {image_key}" |
|
|
|
caption = img_md.get("caption") |
|
tags = img_md.get("tags") |
|
if caption is None: |
|
caption = tags |
|
elif tags is not None and len(tags) > 0: |
|
caption = caption + ", " + tags |
|
tags_list.append(tags) |
|
|
|
if caption is None: |
|
caption = "" |
|
|
|
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path) |
|
image_info.image_size = img_md.get("train_resolution") |
|
|
|
if not subset.color_aug and not subset.random_crop: |
|
|
|
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key) |
|
|
|
self.register_image(image_info, subset) |
|
|
|
self.num_train_images += len(metadata) * subset.num_repeats |
|
|
|
|
|
self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list) |
|
subset.img_count = len(metadata) |
|
self.subsets.append(subset) |
|
|
|
|
|
use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets]) |
|
if use_npz_latents: |
|
flip_aug_in_subset = False |
|
npz_any = False |
|
npz_all = True |
|
|
|
for image_info in self.image_data.values(): |
|
subset = self.image_to_subset[image_info.image_key] |
|
|
|
has_npz = image_info.latents_npz is not None |
|
npz_any = npz_any or has_npz |
|
|
|
if subset.flip_aug: |
|
has_npz = has_npz and image_info.latents_npz_flipped is not None |
|
flip_aug_in_subset = True |
|
npz_all = npz_all and has_npz |
|
|
|
if npz_any and not npz_all: |
|
break |
|
|
|
if not npz_any: |
|
use_npz_latents = False |
|
print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します") |
|
elif not npz_all: |
|
use_npz_latents = False |
|
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") |
|
if flip_aug_in_subset: |
|
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") |
|
|
|
|
|
|
|
|
|
sizes = set() |
|
resos = set() |
|
for image_info in self.image_data.values(): |
|
if image_info.image_size is None: |
|
sizes = None |
|
break |
|
sizes.add(image_info.image_size[0]) |
|
sizes.add(image_info.image_size[1]) |
|
resos.add(tuple(image_info.image_size)) |
|
|
|
if sizes is None: |
|
if use_npz_latents: |
|
use_npz_latents = False |
|
print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します") |
|
|
|
assert ( |
|
resolution is not None |
|
), "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください" |
|
|
|
self.enable_bucket = enable_bucket |
|
if self.enable_bucket: |
|
self.min_bucket_reso = min_bucket_reso |
|
self.max_bucket_reso = max_bucket_reso |
|
self.bucket_reso_steps = bucket_reso_steps |
|
self.bucket_no_upscale = bucket_no_upscale |
|
else: |
|
if not enable_bucket: |
|
print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします") |
|
print("using bucket info in metadata / メタデータ内のbucket情報を使います") |
|
self.enable_bucket = True |
|
|
|
assert ( |
|
not bucket_no_upscale |
|
), "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません" |
|
|
|
|
|
self.bucket_manager = BucketManager(False, None, None, None, None) |
|
self.bucket_manager.set_predefined_resos(resos) |
|
|
|
|
|
if not use_npz_latents: |
|
for image_info in self.image_data.values(): |
|
image_info.latents_npz = image_info.latents_npz_flipped = None |
|
|
|
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key): |
|
base_name = os.path.splitext(image_key)[0] |
|
npz_file_norm = base_name + ".npz" |
|
|
|
if os.path.exists(npz_file_norm): |
|
|
|
npz_file_flip = base_name + "_flip.npz" |
|
if not os.path.exists(npz_file_flip): |
|
npz_file_flip = None |
|
return npz_file_norm, npz_file_flip |
|
|
|
|
|
if subset.image_dir is None: |
|
return None, None |
|
|
|
|
|
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz") |
|
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz") |
|
|
|
if not os.path.exists(npz_file_norm): |
|
npz_file_norm = None |
|
npz_file_flip = None |
|
elif not os.path.exists(npz_file_flip): |
|
npz_file_flip = None |
|
|
|
return npz_file_norm, npz_file_flip |
|
|
|
|
|
class ControlNetDataset(BaseDataset): |
|
def __init__( |
|
self, |
|
subsets: Sequence[ControlNetSubset], |
|
batch_size: int, |
|
tokenizer, |
|
max_token_length, |
|
resolution, |
|
enable_bucket: bool, |
|
min_bucket_reso: int, |
|
max_bucket_reso: int, |
|
bucket_reso_steps: int, |
|
bucket_no_upscale: bool, |
|
debug_dataset, |
|
) -> None: |
|
super().__init__(tokenizer, max_token_length, resolution, debug_dataset) |
|
|
|
db_subsets = [] |
|
for subset in subsets: |
|
db_subset = DreamBoothSubset( |
|
subset.image_dir, |
|
False, |
|
None, |
|
subset.caption_extension, |
|
subset.num_repeats, |
|
subset.shuffle_caption, |
|
subset.caption_separator, |
|
subset.keep_tokens, |
|
subset.keep_tokens_separator, |
|
subset.color_aug, |
|
subset.flip_aug, |
|
subset.face_crop_aug_range, |
|
subset.random_crop, |
|
subset.caption_dropout_rate, |
|
subset.caption_dropout_every_n_epochs, |
|
subset.caption_tag_dropout_rate, |
|
subset.caption_prefix, |
|
subset.caption_suffix, |
|
subset.token_warmup_min, |
|
subset.token_warmup_step, |
|
) |
|
db_subsets.append(db_subset) |
|
|
|
self.dreambooth_dataset_delegate = DreamBoothDataset( |
|
db_subsets, |
|
batch_size, |
|
tokenizer, |
|
max_token_length, |
|
resolution, |
|
enable_bucket, |
|
min_bucket_reso, |
|
max_bucket_reso, |
|
bucket_reso_steps, |
|
bucket_no_upscale, |
|
1.0, |
|
debug_dataset, |
|
) |
|
|
|
|
|
self.image_data = self.dreambooth_dataset_delegate.image_data |
|
self.batch_size = batch_size |
|
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images |
|
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images |
|
|
|
|
|
missing_imgs = [] |
|
cond_imgs_with_img = set() |
|
for image_key, info in self.dreambooth_dataset_delegate.image_data.items(): |
|
db_subset = self.dreambooth_dataset_delegate.image_to_subset[image_key] |
|
subset = None |
|
for s in subsets: |
|
if s.image_dir == db_subset.image_dir: |
|
subset = s |
|
break |
|
assert subset is not None, "internal error: subset not found" |
|
|
|
if not os.path.isdir(subset.conditioning_data_dir): |
|
print(f"not directory: {subset.conditioning_data_dir}") |
|
continue |
|
|
|
img_basename = os.path.basename(info.absolute_path) |
|
ctrl_img_path = os.path.join(subset.conditioning_data_dir, img_basename) |
|
if not os.path.exists(ctrl_img_path): |
|
missing_imgs.append(img_basename) |
|
|
|
info.cond_img_path = ctrl_img_path |
|
cond_imgs_with_img.add(ctrl_img_path) |
|
|
|
extra_imgs = [] |
|
for subset in subsets: |
|
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*") |
|
extra_imgs.extend( |
|
[cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img] |
|
) |
|
|
|
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" |
|
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" |
|
|
|
self.conditioning_image_transforms = IMAGE_TRANSFORMS |
|
|
|
def make_buckets(self): |
|
self.dreambooth_dataset_delegate.make_buckets() |
|
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager |
|
self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices |
|
|
|
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): |
|
return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) |
|
|
|
def __len__(self): |
|
return self.dreambooth_dataset_delegate.__len__() |
|
|
|
def __getitem__(self, index): |
|
example = self.dreambooth_dataset_delegate[index] |
|
|
|
bucket = self.dreambooth_dataset_delegate.bucket_manager.buckets[ |
|
self.dreambooth_dataset_delegate.buckets_indices[index].bucket_index |
|
] |
|
bucket_batch_size = self.dreambooth_dataset_delegate.buckets_indices[index].bucket_batch_size |
|
image_index = self.dreambooth_dataset_delegate.buckets_indices[index].batch_index * bucket_batch_size |
|
|
|
conditioning_images = [] |
|
|
|
for i, image_key in enumerate(bucket[image_index : image_index + bucket_batch_size]): |
|
image_info = self.dreambooth_dataset_delegate.image_data[image_key] |
|
|
|
target_size_hw = example["target_sizes_hw"][i] |
|
original_size_hw = example["original_sizes_hw"][i] |
|
crop_top_left = example["crop_top_lefts"][i] |
|
flipped = example["flippeds"][i] |
|
cond_img = load_image(image_info.cond_img_path) |
|
|
|
if self.dreambooth_dataset_delegate.enable_bucket: |
|
assert ( |
|
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] |
|
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" |
|
cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) |
|
|
|
|
|
|
|
h, w = target_size_hw |
|
ct = (cond_img.shape[0] - h) // 2 |
|
cl = (cond_img.shape[1] - w) // 2 |
|
cond_img = cond_img[ct : ct + h, cl : cl + w] |
|
else: |
|
|
|
|
|
|
|
|
|
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: |
|
cond_img = cv2.resize( |
|
cond_img, (int(target_size_hw[1]), int(target_size_hw[0])), interpolation=cv2.INTER_LANCZOS4 |
|
) |
|
|
|
if flipped: |
|
cond_img = cond_img[:, ::-1, :].copy() |
|
|
|
cond_img = self.conditioning_image_transforms(cond_img) |
|
conditioning_images.append(cond_img) |
|
|
|
example["conditioning_images"] = torch.stack(conditioning_images).to(memory_format=torch.contiguous_format).float() |
|
|
|
return example |
|
|
|
|
|
|
|
class DatasetGroup(torch.utils.data.ConcatDataset): |
|
def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]): |
|
self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]] |
|
|
|
super().__init__(datasets) |
|
|
|
self.image_data = {} |
|
self.num_train_images = 0 |
|
self.num_reg_images = 0 |
|
|
|
|
|
|
|
|
|
for dataset in datasets: |
|
self.image_data.update(dataset.image_data) |
|
self.num_train_images += dataset.num_train_images |
|
self.num_reg_images += dataset.num_reg_images |
|
|
|
def add_replacement(self, str_from, str_to): |
|
for dataset in self.datasets: |
|
dataset.add_replacement(str_from, str_to) |
|
|
|
|
|
|
|
|
|
|
|
def enable_XTI(self, *args, **kwargs): |
|
for dataset in self.datasets: |
|
dataset.enable_XTI(*args, **kwargs) |
|
|
|
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): |
|
for i, dataset in enumerate(self.datasets): |
|
print(f"[Dataset {i}]") |
|
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) |
|
|
|
def cache_text_encoder_outputs( |
|
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True |
|
): |
|
for i, dataset in enumerate(self.datasets): |
|
print(f"[Dataset {i}]") |
|
dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process) |
|
|
|
def set_caching_mode(self, caching_mode): |
|
for dataset in self.datasets: |
|
dataset.set_caching_mode(caching_mode) |
|
|
|
def verify_bucket_reso_steps(self, min_steps: int): |
|
for dataset in self.datasets: |
|
dataset.verify_bucket_reso_steps(min_steps) |
|
|
|
def is_latent_cacheable(self) -> bool: |
|
return all([dataset.is_latent_cacheable() for dataset in self.datasets]) |
|
|
|
def is_text_encoder_output_cacheable(self) -> bool: |
|
return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets]) |
|
|
|
def set_current_epoch(self, epoch): |
|
for dataset in self.datasets: |
|
dataset.set_current_epoch(epoch) |
|
|
|
def set_current_step(self, step): |
|
for dataset in self.datasets: |
|
dataset.set_current_step(step) |
|
|
|
def set_max_train_steps(self, max_train_steps): |
|
for dataset in self.datasets: |
|
dataset.set_max_train_steps(max_train_steps) |
|
|
|
def disable_token_padding(self): |
|
for dataset in self.datasets: |
|
dataset.disable_token_padding() |
|
|
|
|
|
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): |
|
expected_latents_size = (reso[1] // 8, reso[0] // 8) |
|
|
|
if not os.path.exists(npz_path): |
|
return False |
|
|
|
npz = np.load(npz_path) |
|
if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: |
|
return False |
|
if npz["latents"].shape[1:3] != expected_latents_size: |
|
return False |
|
|
|
if flip_aug: |
|
if "latents_flipped" not in npz: |
|
return False |
|
if npz["latents_flipped"].shape[1:3] != expected_latents_size: |
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
def load_latents_from_disk( |
|
npz_path, |
|
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]: |
|
npz = np.load(npz_path) |
|
if "latents" not in npz: |
|
raise ValueError(f"error: npz is old format. please re-generate {npz_path}") |
|
|
|
latents = npz["latents"] |
|
original_size = npz["original_size"].tolist() |
|
crop_ltrb = npz["crop_ltrb"].tolist() |
|
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None |
|
return latents, original_size, crop_ltrb, flipped_latents |
|
|
|
|
|
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None): |
|
kwargs = {} |
|
if flipped_latents_tensor is not None: |
|
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() |
|
np.savez( |
|
npz_path, |
|
latents=latents_tensor.float().cpu().numpy(), |
|
original_size=np.array(original_size), |
|
crop_ltrb=np.array(crop_ltrb), |
|
**kwargs, |
|
) |
|
|
|
|
|
def debug_dataset(train_dataset, show_input_ids=False): |
|
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") |
|
print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します") |
|
|
|
epoch = 1 |
|
while True: |
|
print(f"\nepoch: {epoch}") |
|
|
|
steps = (epoch - 1) * len(train_dataset) + 1 |
|
indices = list(range(len(train_dataset))) |
|
random.shuffle(indices) |
|
|
|
k = 0 |
|
for i, idx in enumerate(indices): |
|
train_dataset.set_current_epoch(epoch) |
|
train_dataset.set_current_step(steps) |
|
print(f"steps: {steps} ({i + 1}/{len(train_dataset)})") |
|
|
|
example = train_dataset[idx] |
|
if example["latents"] is not None: |
|
print(f"sample has latents from npz file: {example['latents'].size()}") |
|
for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate( |
|
zip( |
|
example["image_keys"], |
|
example["captions"], |
|
example["loss_weights"], |
|
example["input_ids"], |
|
example["original_sizes_hw"], |
|
example["crop_top_lefts"], |
|
example["target_sizes_hw"], |
|
example["flippeds"], |
|
) |
|
): |
|
print( |
|
f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}' |
|
) |
|
|
|
if show_input_ids: |
|
print(f"input ids: {iid}") |
|
if "input_ids2" in example: |
|
print(f"input ids2: {example['input_ids2'][j]}") |
|
if example["images"] is not None: |
|
im = example["images"][j] |
|
print(f"image size: {im.size()}") |
|
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) |
|
im = np.transpose(im, (1, 2, 0)) |
|
im = im[:, :, ::-1] |
|
|
|
if "conditioning_images" in example: |
|
cond_img = example["conditioning_images"][j] |
|
print(f"conditioning image size: {cond_img.size()}") |
|
cond_img = ((cond_img.numpy() + 1.0) * 127.5).astype(np.uint8) |
|
cond_img = np.transpose(cond_img, (1, 2, 0)) |
|
cond_img = cond_img[:, :, ::-1] |
|
if os.name == "nt": |
|
cv2.imshow("cond_img", cond_img) |
|
|
|
if os.name == "nt": |
|
cv2.imshow("img", im) |
|
k = cv2.waitKey() |
|
cv2.destroyAllWindows() |
|
if k == 27 or k == ord("s") or k == ord("e"): |
|
break |
|
steps += 1 |
|
|
|
if k == ord("e"): |
|
break |
|
if k == 27 or (example["images"] is None and i >= 8): |
|
k = 27 |
|
break |
|
if k == 27: |
|
break |
|
|
|
epoch += 1 |
|
|
|
|
|
def glob_images(directory, base="*"): |
|
img_paths = [] |
|
for ext in IMAGE_EXTENSIONS: |
|
if base == "*": |
|
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) |
|
else: |
|
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) |
|
img_paths = list(set(img_paths)) |
|
img_paths.sort() |
|
return img_paths |
|
|
|
|
|
def glob_images_pathlib(dir_path, recursive): |
|
image_paths = [] |
|
if recursive: |
|
for ext in IMAGE_EXTENSIONS: |
|
image_paths += list(dir_path.rglob("*" + ext)) |
|
else: |
|
for ext in IMAGE_EXTENSIONS: |
|
image_paths += list(dir_path.glob("*" + ext)) |
|
image_paths = list(set(image_paths)) |
|
image_paths.sort() |
|
return image_paths |
|
|
|
|
|
class MinimalDataset(BaseDataset): |
|
def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False): |
|
super().__init__(tokenizer, max_token_length, resolution, debug_dataset) |
|
|
|
self.num_train_images = 0 |
|
self.num_reg_images = 0 |
|
self.datasets = [self] |
|
self.batch_size = 1 |
|
|
|
self.subsets = [self] |
|
self.num_repeats = 1 |
|
self.img_count = 1 |
|
self.bucket_info = {} |
|
self.is_reg = False |
|
self.image_dir = "dummy" |
|
|
|
def verify_bucket_reso_steps(self, min_steps: int): |
|
pass |
|
|
|
def is_latent_cacheable(self) -> bool: |
|
return False |
|
|
|
def __len__(self): |
|
raise NotImplementedError |
|
|
|
|
|
def set_current_epoch(self, epoch): |
|
self.current_epoch = epoch |
|
|
|
def __getitem__(self, idx): |
|
r""" |
|
The subclass may have image_data for debug_dataset, which is a dict of ImageInfo objects. |
|
|
|
Returns: example like this: |
|
|
|
for i in range(batch_size): |
|
image_key = ... # whatever hashable |
|
image_keys.append(image_key) |
|
|
|
image = ... # PIL Image |
|
img_tensor = self.image_transforms(img) |
|
images.append(img_tensor) |
|
|
|
caption = ... # str |
|
input_ids = self.get_input_ids(caption) |
|
input_ids_list.append(input_ids) |
|
|
|
captions.append(caption) |
|
|
|
images = torch.stack(images, dim=0) |
|
input_ids_list = torch.stack(input_ids_list, dim=0) |
|
example = { |
|
"images": images, |
|
"input_ids": input_ids_list, |
|
"captions": captions, # for debug_dataset |
|
"latents": None, |
|
"image_keys": image_keys, # for debug_dataset |
|
"loss_weights": torch.ones(batch_size, dtype=torch.float32), |
|
} |
|
return example |
|
""" |
|
raise NotImplementedError |
|
|
|
|
|
def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: |
|
module = ".".join(args.dataset_class.split(".")[:-1]) |
|
dataset_class = args.dataset_class.split(".")[-1] |
|
module = importlib.import_module(module) |
|
dataset_class = getattr(module, dataset_class) |
|
train_dataset_group: MinimalDataset = dataset_class(tokenizer, args.max_token_length, args.resolution, args.debug_dataset) |
|
return train_dataset_group |
|
|
|
|
|
def load_image(image_path): |
|
image = Image.open(image_path) |
|
if not image.mode == "RGB": |
|
image = image.convert("RGB") |
|
img = np.array(image, np.uint8) |
|
return img |
|
|
|
|
|
|
|
def trim_and_resize_if_required( |
|
random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int] |
|
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]: |
|
image_height, image_width = image.shape[0:2] |
|
original_size = (image_width, image_height) |
|
|
|
if image_width != resized_size[0] or image_height != resized_size[1]: |
|
|
|
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) |
|
|
|
image_height, image_width = image.shape[0:2] |
|
|
|
if image_width > reso[0]: |
|
trim_size = image_width - reso[0] |
|
p = trim_size // 2 if not random_crop else random.randint(0, trim_size) |
|
|
|
image = image[:, p : p + reso[0]] |
|
if image_height > reso[1]: |
|
trim_size = image_height - reso[1] |
|
p = trim_size // 2 if not random_crop else random.randint(0, trim_size) |
|
|
|
image = image[p : p + reso[1]] |
|
|
|
|
|
|
|
|
|
crop_ltrb = BucketManager.get_crop_ltrb(reso, original_size) |
|
|
|
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}" |
|
return image, original_size, crop_ltrb |
|
|
|
|
|
def cache_batch_latents( |
|
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool |
|
) -> None: |
|
r""" |
|
requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz |
|
optionally requires image_infos to have: image |
|
if cache_to_disk is True, set info.latents_npz |
|
flipped latents is also saved if flip_aug is True |
|
if cache_to_disk is False, set info.latents |
|
latents_flipped is also set if flip_aug is True |
|
latents_original_size and latents_crop_ltrb are also set |
|
""" |
|
images = [] |
|
for info in image_infos: |
|
image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8) |
|
|
|
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) |
|
image = IMAGE_TRANSFORMS(image) |
|
images.append(image) |
|
|
|
info.latents_original_size = original_size |
|
info.latents_crop_ltrb = crop_ltrb |
|
|
|
img_tensors = torch.stack(images, dim=0) |
|
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) |
|
|
|
with torch.no_grad(): |
|
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") |
|
|
|
if flip_aug: |
|
img_tensors = torch.flip(img_tensors, dims=[3]) |
|
with torch.no_grad(): |
|
flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") |
|
else: |
|
flipped_latents = [None] * len(latents) |
|
|
|
for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents): |
|
|
|
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()): |
|
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") |
|
|
|
if cache_to_disk: |
|
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent) |
|
else: |
|
info.latents = latent |
|
if flip_aug: |
|
info.latents_flipped = flipped_latent |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def cache_batch_text_encoder_outputs( |
|
image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype |
|
): |
|
input_ids1 = input_ids1.to(text_encoders[0].device) |
|
input_ids2 = input_ids2.to(text_encoders[1].device) |
|
|
|
with torch.no_grad(): |
|
b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl( |
|
max_token_length, |
|
input_ids1, |
|
input_ids2, |
|
tokenizers[0], |
|
tokenizers[1], |
|
text_encoders[0], |
|
text_encoders[1], |
|
dtype, |
|
) |
|
|
|
|
|
b_hidden_state1 = b_hidden_state1.detach().to("cpu") |
|
b_hidden_state2 = b_hidden_state2.detach().to("cpu") |
|
b_pool2 = b_pool2.detach().to("cpu") |
|
|
|
for info, hidden_state1, hidden_state2, pool2 in zip(image_infos, b_hidden_state1, b_hidden_state2, b_pool2): |
|
if cache_to_disk: |
|
save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, hidden_state1, hidden_state2, pool2) |
|
else: |
|
info.text_encoder_outputs1 = hidden_state1 |
|
info.text_encoder_outputs2 = hidden_state2 |
|
info.text_encoder_pool2 = pool2 |
|
|
|
|
|
def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2): |
|
np.savez( |
|
npz_path, |
|
hidden_state1=hidden_state1.cpu().float().numpy(), |
|
hidden_state2=hidden_state2.cpu().float().numpy(), |
|
pool2=pool2.cpu().float().numpy(), |
|
) |
|
|
|
|
|
def load_text_encoder_outputs_from_disk(npz_path): |
|
with np.load(npz_path) as f: |
|
hidden_state1 = torch.from_numpy(f["hidden_state1"]) |
|
hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f else None |
|
pool2 = torch.from_numpy(f["pool2"]) if "pool2" in f else None |
|
return hidden_state1, hidden_state2, pool2 |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
高速化のためのモジュール入れ替え |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
EPSILON = 1e-6 |
|
|
|
|
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
|
|
def default(val, d): |
|
return val if exists(val) else d |
|
|
|
|
|
def model_hash(filename): |
|
"""Old model hash used by stable-diffusion-webui""" |
|
try: |
|
with open(filename, "rb") as file: |
|
m = hashlib.sha256() |
|
|
|
file.seek(0x100000) |
|
m.update(file.read(0x10000)) |
|
return m.hexdigest()[0:8] |
|
except FileNotFoundError: |
|
return "NOFILE" |
|
except IsADirectoryError: |
|
return "IsADirectory" |
|
except PermissionError: |
|
return "IsADirectory" |
|
|
|
|
|
def calculate_sha256(filename): |
|
"""New model hash used by stable-diffusion-webui""" |
|
try: |
|
hash_sha256 = hashlib.sha256() |
|
blksize = 1024 * 1024 |
|
|
|
with open(filename, "rb") as f: |
|
for chunk in iter(lambda: f.read(blksize), b""): |
|
hash_sha256.update(chunk) |
|
|
|
return hash_sha256.hexdigest() |
|
except FileNotFoundError: |
|
return "NOFILE" |
|
except IsADirectoryError: |
|
return "IsADirectory" |
|
except PermissionError: |
|
return "IsADirectory" |
|
|
|
|
|
def precalculate_safetensors_hashes(tensors, metadata): |
|
"""Precalculate the model hashes needed by sd-webui-additional-networks to |
|
save time on indexing the model later.""" |
|
|
|
|
|
|
|
|
|
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} |
|
|
|
bytes = safetensors.torch.save(tensors, metadata) |
|
b = BytesIO(bytes) |
|
|
|
model_hash = addnet_hash_safetensors(b) |
|
legacy_hash = addnet_hash_legacy(b) |
|
return model_hash, legacy_hash |
|
|
|
|
|
def addnet_hash_legacy(b): |
|
"""Old model hash used by sd-webui-additional-networks for .safetensors format files""" |
|
m = hashlib.sha256() |
|
|
|
b.seek(0x100000) |
|
m.update(b.read(0x10000)) |
|
return m.hexdigest()[0:8] |
|
|
|
|
|
def addnet_hash_safetensors(b): |
|
"""New model hash used by sd-webui-additional-networks for .safetensors format files""" |
|
hash_sha256 = hashlib.sha256() |
|
blksize = 1024 * 1024 |
|
|
|
b.seek(0) |
|
header = b.read(8) |
|
n = int.from_bytes(header, "little") |
|
|
|
offset = n + 8 |
|
b.seek(offset) |
|
for chunk in iter(lambda: b.read(blksize), b""): |
|
hash_sha256.update(chunk) |
|
|
|
return hash_sha256.hexdigest() |
|
|
|
|
|
def get_git_revision_hash() -> str: |
|
try: |
|
return subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=os.path.dirname(__file__)).decode("ascii").strip() |
|
except: |
|
return "(unknown)" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa): |
|
if mem_eff_attn: |
|
print("Enable memory efficient attention for U-Net") |
|
unet.set_use_memory_efficient_attention(False, True) |
|
elif xformers: |
|
print("Enable xformers for U-Net") |
|
try: |
|
import xformers.ops |
|
except ImportError: |
|
raise ImportError("No xformers / xformersがインストールされていないようです") |
|
|
|
unet.set_use_memory_efficient_attention(True, False) |
|
elif sdpa: |
|
print("Enable SDPA for U-Net") |
|
unet.set_use_sdpa(True) |
|
|
|
|
|
""" |
|
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers): |
|
# vae is not used currently, but it is here for future use |
|
if mem_eff_attn: |
|
replace_vae_attn_to_memory_efficient() |
|
elif xformers: |
|
# とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ |
|
print("Use Diffusers xformers for VAE") |
|
vae.encoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) |
|
vae.decoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) |
|
|
|
|
|
def replace_vae_attn_to_memory_efficient(): |
|
print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") |
|
flash_func = FlashAttentionFunction |
|
|
|
def forward_flash_attn(self, hidden_states): |
|
print("forward_flash_attn") |
|
q_bucket_size = 512 |
|
k_bucket_size = 1024 |
|
|
|
residual = hidden_states |
|
batch, channel, height, width = hidden_states.shape |
|
|
|
# norm |
|
hidden_states = self.group_norm(hidden_states) |
|
|
|
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) |
|
|
|
# proj to q, k, v |
|
query_proj = self.query(hidden_states) |
|
key_proj = self.key(hidden_states) |
|
value_proj = self.value(hidden_states) |
|
|
|
query_proj, key_proj, value_proj = map( |
|
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj) |
|
) |
|
|
|
out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) |
|
|
|
out = rearrange(out, "b h n d -> b n (h d)") |
|
|
|
# compute next hidden_states |
|
hidden_states = self.proj_attn(hidden_states) |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) |
|
|
|
# res connect and rescale |
|
hidden_states = (hidden_states + residual) / self.rescale_output_factor |
|
return hidden_states |
|
|
|
diffusers.models.attention.AttentionBlock.forward = forward_flash_attn |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_metadata_from_safetensors(safetensors_file: str) -> dict: |
|
"""r |
|
This method locks the file. see https://github.com/huggingface/safetensors/issues/164 |
|
If the file isn't .safetensors or doesn't have metadata, return empty dict. |
|
""" |
|
if os.path.splitext(safetensors_file)[1] != ".safetensors": |
|
return {} |
|
|
|
with safetensors.safe_open(safetensors_file, framework="pt", device="cpu") as f: |
|
metadata = f.metadata() |
|
if metadata is None: |
|
metadata = {} |
|
return metadata |
|
|
|
|
|
|
|
SS_METADATA_KEY_V2 = "ss_v2" |
|
SS_METADATA_KEY_BASE_MODEL_VERSION = "ss_base_model_version" |
|
SS_METADATA_KEY_NETWORK_MODULE = "ss_network_module" |
|
SS_METADATA_KEY_NETWORK_DIM = "ss_network_dim" |
|
SS_METADATA_KEY_NETWORK_ALPHA = "ss_network_alpha" |
|
SS_METADATA_KEY_NETWORK_ARGS = "ss_network_args" |
|
|
|
SS_METADATA_MINIMUM_KEYS = [ |
|
SS_METADATA_KEY_V2, |
|
SS_METADATA_KEY_BASE_MODEL_VERSION, |
|
SS_METADATA_KEY_NETWORK_MODULE, |
|
SS_METADATA_KEY_NETWORK_DIM, |
|
SS_METADATA_KEY_NETWORK_ALPHA, |
|
SS_METADATA_KEY_NETWORK_ARGS, |
|
] |
|
|
|
|
|
def build_minimum_network_metadata( |
|
v2: Optional[bool], |
|
base_model: Optional[str], |
|
network_module: str, |
|
network_dim: str, |
|
network_alpha: str, |
|
network_args: Optional[dict], |
|
): |
|
|
|
metadata = { |
|
SS_METADATA_KEY_NETWORK_MODULE: network_module, |
|
SS_METADATA_KEY_NETWORK_DIM: network_dim, |
|
SS_METADATA_KEY_NETWORK_ALPHA: network_alpha, |
|
} |
|
if v2 is not None: |
|
metadata[SS_METADATA_KEY_V2] = v2 |
|
if base_model is not None: |
|
metadata[SS_METADATA_KEY_BASE_MODEL_VERSION] = base_model |
|
if network_args is not None: |
|
metadata[SS_METADATA_KEY_NETWORK_ARGS] = json.dumps(network_args) |
|
return metadata |
|
|
|
|
|
def get_sai_model_spec( |
|
state_dict: dict, |
|
args: argparse.Namespace, |
|
sdxl: bool, |
|
lora: bool, |
|
textual_inversion: bool, |
|
is_stable_diffusion_ckpt: Optional[bool] = None, |
|
): |
|
timestamp = time.time() |
|
|
|
v2 = args.v2 |
|
v_parameterization = args.v_parameterization |
|
reso = args.resolution |
|
|
|
title = args.metadata_title if args.metadata_title is not None else args.output_name |
|
|
|
if args.min_timestep is not None or args.max_timestep is not None: |
|
min_time_step = args.min_timestep if args.min_timestep is not None else 0 |
|
max_time_step = args.max_timestep if args.max_timestep is not None else 1000 |
|
timesteps = (min_time_step, max_time_step) |
|
else: |
|
timesteps = None |
|
|
|
metadata = sai_model_spec.build_metadata( |
|
state_dict, |
|
v2, |
|
v_parameterization, |
|
sdxl, |
|
lora, |
|
textual_inversion, |
|
timestamp, |
|
title=title, |
|
reso=reso, |
|
is_stable_diffusion_ckpt=is_stable_diffusion_ckpt, |
|
author=args.metadata_author, |
|
description=args.metadata_description, |
|
license=args.metadata_license, |
|
tags=args.metadata_tags, |
|
timesteps=timesteps, |
|
clip_skip=args.clip_skip, |
|
) |
|
return metadata |
|
|
|
|
|
def add_sd_models_arguments(parser: argparse.ArgumentParser): |
|
|
|
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む") |
|
parser.add_argument( |
|
"--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする" |
|
) |
|
parser.add_argument( |
|
"--pretrained_model_name_or_path", |
|
type=str, |
|
default=None, |
|
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル", |
|
) |
|
parser.add_argument( |
|
"--tokenizer_cache_dir", |
|
type=str, |
|
default=None, |
|
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", |
|
) |
|
|
|
|
|
def add_optimizer_arguments(parser: argparse.ArgumentParser): |
|
parser.add_argument( |
|
"--optimizer_type", |
|
type=str, |
|
default="", |
|
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--use_8bit_adam", |
|
action="store_true", |
|
help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)", |
|
) |
|
parser.add_argument( |
|
"--use_lion_optimizer", |
|
action="store_true", |
|
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)", |
|
) |
|
|
|
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") |
|
parser.add_argument( |
|
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない" |
|
) |
|
|
|
parser.add_argument( |
|
"--optimizer_args", |
|
type=str, |
|
default=None, |
|
nargs="*", |
|
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")', |
|
) |
|
|
|
parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ") |
|
parser.add_argument( |
|
"--lr_scheduler_args", |
|
type=str, |
|
default=None, |
|
nargs="*", |
|
help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100")', |
|
) |
|
|
|
parser.add_argument( |
|
"--lr_scheduler", |
|
type=str, |
|
default="constant", |
|
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor", |
|
) |
|
parser.add_argument( |
|
"--lr_warmup_steps", |
|
type=int, |
|
default=0, |
|
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)", |
|
) |
|
parser.add_argument( |
|
"--lr_scheduler_num_cycles", |
|
type=int, |
|
default=1, |
|
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数", |
|
) |
|
parser.add_argument( |
|
"--lr_scheduler_power", |
|
type=float, |
|
default=1, |
|
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", |
|
) |
|
|
|
|
|
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): |
|
parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ") |
|
parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名") |
|
parser.add_argument( |
|
"--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名" |
|
) |
|
parser.add_argument( |
|
"--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類" |
|
) |
|
parser.add_argument( |
|
"--huggingface_path_in_repo", |
|
type=str, |
|
default=None, |
|
help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス", |
|
) |
|
parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン") |
|
parser.add_argument( |
|
"--huggingface_repo_visibility", |
|
type=str, |
|
default=None, |
|
help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定('public'で公開、'private'またはNoneで非公開)", |
|
) |
|
parser.add_argument( |
|
"--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する" |
|
) |
|
parser.add_argument( |
|
"--resume_from_huggingface", |
|
action="store_true", |
|
help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})", |
|
) |
|
parser.add_argument( |
|
"--async_upload", |
|
action="store_true", |
|
help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする", |
|
) |
|
parser.add_argument( |
|
"--save_precision", |
|
type=str, |
|
default=None, |
|
choices=[None, "float", "fp16", "bf16"], |
|
help="precision in saving / 保存時に精度を変更して保存する", |
|
) |
|
parser.add_argument( |
|
"--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する" |
|
) |
|
parser.add_argument( |
|
"--save_every_n_steps", type=int, default=None, help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する" |
|
) |
|
parser.add_argument( |
|
"--save_n_epoch_ratio", |
|
type=int, |
|
default=None, |
|
help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)", |
|
) |
|
parser.add_argument( |
|
"--save_last_n_epochs", |
|
type=int, |
|
default=None, |
|
help="save last N checkpoints when saving every N epochs (remove older checkpoints) / 指定エポックごとにモデルを保存するとき最大Nエポック保存する(古いチェックポイントは削除する)", |
|
) |
|
parser.add_argument( |
|
"--save_last_n_epochs_state", |
|
type=int, |
|
default=None, |
|
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きする)", |
|
) |
|
parser.add_argument( |
|
"--save_last_n_steps", |
|
type=int, |
|
default=None, |
|
help="save checkpoints until N steps elapsed (remove older checkpoints if N steps elapsed) / 指定ステップごとにモデルを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する)", |
|
) |
|
parser.add_argument( |
|
"--save_last_n_steps_state", |
|
type=int, |
|
default=None, |
|
help="save states until N steps elapsed (remove older states if N steps elapsed, overrides --save_last_n_steps) / 指定ステップごとにstateを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する。--save_last_n_stepsを上書きする)", |
|
) |
|
parser.add_argument( |
|
"--save_state", |
|
action="store_true", |
|
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する", |
|
) |
|
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") |
|
|
|
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ") |
|
parser.add_argument( |
|
"--max_token_length", |
|
type=int, |
|
default=None, |
|
choices=[None, 150, 225], |
|
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)", |
|
) |
|
parser.add_argument( |
|
"--mem_eff_attn", |
|
action="store_true", |
|
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う", |
|
) |
|
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") |
|
parser.add_argument( |
|
"--sdpa", |
|
action="store_true", |
|
help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)", |
|
) |
|
parser.add_argument( |
|
"--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" |
|
) |
|
|
|
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") |
|
parser.add_argument( |
|
"--max_train_epochs", |
|
type=int, |
|
default=None, |
|
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)", |
|
) |
|
parser.add_argument( |
|
"--max_data_loader_n_workers", |
|
type=int, |
|
default=8, |
|
help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)", |
|
) |
|
parser.add_argument( |
|
"--persistent_data_loader_workers", |
|
action="store_true", |
|
help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)", |
|
) |
|
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") |
|
parser.add_argument( |
|
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする" |
|
) |
|
parser.add_argument( |
|
"--gradient_accumulation_steps", |
|
type=int, |
|
default=1, |
|
help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数", |
|
) |
|
parser.add_argument( |
|
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度" |
|
) |
|
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") |
|
parser.add_argument( |
|
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する" |
|
) |
|
parser.add_argument( |
|
"--ddp_timeout", |
|
type=int, |
|
default=None, |
|
help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)", |
|
) |
|
parser.add_argument( |
|
"--ddp_gradient_as_bucket_view", |
|
action="store_true", |
|
help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする", |
|
) |
|
parser.add_argument( |
|
"--ddp_static_graph", |
|
action="store_true", |
|
help="enable static_graph for DDP / DDPでstatic_graphを有効にする", |
|
) |
|
parser.add_argument( |
|
"--clip_skip", |
|
type=int, |
|
default=None, |
|
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)", |
|
) |
|
parser.add_argument( |
|
"--logging_dir", |
|
type=str, |
|
default=None, |
|
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する", |
|
) |
|
parser.add_argument( |
|
"--log_with", |
|
type=str, |
|
default=None, |
|
choices=["tensorboard", "wandb", "all"], |
|
help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)", |
|
) |
|
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") |
|
parser.add_argument( |
|
"--log_tracker_name", |
|
type=str, |
|
default=None, |
|
help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名", |
|
) |
|
parser.add_argument( |
|
"--log_tracker_config", |
|
type=str, |
|
default=None, |
|
help="path to tracker config file to use for logging / ログ出力に使用するtrackerの設定ファイルのパス", |
|
) |
|
parser.add_argument( |
|
"--wandb_api_key", |
|
type=str, |
|
default=None, |
|
help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)", |
|
) |
|
parser.add_argument( |
|
"--noise_offset", |
|
type=float, |
|
default=None, |
|
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)", |
|
) |
|
parser.add_argument( |
|
"--multires_noise_iterations", |
|
type=int, |
|
default=None, |
|
help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する(有効にする場合は6-10程度を推奨)", |
|
) |
|
parser.add_argument( |
|
"--ip_noise_gamma", |
|
type=float, |
|
default=None, |
|
help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) " |
|
+ "/ input perturbation noiseを有効にする。正則化に使用される。推奨値: 0.1程度 (arxiv.org/abs/2301.11706 より)", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument( |
|
"--multires_noise_discount", |
|
type=float, |
|
default=0.3, |
|
help="set discount value for multires noise (has no effect without --multires_noise_iterations) / Multires noiseのdiscount値を設定する(--multires_noise_iterations指定時のみ有効)", |
|
) |
|
parser.add_argument( |
|
"--adaptive_noise_scale", |
|
type=float, |
|
default=None, |
|
help="add `latent mean absolute value * this value` to noise_offset (disabled if None, default) / latentの平均値の絶対値 * この値をnoise_offsetに加算する(Noneの場合は無効、デフォルト)", |
|
) |
|
parser.add_argument( |
|
"--zero_terminal_snr", |
|
action="store_true", |
|
help="fix noise scheduler betas to enforce zero terminal SNR / noise schedulerのbetasを修正して、zero terminal SNRを強制する", |
|
) |
|
parser.add_argument( |
|
"--min_timestep", |
|
type=int, |
|
default=None, |
|
help="set minimum time step for U-Net training (0~999, default is 0) / U-Net学習時のtime stepの最小値を設定する(0~999で指定、省略時はデフォルト値(0)) ", |
|
) |
|
parser.add_argument( |
|
"--max_timestep", |
|
type=int, |
|
default=None, |
|
help="set maximum time step for U-Net training (1~1000, default is 1000) / U-Net学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))", |
|
) |
|
|
|
parser.add_argument( |
|
"--lowram", |
|
action="store_true", |
|
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)", |
|
) |
|
|
|
parser.add_argument( |
|
"--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する" |
|
) |
|
parser.add_argument("--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する") |
|
parser.add_argument( |
|
"--sample_every_n_epochs", |
|
type=int, |
|
default=None, |
|
help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)", |
|
) |
|
parser.add_argument( |
|
"--sample_prompts", type=str, default=None, help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル" |
|
) |
|
parser.add_argument( |
|
"--sample_sampler", |
|
type=str, |
|
default="ddim", |
|
choices=[ |
|
"ddim", |
|
"pndm", |
|
"lms", |
|
"euler", |
|
"euler_a", |
|
"heun", |
|
"dpm_2", |
|
"dpm_2_a", |
|
"dpmsolver", |
|
"dpmsolver++", |
|
"dpmsingle", |
|
"k_lms", |
|
"k_euler", |
|
"k_euler_a", |
|
"k_dpm_2", |
|
"k_dpm_2_a", |
|
], |
|
help=f"sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類", |
|
) |
|
|
|
parser.add_argument( |
|
"--config_file", |
|
type=str, |
|
default=None, |
|
help="using .toml instead of args to pass hyperparameter / ハイパーパラメータを引数ではなく.tomlファイルで渡す", |
|
) |
|
parser.add_argument( |
|
"--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する" |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--metadata_title", |
|
type=str, |
|
default=None, |
|
help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name", |
|
) |
|
parser.add_argument( |
|
"--metadata_author", |
|
type=str, |
|
default=None, |
|
help="author name for model metadata / メタデータに書き込まれるモデル作者名", |
|
) |
|
parser.add_argument( |
|
"--metadata_description", |
|
type=str, |
|
default=None, |
|
help="description for model metadata / メタデータに書き込まれるモデル説明", |
|
) |
|
parser.add_argument( |
|
"--metadata_license", |
|
type=str, |
|
default=None, |
|
help="license for model metadata / メタデータに書き込まれるモデルライセンス", |
|
) |
|
parser.add_argument( |
|
"--metadata_tags", |
|
type=str, |
|
default=None, |
|
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り", |
|
) |
|
|
|
if support_dreambooth: |
|
|
|
parser.add_argument( |
|
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" |
|
) |
|
|
|
|
|
def verify_training_args(args: argparse.Namespace): |
|
if args.v_parameterization and not args.v2: |
|
print("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません") |
|
if args.v2 and args.clip_skip is not None: |
|
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") |
|
|
|
if args.cache_latents_to_disk and not args.cache_latents: |
|
args.cache_latents = True |
|
print( |
|
"cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.adaptive_noise_scale is not None and args.noise_offset is None: |
|
raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です") |
|
|
|
if args.scale_v_pred_loss_like_noise_pred and not args.v_parameterization: |
|
raise ValueError( |
|
"scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます" |
|
) |
|
|
|
if args.v_pred_like_loss and args.v_parameterization: |
|
raise ValueError( |
|
"v_pred_like_loss cannot be enabled with v_parameterization / v_pred_like_lossはv_parameterizationが有効なときには有効にできません" |
|
) |
|
|
|
if args.zero_terminal_snr and not args.v_parameterization: |
|
print( |
|
f"zero_terminal_snr is enabled, but v_parameterization is not enabled. training will be unexpected" |
|
+ " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります" |
|
) |
|
|
|
|
|
def add_dataset_arguments( |
|
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool |
|
): |
|
|
|
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") |
|
parser.add_argument("--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする") |
|
parser.add_argument("--caption_separator", type=str, default=",", help="separator for caption / captionの区切り文字") |
|
parser.add_argument( |
|
"--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子" |
|
) |
|
parser.add_argument( |
|
"--caption_extention", |
|
type=str, |
|
default=None, |
|
help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)", |
|
) |
|
parser.add_argument( |
|
"--keep_tokens", |
|
type=int, |
|
default=0, |
|
help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)", |
|
) |
|
parser.add_argument( |
|
"--keep_tokens_separator", |
|
type=str, |
|
default="", |
|
help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens." |
|
+ " / captionを固定部分と可変部分に分けるためのカスタム区切り文字。この区切り文字より前のトークンはシャッフルされない。指定しない場合、'--keep_tokens'が固定部分のトークン数として使用される。", |
|
) |
|
parser.add_argument( |
|
"--caption_prefix", |
|
type=str, |
|
default=None, |
|
help="prefix for caption text / captionのテキストの先頭に付ける文字列", |
|
) |
|
parser.add_argument( |
|
"--caption_suffix", |
|
type=str, |
|
default=None, |
|
help="suffix for caption text / captionのテキストの末尾に付ける文字列", |
|
) |
|
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") |
|
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") |
|
parser.add_argument( |
|
"--face_crop_aug_range", |
|
type=str, |
|
default=None, |
|
help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)", |
|
) |
|
parser.add_argument( |
|
"--random_crop", |
|
action="store_true", |
|
help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)", |
|
) |
|
parser.add_argument( |
|
"--debug_dataset", action="store_true", help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)" |
|
) |
|
parser.add_argument( |
|
"--resolution", |
|
type=str, |
|
default=None, |
|
help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)", |
|
) |
|
parser.add_argument( |
|
"--cache_latents", |
|
action="store_true", |
|
help="cache latents to main memory to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをメインメモリにcacheする(augmentationは使用不可) ", |
|
) |
|
parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ") |
|
parser.add_argument( |
|
"--cache_latents_to_disk", |
|
action="store_true", |
|
help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)", |
|
) |
|
parser.add_argument( |
|
"--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする" |
|
) |
|
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") |
|
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") |
|
parser.add_argument( |
|
"--bucket_reso_steps", |
|
type=int, |
|
default=64, |
|
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します", |
|
) |
|
parser.add_argument( |
|
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" |
|
) |
|
|
|
parser.add_argument( |
|
"--token_warmup_min", |
|
type=int, |
|
default=1, |
|
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する", |
|
) |
|
parser.add_argument( |
|
"--token_warmup_step", |
|
type=float, |
|
default=0, |
|
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)", |
|
) |
|
|
|
parser.add_argument( |
|
"--dataset_class", |
|
type=str, |
|
default=None, |
|
help="dataset class for arbitrary dataset (package.module.Class) / 任意のデータセットを用いるときのクラス名 (package.module.Class)", |
|
) |
|
|
|
if support_caption_dropout: |
|
|
|
|
|
parser.add_argument( |
|
"--caption_dropout_rate", type=float, default=0.0, help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合" |
|
) |
|
parser.add_argument( |
|
"--caption_dropout_every_n_epochs", |
|
type=int, |
|
default=0, |
|
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする", |
|
) |
|
parser.add_argument( |
|
"--caption_tag_dropout_rate", |
|
type=float, |
|
default=0.0, |
|
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合", |
|
) |
|
|
|
if support_dreambooth: |
|
|
|
parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") |
|
|
|
if support_caption: |
|
|
|
parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル") |
|
parser.add_argument( |
|
"--dataset_repeats", type=int, default=1, help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数" |
|
) |
|
|
|
|
|
def add_sd_saving_arguments(parser: argparse.ArgumentParser): |
|
parser.add_argument( |
|
"--save_model_as", |
|
type=str, |
|
default=None, |
|
choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"], |
|
help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)", |
|
) |
|
parser.add_argument( |
|
"--use_safetensors", |
|
action="store_true", |
|
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)", |
|
) |
|
|
|
|
|
def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser): |
|
if not args.config_file: |
|
return args |
|
|
|
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file |
|
|
|
if args.output_config: |
|
|
|
if os.path.exists(config_path): |
|
print(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}") |
|
exit(1) |
|
|
|
|
|
args_dict = vars(args) |
|
|
|
|
|
for key in ["config_file", "output_config", "wandb_api_key"]: |
|
if key in args_dict: |
|
del args_dict[key] |
|
|
|
|
|
default_args = vars(parser.parse_args([])) |
|
|
|
|
|
for key, value in list(args_dict.items()): |
|
if key in default_args and value == default_args[key]: |
|
del args_dict[key] |
|
|
|
|
|
for key, value in args_dict.items(): |
|
if isinstance(value, pathlib.Path): |
|
args_dict[key] = str(value) |
|
|
|
|
|
with open(config_path, "w") as f: |
|
toml.dump(args_dict, f) |
|
|
|
print(f"Saved config file / 設定ファイルを保存しました: {config_path}") |
|
exit(0) |
|
|
|
if not os.path.exists(config_path): |
|
print(f"{config_path} not found.") |
|
exit(1) |
|
|
|
print(f"Loading settings from {config_path}...") |
|
with open(config_path, "r") as f: |
|
config_dict = toml.load(f) |
|
|
|
|
|
ignore_nesting_dict = {} |
|
for section_name, section_dict in config_dict.items(): |
|
|
|
if not isinstance(section_dict, dict): |
|
ignore_nesting_dict[section_name] = section_dict |
|
continue |
|
|
|
|
|
for key, value in section_dict.items(): |
|
ignore_nesting_dict[key] = value |
|
|
|
config_args = argparse.Namespace(**ignore_nesting_dict) |
|
args = parser.parse_args(namespace=config_args) |
|
args.config_file = os.path.splitext(args.config_file)[0] |
|
print(args.config_file) |
|
|
|
return args |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resume_from_local_or_hf_if_specified(accelerator, args): |
|
if not args.resume: |
|
return |
|
|
|
if not args.resume_from_huggingface: |
|
print(f"resume training from local state: {args.resume}") |
|
accelerator.load_state(args.resume) |
|
return |
|
|
|
print(f"resume training from huggingface state: {args.resume}") |
|
repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1] |
|
path_in_repo = "/".join(args.resume.split("/")[2:]) |
|
revision = None |
|
repo_type = None |
|
if ":" in path_in_repo: |
|
divided = path_in_repo.split(":") |
|
if len(divided) == 2: |
|
path_in_repo, revision = divided |
|
repo_type = "model" |
|
else: |
|
path_in_repo, revision, repo_type = divided |
|
print(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}") |
|
|
|
list_files = huggingface_util.list_dir( |
|
repo_id=repo_id, |
|
subfolder=path_in_repo, |
|
revision=revision, |
|
token=args.huggingface_token, |
|
repo_type=repo_type, |
|
) |
|
|
|
async def download(filename) -> str: |
|
def task(): |
|
return hf_hub_download( |
|
repo_id=repo_id, |
|
filename=filename, |
|
revision=revision, |
|
repo_type=repo_type, |
|
token=args.huggingface_token, |
|
) |
|
|
|
return await asyncio.get_event_loop().run_in_executor(None, task) |
|
|
|
loop = asyncio.get_event_loop() |
|
results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files])) |
|
if len(results) == 0: |
|
raise ValueError("No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした") |
|
dirname = os.path.dirname(results[0]) |
|
accelerator.load_state(dirname) |
|
|
|
|
|
def get_optimizer(args, trainable_params): |
|
|
|
|
|
optimizer_type = args.optimizer_type |
|
if args.use_8bit_adam: |
|
assert ( |
|
not args.use_lion_optimizer |
|
), "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています" |
|
assert ( |
|
optimizer_type is None or optimizer_type == "" |
|
), "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています" |
|
optimizer_type = "AdamW8bit" |
|
|
|
elif args.use_lion_optimizer: |
|
assert ( |
|
optimizer_type is None or optimizer_type == "" |
|
), "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています" |
|
optimizer_type = "Lion" |
|
|
|
if optimizer_type is None or optimizer_type == "": |
|
optimizer_type = "AdamW" |
|
optimizer_type = optimizer_type.lower() |
|
|
|
|
|
optimizer_kwargs = {} |
|
if args.optimizer_args is not None and len(args.optimizer_args) > 0: |
|
for arg in args.optimizer_args: |
|
key, value = arg.split("=") |
|
value = ast.literal_eval(value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer_kwargs[key] = value |
|
|
|
|
|
lr = args.learning_rate |
|
optimizer = None |
|
|
|
if optimizer_type == "Lion".lower(): |
|
try: |
|
import lion_pytorch |
|
except ImportError: |
|
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") |
|
print(f"use Lion optimizer | {optimizer_kwargs}") |
|
optimizer_class = lion_pytorch.Lion |
|
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) |
|
|
|
elif optimizer_type.endswith("8bit".lower()): |
|
try: |
|
import bitsandbytes as bnb |
|
except ImportError: |
|
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") |
|
|
|
if optimizer_type == "AdamW8bit".lower(): |
|
print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") |
|
optimizer_class = bnb.optim.AdamW8bit |
|
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) |
|
|
|
elif optimizer_type == "SGDNesterov8bit".lower(): |
|
print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") |
|
if "momentum" not in optimizer_kwargs: |
|
print( |
|
f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します" |
|
) |
|
optimizer_kwargs["momentum"] = 0.9 |
|
|
|
optimizer_class = bnb.optim.SGD8bit |
|
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) |
|
|
|
elif optimizer_type == "Lion8bit".lower(): |
|
print(f"use 8-bit Lion optimizer | {optimizer_kwargs}") |
|
try: |
|
optimizer_class = bnb.optim.Lion8bit |
|
except AttributeError: |
|
raise AttributeError( |
|
"No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください" |
|
) |
|
elif optimizer_type == "PagedAdamW8bit".lower(): |
|
print(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}") |
|
try: |
|
optimizer_class = bnb.optim.PagedAdamW8bit |
|
except AttributeError: |
|
raise AttributeError( |
|
"No PagedAdamW8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" |
|
) |
|
elif optimizer_type == "PagedLion8bit".lower(): |
|
print(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}") |
|
try: |
|
optimizer_class = bnb.optim.PagedLion8bit |
|
except AttributeError: |
|
raise AttributeError( |
|
"No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" |
|
) |
|
|
|
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) |
|
|
|
elif optimizer_type == "PagedAdamW".lower(): |
|
print(f"use PagedAdamW optimizer | {optimizer_kwargs}") |
|
try: |
|
import bitsandbytes as bnb |
|
except ImportError: |
|
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") |
|
try: |
|
optimizer_class = bnb.optim.PagedAdamW |
|
except AttributeError: |
|
raise AttributeError( |
|
"No PagedAdamW. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamWが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" |
|
) |
|
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) |
|
|
|
elif optimizer_type == "PagedAdamW32bit".lower(): |
|
print(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}") |
|
try: |
|
import bitsandbytes as bnb |
|
except ImportError: |
|
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") |
|
try: |
|
optimizer_class = bnb.optim.PagedAdamW32bit |
|
except AttributeError: |
|
raise AttributeError( |
|
"No PagedAdamW32bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW32bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" |
|
) |
|
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) |
|
|
|
elif optimizer_type == "SGDNesterov".lower(): |
|
print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") |
|
if "momentum" not in optimizer_kwargs: |
|
print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します") |
|
optimizer_kwargs["momentum"] = 0.9 |
|
|
|
optimizer_class = torch.optim.SGD |
|
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) |
|
|
|
elif optimizer_type.startswith("DAdapt".lower()) or optimizer_type == "Prodigy".lower(): |
|
|
|
actual_lr = lr |
|
lr_count = 1 |
|
if type(trainable_params) == list and type(trainable_params[0]) == dict: |
|
lrs = set() |
|
actual_lr = trainable_params[0].get("lr", actual_lr) |
|
for group in trainable_params: |
|
lrs.add(group.get("lr", actual_lr)) |
|
lr_count = len(lrs) |
|
|
|
if actual_lr <= 0.1: |
|
print( |
|
f"learning rate is too low. If using D-Adaptation or Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。D-AdaptationまたはProdigyの使用時は1.0前後の値を指定してください: lr={actual_lr}" |
|
) |
|
print("recommend option: lr=1.0 / 推奨は1.0です") |
|
if lr_count > 1: |
|
print( |
|
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-AdaptationまたはProdigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" |
|
) |
|
|
|
if optimizer_type.startswith("DAdapt".lower()): |
|
|
|
|
|
try: |
|
import dadaptation |
|
import dadaptation.experimental as experimental |
|
except ImportError: |
|
raise ImportError("No dadaptation / dadaptation がインストールされていないようです") |
|
|
|
|
|
if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower(): |
|
optimizer_class = experimental.DAdaptAdamPreprint |
|
print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") |
|
elif optimizer_type == "DAdaptAdaGrad".lower(): |
|
optimizer_class = dadaptation.DAdaptAdaGrad |
|
print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") |
|
elif optimizer_type == "DAdaptAdam".lower(): |
|
optimizer_class = dadaptation.DAdaptAdam |
|
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") |
|
elif optimizer_type == "DAdaptAdan".lower(): |
|
optimizer_class = dadaptation.DAdaptAdan |
|
print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") |
|
elif optimizer_type == "DAdaptAdanIP".lower(): |
|
optimizer_class = experimental.DAdaptAdanIP |
|
print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}") |
|
elif optimizer_type == "DAdaptLion".lower(): |
|
optimizer_class = dadaptation.DAdaptLion |
|
print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}") |
|
elif optimizer_type == "DAdaptSGD".lower(): |
|
optimizer_class = dadaptation.DAdaptSGD |
|
print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") |
|
else: |
|
raise ValueError(f"Unknown optimizer type: {optimizer_type}") |
|
|
|
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) |
|
else: |
|
|
|
|
|
try: |
|
import prodigyopt |
|
except ImportError: |
|
raise ImportError("No Prodigy / Prodigy がインストールされていないようです") |
|
|
|
print(f"use Prodigy optimizer | {optimizer_kwargs}") |
|
optimizer_class = prodigyopt.Prodigy |
|
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) |
|
|
|
elif optimizer_type == "Adafactor".lower(): |
|
|
|
if "relative_step" not in optimizer_kwargs: |
|
optimizer_kwargs["relative_step"] = True |
|
if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False): |
|
print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします") |
|
optimizer_kwargs["relative_step"] = True |
|
print(f"use Adafactor optimizer | {optimizer_kwargs}") |
|
|
|
if optimizer_kwargs["relative_step"]: |
|
print(f"relative_step is true / relative_stepがtrueです") |
|
if lr != 0.0: |
|
print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") |
|
args.learning_rate = None |
|
|
|
|
|
if type(trainable_params) == list and type(trainable_params[0]) == dict: |
|
has_group_lr = False |
|
for group in trainable_params: |
|
p = group.pop("lr", None) |
|
has_group_lr = has_group_lr or (p is not None) |
|
|
|
if has_group_lr: |
|
|
|
print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます") |
|
args.unet_lr = None |
|
args.text_encoder_lr = None |
|
|
|
if args.lr_scheduler != "adafactor": |
|
print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します") |
|
args.lr_scheduler = f"adafactor:{lr}" |
|
|
|
lr = None |
|
else: |
|
if args.max_grad_norm != 0.0: |
|
print( |
|
f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません" |
|
) |
|
if args.lr_scheduler != "constant_with_warmup": |
|
print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません") |
|
if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0: |
|
print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません") |
|
|
|
optimizer_class = transformers.optimization.Adafactor |
|
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) |
|
|
|
elif optimizer_type == "AdamW".lower(): |
|
print(f"use AdamW optimizer | {optimizer_kwargs}") |
|
optimizer_class = torch.optim.AdamW |
|
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) |
|
|
|
if optimizer is None: |
|
|
|
optimizer_type = args.optimizer_type |
|
print(f"use {optimizer_type} | {optimizer_kwargs}") |
|
if "." not in optimizer_type: |
|
optimizer_module = torch.optim |
|
else: |
|
values = optimizer_type.split(".") |
|
optimizer_module = importlib.import_module(".".join(values[:-1])) |
|
optimizer_type = values[-1] |
|
|
|
optimizer_class = getattr(optimizer_module, optimizer_type) |
|
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) |
|
|
|
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ |
|
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) |
|
|
|
return optimizer_name, optimizer_args, optimizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): |
|
""" |
|
Unified API to get any scheduler from its name. |
|
""" |
|
name = args.lr_scheduler |
|
num_warmup_steps: Optional[int] = args.lr_warmup_steps |
|
num_training_steps = args.max_train_steps * num_processes |
|
num_cycles = args.lr_scheduler_num_cycles |
|
power = args.lr_scheduler_power |
|
|
|
lr_scheduler_kwargs = {} |
|
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0: |
|
for arg in args.lr_scheduler_args: |
|
key, value = arg.split("=") |
|
value = ast.literal_eval(value) |
|
lr_scheduler_kwargs[key] = value |
|
|
|
def wrap_check_needless_num_warmup_steps(return_vals): |
|
if num_warmup_steps is not None and num_warmup_steps != 0: |
|
raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.") |
|
return return_vals |
|
|
|
|
|
if args.lr_scheduler_type: |
|
lr_scheduler_type = args.lr_scheduler_type |
|
print(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") |
|
if "." not in lr_scheduler_type: |
|
lr_scheduler_module = torch.optim.lr_scheduler |
|
else: |
|
values = lr_scheduler_type.split(".") |
|
lr_scheduler_module = importlib.import_module(".".join(values[:-1])) |
|
lr_scheduler_type = values[-1] |
|
lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type) |
|
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs) |
|
return wrap_check_needless_num_warmup_steps(lr_scheduler) |
|
|
|
if name.startswith("adafactor"): |
|
assert ( |
|
type(optimizer) == transformers.optimization.Adafactor |
|
), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください" |
|
initial_lr = float(name.split(":")[1]) |
|
|
|
return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) |
|
|
|
name = SchedulerType(name) |
|
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] |
|
|
|
if name == SchedulerType.CONSTANT: |
|
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs)) |
|
|
|
if name == SchedulerType.PIECEWISE_CONSTANT: |
|
return schedule_func(optimizer, **lr_scheduler_kwargs) |
|
|
|
|
|
if num_warmup_steps is None: |
|
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") |
|
|
|
if name == SchedulerType.CONSTANT_WITH_WARMUP: |
|
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs) |
|
|
|
|
|
if num_training_steps is None: |
|
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") |
|
|
|
if name == SchedulerType.COSINE_WITH_RESTARTS: |
|
return schedule_func( |
|
optimizer, |
|
num_warmup_steps=num_warmup_steps, |
|
num_training_steps=num_training_steps, |
|
num_cycles=num_cycles, |
|
**lr_scheduler_kwargs, |
|
) |
|
|
|
if name == SchedulerType.POLYNOMIAL: |
|
return schedule_func( |
|
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs |
|
) |
|
|
|
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs) |
|
|
|
|
|
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): |
|
|
|
if args.caption_extention is not None: |
|
args.caption_extension = args.caption_extention |
|
args.caption_extention = None |
|
|
|
|
|
if args.resolution is not None: |
|
args.resolution = tuple([int(r) for r in args.resolution.split(",")]) |
|
if len(args.resolution) == 1: |
|
args.resolution = (args.resolution[0], args.resolution[0]) |
|
assert ( |
|
len(args.resolution) == 2 |
|
), f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}" |
|
|
|
if args.face_crop_aug_range is not None: |
|
args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(",")]) |
|
assert ( |
|
len(args.face_crop_aug_range) == 2 and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1] |
|
), f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}" |
|
else: |
|
args.face_crop_aug_range = None |
|
|
|
if support_metadata: |
|
if args.in_json is not None and (args.color_aug or args.random_crop): |
|
print( |
|
f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます" |
|
) |
|
|
|
|
|
def load_tokenizer(args: argparse.Namespace): |
|
print("prepare tokenizer") |
|
original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH |
|
|
|
tokenizer: CLIPTokenizer = None |
|
if args.tokenizer_cache_dir: |
|
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) |
|
if os.path.exists(local_tokenizer_path): |
|
print(f"load tokenizer from cache: {local_tokenizer_path}") |
|
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) |
|
|
|
if tokenizer is None: |
|
if args.v2: |
|
tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer") |
|
else: |
|
tokenizer = CLIPTokenizer.from_pretrained(original_path) |
|
|
|
if hasattr(args, "max_token_length") and args.max_token_length is not None: |
|
print(f"update token length: {args.max_token_length}") |
|
|
|
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): |
|
print(f"save Tokenizer to cache: {local_tokenizer_path}") |
|
tokenizer.save_pretrained(local_tokenizer_path) |
|
|
|
return tokenizer |
|
|
|
|
|
def prepare_accelerator(args: argparse.Namespace): |
|
if args.logging_dir is None: |
|
logging_dir = None |
|
else: |
|
log_prefix = "" if args.log_prefix is None else args.log_prefix |
|
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime()) |
|
|
|
if args.log_with is None: |
|
if logging_dir is not None: |
|
log_with = "tensorboard" |
|
else: |
|
log_with = None |
|
else: |
|
log_with = args.log_with |
|
if log_with in ["tensorboard", "all"]: |
|
if logging_dir is None: |
|
raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください") |
|
if log_with in ["wandb", "all"]: |
|
try: |
|
import wandb |
|
except ImportError: |
|
raise ImportError("No wandb / wandb がインストールされていないようです") |
|
if logging_dir is not None: |
|
os.makedirs(logging_dir, exist_ok=True) |
|
os.environ["WANDB_DIR"] = logging_dir |
|
if args.wandb_api_key is not None: |
|
wandb.login(key=args.wandb_api_key) |
|
|
|
kwargs_handlers = ( |
|
InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None, |
|
DistributedDataParallelKwargs(gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph) |
|
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph |
|
else None, |
|
) |
|
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers)) |
|
accelerator = Accelerator( |
|
gradient_accumulation_steps=args.gradient_accumulation_steps, |
|
mixed_precision=args.mixed_precision, |
|
log_with=log_with, |
|
project_dir=logging_dir, |
|
kwargs_handlers=kwargs_handlers, |
|
) |
|
return accelerator |
|
|
|
|
|
def prepare_dtype(args: argparse.Namespace): |
|
weight_dtype = torch.float32 |
|
if args.mixed_precision == "fp16": |
|
weight_dtype = torch.float16 |
|
elif args.mixed_precision == "bf16": |
|
weight_dtype = torch.bfloat16 |
|
|
|
save_dtype = None |
|
if args.save_precision == "fp16": |
|
save_dtype = torch.float16 |
|
elif args.save_precision == "bf16": |
|
save_dtype = torch.bfloat16 |
|
elif args.save_precision == "float": |
|
save_dtype = torch.float32 |
|
|
|
return weight_dtype, save_dtype |
|
|
|
|
|
def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", unet_use_linear_projection_in_v2=False): |
|
name_or_path = args.pretrained_model_name_or_path |
|
name_or_path = os.path.realpath(name_or_path) if os.path.islink(name_or_path) else name_or_path |
|
load_stable_diffusion_format = os.path.isfile(name_or_path) |
|
if load_stable_diffusion_format: |
|
print(f"load StableDiffusion checkpoint: {name_or_path}") |
|
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint( |
|
args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2 |
|
) |
|
else: |
|
|
|
print(f"load Diffusers pretrained models: {name_or_path}") |
|
try: |
|
pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None) |
|
except EnvironmentError as ex: |
|
print( |
|
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" |
|
) |
|
raise ex |
|
text_encoder = pipe.text_encoder |
|
vae = pipe.vae |
|
unet = pipe.unet |
|
del pipe |
|
|
|
|
|
|
|
|
|
original_unet = UNet2DConditionModel( |
|
unet.config.sample_size, |
|
unet.config.attention_head_dim, |
|
unet.config.cross_attention_dim, |
|
unet.config.use_linear_projection, |
|
unet.config.upcast_attention, |
|
) |
|
original_unet.load_state_dict(unet.state_dict()) |
|
unet = original_unet |
|
print("U-Net converted to original U-Net") |
|
|
|
|
|
if args.vae is not None: |
|
vae = model_util.load_vae(args.vae, weight_dtype) |
|
print("additional VAE loaded") |
|
|
|
return text_encoder, vae, unet, load_stable_diffusion_format |
|
|
|
|
|
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False): |
|
|
|
for pi in range(accelerator.state.num_processes): |
|
if pi == accelerator.state.local_process_index: |
|
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") |
|
|
|
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model( |
|
args, |
|
weight_dtype, |
|
accelerator.device if args.lowram else "cpu", |
|
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2, |
|
) |
|
|
|
|
|
if args.lowram: |
|
text_encoder.to(accelerator.device) |
|
unet.to(accelerator.device) |
|
vae.to(accelerator.device) |
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
accelerator.wait_for_everyone() |
|
|
|
return text_encoder, vae, unet, load_stable_diffusion_format |
|
|
|
|
|
def patch_accelerator_for_fp16_training(accelerator): |
|
org_unscale_grads = accelerator.scaler._unscale_grads_ |
|
|
|
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): |
|
return org_unscale_grads(optimizer, inv_scale, found_inf, True) |
|
|
|
accelerator.scaler._unscale_grads_ = _unscale_grads_replacer |
|
|
|
|
|
def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None): |
|
|
|
if input_ids.size()[-1] != tokenizer.model_max_length: |
|
return text_encoder(input_ids)[0] |
|
|
|
|
|
b_size = input_ids.size()[0] |
|
input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) |
|
|
|
if args.clip_skip is None: |
|
encoder_hidden_states = text_encoder(input_ids)[0] |
|
else: |
|
enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) |
|
encoder_hidden_states = enc_out["hidden_states"][-args.clip_skip] |
|
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) |
|
|
|
|
|
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) |
|
|
|
if args.max_token_length is not None: |
|
if args.v2: |
|
|
|
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] |
|
for i in range(1, args.max_token_length, tokenizer.model_max_length): |
|
chunk = encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2] |
|
if i > 0: |
|
for j in range(len(chunk)): |
|
if input_ids[j, 1] == tokenizer.eos_token: |
|
chunk[j, 0] = chunk[j, 1] |
|
states_list.append(chunk) |
|
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) |
|
encoder_hidden_states = torch.cat(states_list, dim=1) |
|
else: |
|
|
|
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] |
|
for i in range(1, args.max_token_length, tokenizer.model_max_length): |
|
states_list.append(encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2]) |
|
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) |
|
encoder_hidden_states = torch.cat(states_list, dim=1) |
|
|
|
if weight_dtype is not None: |
|
|
|
encoder_hidden_states = encoder_hidden_states.to(weight_dtype) |
|
|
|
return encoder_hidden_states |
|
|
|
|
|
def pool_workaround( |
|
text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int |
|
): |
|
r""" |
|
workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output |
|
instead of the hidden states for the EOS token |
|
If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output |
|
|
|
Original code from CLIP's pooling function: |
|
|
|
\# text_embeds.shape = [batch_size, sequence_length, transformer.width] |
|
\# take features from the eot embedding (eot_token is the highest number in each sequence) |
|
\# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 |
|
pooled_output = last_hidden_state[ |
|
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), |
|
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), |
|
] |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eos_token_mask = (input_ids == eos_token_id).int() |
|
|
|
|
|
eos_token_index = torch.argmax(eos_token_mask, dim=1) |
|
eos_token_index = eos_token_index.to(device=last_hidden_state.device) |
|
|
|
|
|
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index] |
|
|
|
|
|
pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype)) |
|
pooled_output = pooled_output.to(last_hidden_state.dtype) |
|
|
|
return pooled_output |
|
|
|
|
|
def get_hidden_states_sdxl( |
|
max_token_length: int, |
|
input_ids1: torch.Tensor, |
|
input_ids2: torch.Tensor, |
|
tokenizer1: CLIPTokenizer, |
|
tokenizer2: CLIPTokenizer, |
|
text_encoder1: CLIPTextModel, |
|
text_encoder2: CLIPTextModelWithProjection, |
|
weight_dtype: Optional[str] = None, |
|
accelerator: Optional[Accelerator] = None, |
|
): |
|
|
|
b_size = input_ids1.size()[0] |
|
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) |
|
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) |
|
|
|
|
|
enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True) |
|
hidden_states1 = enc_out["hidden_states"][11] |
|
|
|
|
|
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True) |
|
hidden_states2 = enc_out["hidden_states"][-2] |
|
|
|
|
|
unwrapped_text_encoder2 = text_encoder2 if accelerator is None else accelerator.unwrap_model(text_encoder2) |
|
pool2 = pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id) |
|
|
|
|
|
n_size = 1 if max_token_length is None else max_token_length // 75 |
|
hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1])) |
|
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1])) |
|
|
|
if max_token_length is not None: |
|
|
|
|
|
states_list = [hidden_states1[:, 0].unsqueeze(1)] |
|
for i in range(1, max_token_length, tokenizer1.model_max_length): |
|
states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) |
|
states_list.append(hidden_states1[:, -1].unsqueeze(1)) |
|
hidden_states1 = torch.cat(states_list, dim=1) |
|
|
|
|
|
states_list = [hidden_states2[:, 0].unsqueeze(1)] |
|
for i in range(1, max_token_length, tokenizer2.model_max_length): |
|
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] |
|
|
|
|
|
|
|
|
|
|
|
|
|
states_list.append(chunk) |
|
states_list.append(hidden_states2[:, -1].unsqueeze(1)) |
|
hidden_states2 = torch.cat(states_list, dim=1) |
|
|
|
|
|
pool2 = pool2[::n_size] |
|
|
|
if weight_dtype is not None: |
|
|
|
hidden_states1 = hidden_states1.to(weight_dtype) |
|
hidden_states2 = hidden_states2.to(weight_dtype) |
|
|
|
return hidden_states1, hidden_states2, pool2 |
|
|
|
|
|
def default_if_none(value, default): |
|
return default if value is None else value |
|
|
|
|
|
def get_epoch_ckpt_name(args: argparse.Namespace, ext: str, epoch_no: int): |
|
model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME) |
|
return EPOCH_FILE_NAME.format(model_name, epoch_no) + ext |
|
|
|
|
|
def get_step_ckpt_name(args: argparse.Namespace, ext: str, step_no: int): |
|
model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME) |
|
return STEP_FILE_NAME.format(model_name, step_no) + ext |
|
|
|
|
|
def get_last_ckpt_name(args: argparse.Namespace, ext: str): |
|
model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME) |
|
return model_name + ext |
|
|
|
|
|
def get_remove_epoch_no(args: argparse.Namespace, epoch_no: int): |
|
if args.save_last_n_epochs is None: |
|
return None |
|
|
|
remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs |
|
if remove_epoch_no < 0: |
|
return None |
|
return remove_epoch_no |
|
|
|
|
|
def get_remove_step_no(args: argparse.Namespace, step_no: int): |
|
if args.save_last_n_steps is None: |
|
return None |
|
|
|
|
|
|
|
remove_step_no = step_no - args.save_last_n_steps - 1 |
|
remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps) |
|
if remove_step_no < 0: |
|
return None |
|
return remove_step_no |
|
|
|
|
|
|
|
|
|
def save_sd_model_on_epoch_end_or_stepwise( |
|
args: argparse.Namespace, |
|
on_epoch_end: bool, |
|
accelerator, |
|
src_path: str, |
|
save_stable_diffusion_format: bool, |
|
use_safetensors: bool, |
|
save_dtype: torch.dtype, |
|
epoch: int, |
|
num_train_epochs: int, |
|
global_step: int, |
|
text_encoder, |
|
unet, |
|
vae, |
|
): |
|
def sd_saver(ckpt_file, epoch_no, global_step): |
|
sai_metadata = get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True) |
|
model_util.save_stable_diffusion_checkpoint( |
|
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, sai_metadata, save_dtype, vae |
|
) |
|
|
|
def diffusers_saver(out_dir): |
|
model_util.save_diffusers_checkpoint( |
|
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors |
|
) |
|
|
|
save_sd_model_on_epoch_end_or_stepwise_common( |
|
args, |
|
on_epoch_end, |
|
accelerator, |
|
save_stable_diffusion_format, |
|
use_safetensors, |
|
epoch, |
|
num_train_epochs, |
|
global_step, |
|
sd_saver, |
|
diffusers_saver, |
|
) |
|
|
|
|
|
def save_sd_model_on_epoch_end_or_stepwise_common( |
|
args: argparse.Namespace, |
|
on_epoch_end: bool, |
|
accelerator, |
|
save_stable_diffusion_format: bool, |
|
use_safetensors: bool, |
|
epoch: int, |
|
num_train_epochs: int, |
|
global_step: int, |
|
sd_saver, |
|
diffusers_saver, |
|
): |
|
if on_epoch_end: |
|
epoch_no = epoch + 1 |
|
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs |
|
if not saving: |
|
return |
|
|
|
model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME) |
|
remove_no = get_remove_epoch_no(args, epoch_no) |
|
else: |
|
|
|
|
|
model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME) |
|
epoch_no = epoch |
|
remove_no = get_remove_step_no(args, global_step) |
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
if save_stable_diffusion_format: |
|
ext = ".safetensors" if use_safetensors else ".ckpt" |
|
|
|
if on_epoch_end: |
|
ckpt_name = get_epoch_ckpt_name(args, ext, epoch_no) |
|
else: |
|
ckpt_name = get_step_ckpt_name(args, ext, global_step) |
|
|
|
ckpt_file = os.path.join(args.output_dir, ckpt_name) |
|
print(f"\nsaving checkpoint: {ckpt_file}") |
|
sd_saver(ckpt_file, epoch_no, global_step) |
|
|
|
if args.huggingface_repo_id is not None: |
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name) |
|
|
|
|
|
if remove_no is not None: |
|
if on_epoch_end: |
|
remove_ckpt_name = get_epoch_ckpt_name(args, ext, remove_no) |
|
else: |
|
remove_ckpt_name = get_step_ckpt_name(args, ext, remove_no) |
|
|
|
remove_ckpt_file = os.path.join(args.output_dir, remove_ckpt_name) |
|
if os.path.exists(remove_ckpt_file): |
|
print(f"removing old checkpoint: {remove_ckpt_file}") |
|
os.remove(remove_ckpt_file) |
|
|
|
else: |
|
if on_epoch_end: |
|
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no)) |
|
else: |
|
out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step)) |
|
|
|
print(f"\nsaving model: {out_dir}") |
|
diffusers_saver(out_dir) |
|
|
|
if args.huggingface_repo_id is not None: |
|
huggingface_util.upload(args, out_dir, "/" + model_name) |
|
|
|
|
|
if remove_no is not None: |
|
if on_epoch_end: |
|
remove_out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, remove_no)) |
|
else: |
|
remove_out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, remove_no)) |
|
|
|
if os.path.exists(remove_out_dir): |
|
print(f"removing old model: {remove_out_dir}") |
|
shutil.rmtree(remove_out_dir) |
|
|
|
if args.save_state: |
|
if on_epoch_end: |
|
save_and_remove_state_on_epoch_end(args, accelerator, epoch_no) |
|
else: |
|
save_and_remove_state_stepwise(args, accelerator, global_step) |
|
|
|
|
|
def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, epoch_no): |
|
model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME) |
|
|
|
print(f"\nsaving state at epoch {epoch_no}") |
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)) |
|
accelerator.save_state(state_dir) |
|
if args.save_state_to_huggingface: |
|
print("uploading state to huggingface.") |
|
huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no)) |
|
|
|
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs |
|
if last_n_epochs is not None: |
|
remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs |
|
state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) |
|
if os.path.exists(state_dir_old): |
|
print(f"removing old state: {state_dir_old}") |
|
shutil.rmtree(state_dir_old) |
|
|
|
|
|
def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_no): |
|
model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME) |
|
|
|
print(f"\nsaving state at step {step_no}") |
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no)) |
|
accelerator.save_state(state_dir) |
|
if args.save_state_to_huggingface: |
|
print("uploading state to huggingface.") |
|
huggingface_util.upload(args, state_dir, "/" + STEP_STATE_NAME.format(model_name, step_no)) |
|
|
|
last_n_steps = args.save_last_n_steps_state if args.save_last_n_steps_state else args.save_last_n_steps |
|
if last_n_steps is not None: |
|
|
|
remove_step_no = step_no - last_n_steps - 1 |
|
remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps) |
|
|
|
if remove_step_no > 0: |
|
state_dir_old = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, remove_step_no)) |
|
if os.path.exists(state_dir_old): |
|
print(f"removing old state: {state_dir_old}") |
|
shutil.rmtree(state_dir_old) |
|
|
|
|
|
def save_state_on_train_end(args: argparse.Namespace, accelerator): |
|
model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME) |
|
|
|
print("\nsaving last state.") |
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)) |
|
accelerator.save_state(state_dir) |
|
|
|
if args.save_state_to_huggingface: |
|
print("uploading last state to huggingface.") |
|
huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name)) |
|
|
|
|
|
def save_sd_model_on_train_end( |
|
args: argparse.Namespace, |
|
src_path: str, |
|
save_stable_diffusion_format: bool, |
|
use_safetensors: bool, |
|
save_dtype: torch.dtype, |
|
epoch: int, |
|
global_step: int, |
|
text_encoder, |
|
unet, |
|
vae, |
|
): |
|
def sd_saver(ckpt_file, epoch_no, global_step): |
|
sai_metadata = get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True) |
|
model_util.save_stable_diffusion_checkpoint( |
|
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, sai_metadata, save_dtype, vae |
|
) |
|
|
|
def diffusers_saver(out_dir): |
|
model_util.save_diffusers_checkpoint( |
|
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors |
|
) |
|
|
|
save_sd_model_on_train_end_common( |
|
args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver |
|
) |
|
|
|
|
|
def save_sd_model_on_train_end_common( |
|
args: argparse.Namespace, |
|
save_stable_diffusion_format: bool, |
|
use_safetensors: bool, |
|
epoch: int, |
|
global_step: int, |
|
sd_saver, |
|
diffusers_saver, |
|
): |
|
model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME) |
|
|
|
if save_stable_diffusion_format: |
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt") |
|
ckpt_file = os.path.join(args.output_dir, ckpt_name) |
|
|
|
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") |
|
sd_saver(ckpt_file, epoch, global_step) |
|
|
|
if args.huggingface_repo_id is not None: |
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True) |
|
else: |
|
out_dir = os.path.join(args.output_dir, model_name) |
|
os.makedirs(out_dir, exist_ok=True) |
|
|
|
print(f"save trained model as Diffusers to {out_dir}") |
|
diffusers_saver(out_dir) |
|
|
|
if args.huggingface_repo_id is not None: |
|
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) |
|
|
|
|
|
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): |
|
|
|
noise = torch.randn_like(latents, device=latents.device) |
|
if args.noise_offset: |
|
noise = custom_train_functions.apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) |
|
if args.multires_noise_iterations: |
|
noise = custom_train_functions.pyramid_noise_like( |
|
noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount |
|
) |
|
|
|
|
|
b_size = latents.shape[0] |
|
min_timestep = 0 if args.min_timestep is None else args.min_timestep |
|
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep |
|
|
|
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=latents.device) |
|
timesteps = timesteps.long() |
|
|
|
|
|
|
|
if args.ip_noise_gamma: |
|
noisy_latents = noise_scheduler.add_noise(latents, noise + args.ip_noise_gamma * torch.randn_like(latents), timesteps) |
|
else: |
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
|
|
|
return noise, noisy_latents, timesteps |
|
|
|
|
|
def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True): |
|
names = [] |
|
if including_unet: |
|
names.append("unet") |
|
names.append("text_encoder1") |
|
names.append("text_encoder2") |
|
|
|
append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names) |
|
|
|
|
|
def append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names): |
|
lrs = lr_scheduler.get_last_lr() |
|
|
|
for lr_index in range(len(lrs)): |
|
name = names[lr_index] |
|
logs["lr/" + name] = float(lrs[lr_index]) |
|
|
|
if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower(): |
|
logs["lr/d*lr/" + name] = ( |
|
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"] |
|
) |
|
|
|
|
|
|
|
SCHEDULER_LINEAR_START = 0.00085 |
|
SCHEDULER_LINEAR_END = 0.0120 |
|
SCHEDULER_TIMESTEPS = 1000 |
|
SCHEDLER_SCHEDULE = "scaled_linear" |
|
|
|
|
|
def get_my_scheduler( |
|
*, |
|
sample_sampler: str, |
|
v_parameterization: bool, |
|
): |
|
sched_init_args = {} |
|
if sample_sampler == "ddim": |
|
scheduler_cls = DDIMScheduler |
|
elif sample_sampler == "ddpm": |
|
scheduler_cls = DDPMScheduler |
|
elif sample_sampler == "pndm": |
|
scheduler_cls = PNDMScheduler |
|
elif sample_sampler == "lms" or sample_sampler == "k_lms": |
|
scheduler_cls = LMSDiscreteScheduler |
|
elif sample_sampler == "euler" or sample_sampler == "k_euler": |
|
scheduler_cls = EulerDiscreteScheduler |
|
elif sample_sampler == "euler_a" or sample_sampler == "k_euler_a": |
|
scheduler_cls = EulerAncestralDiscreteScheduler |
|
elif sample_sampler == "dpmsolver" or sample_sampler == "dpmsolver++": |
|
scheduler_cls = DPMSolverMultistepScheduler |
|
sched_init_args["algorithm_type"] = sample_sampler |
|
elif sample_sampler == "dpmsingle": |
|
scheduler_cls = DPMSolverSinglestepScheduler |
|
elif sample_sampler == "heun": |
|
scheduler_cls = HeunDiscreteScheduler |
|
elif sample_sampler == "dpm_2" or sample_sampler == "k_dpm_2": |
|
scheduler_cls = KDPM2DiscreteScheduler |
|
elif sample_sampler == "dpm_2_a" or sample_sampler == "k_dpm_2_a": |
|
scheduler_cls = KDPM2AncestralDiscreteScheduler |
|
else: |
|
scheduler_cls = DDIMScheduler |
|
|
|
if v_parameterization: |
|
sched_init_args["prediction_type"] = "v_prediction" |
|
|
|
scheduler = scheduler_cls( |
|
num_train_timesteps=SCHEDULER_TIMESTEPS, |
|
beta_start=SCHEDULER_LINEAR_START, |
|
beta_end=SCHEDULER_LINEAR_END, |
|
beta_schedule=SCHEDLER_SCHEDULE, |
|
**sched_init_args, |
|
) |
|
|
|
|
|
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: |
|
|
|
scheduler.config.clip_sample = True |
|
|
|
return scheduler |
|
|
|
|
|
def sample_images(*args, **kwargs): |
|
return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs) |
|
|
|
|
|
def line_to_prompt_dict(line: str) -> dict: |
|
|
|
prompt_args = line.split(" --") |
|
prompt_dict = {} |
|
prompt_dict["prompt"] = prompt_args[0] |
|
|
|
for parg in prompt_args: |
|
try: |
|
m = re.match(r"w (\d+)", parg, re.IGNORECASE) |
|
if m: |
|
prompt_dict["width"] = int(m.group(1)) |
|
continue |
|
|
|
m = re.match(r"h (\d+)", parg, re.IGNORECASE) |
|
if m: |
|
prompt_dict["height"] = int(m.group(1)) |
|
continue |
|
|
|
m = re.match(r"d (\d+)", parg, re.IGNORECASE) |
|
if m: |
|
prompt_dict["seed"] = int(m.group(1)) |
|
continue |
|
|
|
m = re.match(r"s (\d+)", parg, re.IGNORECASE) |
|
if m: |
|
prompt_dict["sample_steps"] = max(1, min(1000, int(m.group(1)))) |
|
continue |
|
|
|
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) |
|
if m: |
|
prompt_dict["scale"] = float(m.group(1)) |
|
continue |
|
|
|
m = re.match(r"n (.+)", parg, re.IGNORECASE) |
|
if m: |
|
prompt_dict["negative_prompt"] = m.group(1) |
|
continue |
|
|
|
m = re.match(r"ss (.+)", parg, re.IGNORECASE) |
|
if m: |
|
prompt_dict["sample_sampler"] = m.group(1) |
|
continue |
|
|
|
m = re.match(r"cn (.+)", parg, re.IGNORECASE) |
|
if m: |
|
prompt_dict["controlnet_image"] = m.group(1) |
|
continue |
|
|
|
except ValueError as ex: |
|
print(f"Exception in parsing / 解析エラー: {parg}") |
|
print(ex) |
|
|
|
return prompt_dict |
|
|
|
|
|
def sample_images_common( |
|
pipe_class, |
|
accelerator: Accelerator, |
|
args: argparse.Namespace, |
|
epoch, |
|
steps, |
|
device, |
|
vae, |
|
tokenizer, |
|
text_encoder, |
|
unet, |
|
prompt_replacement=None, |
|
controlnet=None, |
|
): |
|
""" |
|
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した |
|
""" |
|
if steps == 0: |
|
if not args.sample_at_first: |
|
return |
|
else: |
|
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: |
|
return |
|
if args.sample_every_n_epochs is not None: |
|
|
|
if epoch is None or epoch % args.sample_every_n_epochs != 0: |
|
return |
|
else: |
|
if steps % args.sample_every_n_steps != 0 or epoch is not None: |
|
return |
|
|
|
print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}") |
|
if not os.path.isfile(args.sample_prompts): |
|
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") |
|
return |
|
|
|
org_vae_device = vae.device |
|
vae.to(device) |
|
|
|
|
|
unet = accelerator.unwrap_model(unet) |
|
if isinstance(text_encoder, (list, tuple)): |
|
text_encoder = [accelerator.unwrap_model(te) for te in text_encoder] |
|
else: |
|
text_encoder = accelerator.unwrap_model(text_encoder) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.sample_prompts.endswith(".txt"): |
|
with open(args.sample_prompts, "r", encoding="utf-8") as f: |
|
lines = f.readlines() |
|
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] |
|
elif args.sample_prompts.endswith(".toml"): |
|
with open(args.sample_prompts, "r", encoding="utf-8") as f: |
|
data = toml.load(f) |
|
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] |
|
elif args.sample_prompts.endswith(".json"): |
|
with open(args.sample_prompts, "r", encoding="utf-8") as f: |
|
prompts = json.load(f) |
|
|
|
schedulers: dict = {} |
|
default_scheduler = get_my_scheduler( |
|
sample_sampler=args.sample_sampler, |
|
v_parameterization=args.v_parameterization, |
|
) |
|
schedulers[args.sample_sampler] = default_scheduler |
|
|
|
pipeline = pipe_class( |
|
text_encoder=text_encoder, |
|
vae=vae, |
|
unet=unet, |
|
tokenizer=tokenizer, |
|
scheduler=default_scheduler, |
|
safety_checker=None, |
|
feature_extractor=None, |
|
requires_safety_checker=False, |
|
clip_skip=args.clip_skip, |
|
) |
|
pipeline.to(device) |
|
|
|
save_dir = args.output_dir + "/sample" |
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
rng_state = torch.get_rng_state() |
|
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None |
|
|
|
with torch.no_grad(): |
|
|
|
for i, prompt_dict in enumerate(prompts): |
|
if not accelerator.is_main_process: |
|
continue |
|
|
|
if isinstance(prompt_dict, str): |
|
prompt_dict = line_to_prompt_dict(prompt_dict) |
|
|
|
assert isinstance(prompt_dict, dict) |
|
negative_prompt = prompt_dict.get("negative_prompt") |
|
sample_steps = prompt_dict.get("sample_steps", 30) |
|
width = prompt_dict.get("width", 512) |
|
height = prompt_dict.get("height", 512) |
|
scale = prompt_dict.get("scale", 7.5) |
|
seed = prompt_dict.get("seed") |
|
controlnet_image = prompt_dict.get("controlnet_image") |
|
prompt: str = prompt_dict.get("prompt", "") |
|
sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) |
|
|
|
if seed is not None: |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
|
|
scheduler = schedulers.get(sampler_name) |
|
if scheduler is None: |
|
scheduler = get_my_scheduler( |
|
sample_sampler=sampler_name, |
|
v_parameterization=args.v_parameterization, |
|
) |
|
schedulers[sampler_name] = scheduler |
|
pipeline.scheduler = scheduler |
|
|
|
if prompt_replacement is not None: |
|
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) |
|
if negative_prompt is not None: |
|
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) |
|
|
|
if controlnet_image is not None: |
|
controlnet_image = Image.open(controlnet_image).convert("RGB") |
|
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) |
|
|
|
height = max(64, height - height % 8) |
|
width = max(64, width - width % 8) |
|
print(f"prompt: {prompt}") |
|
print(f"negative_prompt: {negative_prompt}") |
|
print(f"height: {height}") |
|
print(f"width: {width}") |
|
print(f"sample_steps: {sample_steps}") |
|
print(f"scale: {scale}") |
|
print(f"sample_sampler: {sampler_name}") |
|
if seed is not None: |
|
print(f"seed: {seed}") |
|
with accelerator.autocast(): |
|
latents = pipeline( |
|
prompt=prompt, |
|
height=height, |
|
width=width, |
|
num_inference_steps=sample_steps, |
|
guidance_scale=scale, |
|
negative_prompt=negative_prompt, |
|
controlnet=controlnet, |
|
controlnet_image=controlnet_image, |
|
) |
|
|
|
image = pipeline.latents_to_image(latents)[0] |
|
|
|
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) |
|
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" |
|
seed_suffix = "" if seed is None else f"_{seed}" |
|
img_filename = ( |
|
f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" |
|
) |
|
|
|
image.save(os.path.join(save_dir, img_filename)) |
|
|
|
|
|
try: |
|
wandb_tracker = accelerator.get_tracker("wandb") |
|
try: |
|
import wandb |
|
except ImportError: |
|
raise ImportError("No wandb / wandb がインストールされていないようです") |
|
|
|
wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) |
|
except: |
|
pass |
|
|
|
|
|
del pipeline |
|
torch.cuda.empty_cache() |
|
|
|
torch.set_rng_state(rng_state) |
|
if cuda_rng_state is not None: |
|
torch.cuda.set_rng_state(cuda_rng_state) |
|
vae.to(org_vae_device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ImageLoadingDataset(torch.utils.data.Dataset): |
|
def __init__(self, image_paths): |
|
self.images = image_paths |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
def __getitem__(self, idx): |
|
img_path = self.images[idx] |
|
|
|
try: |
|
image = Image.open(img_path).convert("RGB") |
|
|
|
tensor_pil = transforms.functional.pil_to_tensor(image) |
|
except Exception as e: |
|
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") |
|
return None |
|
|
|
return (tensor_pil, img_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
class collator_class: |
|
def __init__(self, epoch, step, dataset): |
|
self.current_epoch = epoch |
|
self.current_step = step |
|
self.dataset = dataset |
|
|
|
def __call__(self, examples): |
|
worker_info = torch.utils.data.get_worker_info() |
|
|
|
if worker_info is not None: |
|
dataset = worker_info.dataset |
|
else: |
|
dataset = self.dataset |
|
|
|
|
|
dataset.set_current_epoch(self.current_epoch.value) |
|
dataset.set_current_step(self.current_step.value) |
|
return examples[0] |
|
|
|
|
|
class LossRecorder: |
|
def __init__(self): |
|
self.loss_list: List[float] = [] |
|
self.loss_total: float = 0.0 |
|
|
|
def add(self, *, epoch: int, step: int, loss: float) -> None: |
|
if epoch == 0: |
|
self.loss_list.append(loss) |
|
else: |
|
self.loss_total -= self.loss_list[step] |
|
self.loss_list[step] = loss |
|
self.loss_total += loss |
|
|
|
@property |
|
def moving_average(self) -> float: |
|
return self.loss_total / len(self.loss_list) |
|
|